Merge pull request #2518 from vrv/branch_123332988

Upstream changes from internal
This commit is contained in:
Vijay Vasudevan 2016-05-26 15:08:06 -07:00
commit 15e51e6113
602 changed files with 943 additions and 321 deletions

View File

@ -53,6 +53,7 @@ def optimize_loss(loss,
clip_gradients=None, clip_gradients=None,
moving_average_decay=0.9, moving_average_decay=0.9,
learning_rate_decay_fn=None, learning_rate_decay_fn=None,
update_ops=None,
variables=None, variables=None,
name=None): name=None):
"""Given loss and parameters for optimizer, returns a training op. """Given loss and parameters for optimizer, returns a training op.
@ -81,6 +82,8 @@ def optimize_loss(loss,
Can be used to implement any learning rate decay Can be used to implement any learning rate decay
functions. functions.
For example: tf.train.exponential_decay. For example: tf.train.exponential_decay.
update_ops: list of update `Operation`s to execute at each step. If `None`,
uses elements of UPDATE_OPS collection.
variables: list of variables to optimize or variables: list of variables to optimize or
`None` to use all trainable variables. `None` to use all trainable variables.
name: The name for this operation is used to scope operations and summaries. name: The name for this operation is used to scope operations and summaries.
@ -92,6 +95,15 @@ def optimize_loss(loss,
ValueError: if optimizer is wrong type. ValueError: if optimizer is wrong type.
""" """
with vs.variable_op_scope([loss, global_step], name, "OptimizeLoss"): with vs.variable_op_scope([loss, global_step], name, "OptimizeLoss"):
# Update ops take UPDATE_OPS collection if not provided.
update_ops = (set(update_ops or []) or
set(ops.get_collection(ops.GraphKeys.UPDATE_OPS)))
# Make sure update ops are ran before computing loss.
if update_ops:
with ops.control_dependencies(update_ops):
barrier = control_flow_ops.no_op(name="update_barrier")
loss = control_flow_ops.with_dependencies([barrier], loss)
# Moving average of the loss with decay. # Moving average of the loss with decay.
if moving_average_decay is not None: if moving_average_decay is not None:
# Generate moving averages of the loss. # Generate moving averages of the loss.

View File

@ -132,6 +132,25 @@ class OptimizersTest(tf.test.TestCase):
tf.contrib.layers.optimize_loss( tf.contrib.layers.optimize_loss(
loss, global_step, learning_rate=0.1, optimizer="SGD") loss, global_step, learning_rate=0.1, optimizer="SGD")
def testUpdateOp(self):
optimizers = ["SGD", tf.train.GradientDescentOptimizer,
tf.train.GradientDescentOptimizer(learning_rate=0.1)]
for optimizer in optimizers:
with tf.Graph().as_default() as g:
with self.test_session(graph=g) as session:
x, var, loss, global_step = _setup_model()
update_op = tf.assign(var, 20)
train = tf.contrib.layers.optimize_loss(loss,
global_step,
learning_rate=0.1,
optimizer=optimizer,
update_ops=[update_op])
tf.initialize_all_variables().run()
session.run(train, feed_dict={x: 5})
var_value, global_step_value = session.run([var, global_step])
# 19.5, due to update of var to 20 before loss computation.
self.assertEqual(var_value, 19.5)
self.assertEqual(global_step_value, 1)
if __name__ == "__main__": if __name__ == "__main__":
tf.test.main() tf.test.main()

View File

@ -195,7 +195,10 @@ def train(graph,
raise ValueError('No "global_step" was provided or found in the graph.') raise ValueError('No "global_step" was provided or found in the graph.')
# TODO(ipolosukhin): Replace all functionality of Supervisor with Monitors. # TODO(ipolosukhin): Replace all functionality of Supervisor with Monitors.
if not monitors: if not supervisor_is_chief:
# monitors should run only in supervisor.
monitors = []
elif not monitors:
monitors = monitors_lib.get_default_monitors( monitors = monitors_lib.get_default_monitors(
loss_op=loss_op, loss_op=loss_op,
summary_op=logging_ops.get_summary_op(), summary_op=logging_ops.get_summary_op(),

View File

@ -26,8 +26,9 @@ from tensorflow.python.training import input as input_ops
def read_batch_examples(file_pattern, batch_size, reader, def read_batch_examples(file_pattern, batch_size, reader,
randomize_input=True, queue_capacity=10000, randomize_input=True, num_epochs=None,
num_threads=1, name='dequeue_examples'): queue_capacity=10000, num_threads=1,
name=None):
"""Adds operations to read, queue, batch `Example` protos. """Adds operations to read, queue, batch `Example` protos.
Given file pattern (or list of files), will setup a queue for file names, Given file pattern (or list of files), will setup a queue for file names,
@ -46,6 +47,10 @@ def read_batch_examples(file_pattern, batch_size, reader,
reader: A function or class that returns an object with reader: A function or class that returns an object with
`read` method, (filename tensor) -> (example tensor). `read` method, (filename tensor) -> (example tensor).
randomize_input: Whether the input should be randomized. randomize_input: Whether the input should be randomized.
num_epochs: Integer specifying the number of times to read through the
dataset. If `None`, cycles through the dataset forever.
NOTE - If specified, creates a variable that must be initialized, so call
`tf.initialize_all_variables()` as shown in the tests.
queue_capacity: Capacity for input queue. queue_capacity: Capacity for input queue.
num_threads: The number of threads enqueuing examples. num_threads: The number of threads enqueuing examples.
name: Name of resulting op. name: Name of resulting op.
@ -82,39 +87,47 @@ def read_batch_examples(file_pattern, batch_size, reader,
(batch_size, queue_capacity)) (batch_size, queue_capacity))
if (not num_threads) or (num_threads <= 0): if (not num_threads) or (num_threads <= 0):
raise ValueError('Invalid num_threads %s.' % num_threads) raise ValueError('Invalid num_threads %s.' % num_threads)
if (num_epochs is not None) and (num_epochs <= 0):
raise ValueError('Invalid num_epochs %s.' % num_epochs)
with ops.name_scope(name) as scope: with ops.op_scope([file_pattern], name, 'read_batch_examples') as scope:
# Setup filename queue with shuffling. # Setup filename queue with shuffling.
with ops.name_scope('file_name_queue') as file_name_queue_scope: with ops.name_scope('file_name_queue') as file_name_queue_scope:
file_name_queue = input_ops.string_input_producer( file_name_queue = input_ops.string_input_producer(
constant_op.constant(file_names, name='input'), constant_op.constant(file_names, name='input'),
shuffle=randomize_input, name=file_name_queue_scope) shuffle=randomize_input, num_epochs=num_epochs,
name=file_name_queue_scope)
# Create reader and set it to read from filename queue. # Create readers, one per thread and set them to read from filename queue.
with ops.name_scope('read'): with ops.name_scope('read'):
_, example_proto = reader().read(file_name_queue) example_list = []
for _ in range(num_threads):
_, example_proto = reader().read(file_name_queue)
example_list.append([example_proto])
# Setup batching queue. # Setup batching queue given list of read example tensors.
if randomize_input: if randomize_input:
if isinstance(batch_size, ops.Tensor): if isinstance(batch_size, ops.Tensor):
min_after_dequeue = int(queue_capacity * 0.4) min_after_dequeue = int(queue_capacity * 0.4)
else: else:
min_after_dequeue = max(queue_capacity - (3 * batch_size), batch_size) min_after_dequeue = max(queue_capacity - (3 * batch_size), batch_size)
examples = input_ops.shuffle_batch( examples = input_ops.shuffle_batch_join(
[example_proto], batch_size, capacity=queue_capacity, example_list, batch_size, capacity=queue_capacity,
num_threads=num_threads, min_after_dequeue=min_after_dequeue, min_after_dequeue=min_after_dequeue,
name=scope) name=scope)
else: else:
examples = input_ops.batch( examples = input_ops.batch_join(
[example_proto], batch_size, capacity=queue_capacity, example_list, batch_size, capacity=queue_capacity,
num_threads=num_threads, name=scope) name=scope)
return examples return examples
def read_batch_features(file_pattern, batch_size, features, reader, def read_batch_features(file_pattern, batch_size, features, reader,
randomize_input=True, queue_capacity=10000, randomize_input=True, num_epochs=None,
num_threads=1, name='dequeue_examples'): queue_capacity=10000, reader_num_threads=1,
parser_num_threads=1,
name=None):
"""Adds operations to read, queue, batch and parse `Example` protos. """Adds operations to read, queue, batch and parse `Example` protos.
Given file pattern (or list of files), will setup a queue for file names, Given file pattern (or list of files), will setup a queue for file names,
@ -136,8 +149,13 @@ def read_batch_features(file_pattern, batch_size, features, reader,
reader: A function or class that returns an object with reader: A function or class that returns an object with
`read` method, (filename tensor) -> (example tensor). `read` method, (filename tensor) -> (example tensor).
randomize_input: Whether the input should be randomized. randomize_input: Whether the input should be randomized.
num_epochs: Integer specifying the number of times to read through the
dataset. If None, cycles through the dataset forever. NOTE - If specified,
creates a variable that must be initialized, so call
tf.initialize_all_variables() as shown in the tests.
queue_capacity: Capacity for input queue. queue_capacity: Capacity for input queue.
num_threads: The number of threads enqueuing examples. reader_num_threads: The number of threads to read examples.
parser_num_threads: The number of threads to parse examples.
name: Name of resulting op. name: Name of resulting op.
Returns: Returns:
@ -146,17 +164,29 @@ def read_batch_features(file_pattern, batch_size, features, reader,
Raises: Raises:
ValueError: for invalid inputs. ValueError: for invalid inputs.
""" """
examples = read_batch_examples( with ops.op_scope([file_pattern], name, 'read_batch_features') as scope:
file_pattern, batch_size, reader, randomize_input, examples = read_batch_examples(
queue_capacity, num_threads, name=name) file_pattern, batch_size, reader, randomize_input=randomize_input,
num_epochs=num_epochs, queue_capacity=queue_capacity,
num_threads=reader_num_threads, name=scope)
# Parse features into tensors. # Parse features into tensors in many threads and put on the queue.
return parsing_ops.parse_example(examples, features) features_list = []
for _ in range(parser_num_threads):
features_list.append(parsing_ops.parse_example(examples, features))
return input_ops.batch_join(
features_list,
batch_size=batch_size,
capacity=queue_capacity,
enqueue_many=True,
name='parse_example_batch_join')
def read_batch_record_features(file_pattern, batch_size, features, def read_batch_record_features(file_pattern, batch_size, features,
randomize_input=True, queue_capacity=10000, randomize_input=True, num_epochs=None,
num_threads=1, name='dequeue_record_examples'): queue_capacity=10000, reader_num_threads=1,
parser_num_threads=1,
name='dequeue_record_examples'):
"""Reads TFRecord, queues, batches and parses `Example` proto. """Reads TFRecord, queues, batches and parses `Example` proto.
See more detailed description in `read_examples`. See more detailed description in `read_examples`.
@ -168,8 +198,13 @@ def read_batch_record_features(file_pattern, batch_size, features,
features: A `dict` mapping feature keys to `FixedLenFeature` or features: A `dict` mapping feature keys to `FixedLenFeature` or
`VarLenFeature` values. `VarLenFeature` values.
randomize_input: Whether the input should be randomized. randomize_input: Whether the input should be randomized.
num_epochs: Integer specifying the number of times to read through the
dataset. If None, cycles through the dataset forever. NOTE - If specified,
creates a variable that must be initialized, so call
tf.initialize_all_variables() as shown in the tests.
queue_capacity: Capacity for input queue. queue_capacity: Capacity for input queue.
num_threads: The number of threads enqueuing examples. reader_num_threads: The number of threads to read examples.
parser_num_threads: The number of threads to parse examples.
name: Name of resulting op. name: Name of resulting op.
Returns: Returns:
@ -181,5 +216,6 @@ def read_batch_record_features(file_pattern, batch_size, features,
return read_batch_features( return read_batch_features(
file_pattern=file_pattern, batch_size=batch_size, features=features, file_pattern=file_pattern, batch_size=batch_size, features=features,
reader=io_ops.TFRecordReader, reader=io_ops.TFRecordReader,
randomize_input=randomize_input, randomize_input=randomize_input, num_epochs=num_epochs,
queue_capacity=queue_capacity, num_threads=num_threads, name=name) queue_capacity=queue_capacity, reader_num_threads=reader_num_threads,
parser_num_threads=parser_num_threads, name=name)

View File

@ -17,10 +17,13 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os
import random import random
import tempfile
import tensorflow as tf import tensorflow as tf
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.platform import gfile from tensorflow.python.platform import gfile
@ -55,44 +58,83 @@ class GraphIOTest(tf.test.TestCase):
self.assertRaisesRegexp( self.assertRaisesRegexp(
ValueError, "No files match", ValueError, "No files match",
tf.contrib.learn.io.read_batch_features, tf.contrib.learn.io.read_batch_examples,
_INVALID_FILE_PATTERN, default_batch_size, None, tf.TFRecordReader, _INVALID_FILE_PATTERN, default_batch_size, tf.TFRecordReader,
False, queue_capacity, False, num_epochs=None, queue_capacity=queue_capacity,
num_threads, name) num_threads=num_threads, name=name)
self.assertRaisesRegexp( self.assertRaisesRegexp(
ValueError, "Invalid batch_size", ValueError, "Invalid batch_size",
tf.contrib.learn.io.read_batch_features, tf.contrib.learn.io.read_batch_examples,
_VALID_FILE_PATTERN, None, None, tf.TFRecordReader, _VALID_FILE_PATTERN, None, tf.TFRecordReader,
False, queue_capacity, num_threads, name) False, num_epochs=None, queue_capacity=queue_capacity,
num_threads=num_threads, name=name)
self.assertRaisesRegexp( self.assertRaisesRegexp(
ValueError, "Invalid batch_size", ValueError, "Invalid batch_size",
tf.contrib.learn.io.read_batch_features, tf.contrib.learn.io.read_batch_examples,
_VALID_FILE_PATTERN, -1, None, tf.TFRecordReader, _VALID_FILE_PATTERN, -1, tf.TFRecordReader,
False, queue_capacity, num_threads, name) False, num_epochs=None, queue_capacity=queue_capacity,
num_threads=num_threads, name=name)
self.assertRaisesRegexp( self.assertRaisesRegexp(
ValueError, "Invalid queue_capacity", ValueError, "Invalid queue_capacity",
tf.contrib.learn.io.read_batch_features, tf.contrib.learn.io.read_batch_examples,
_VALID_FILE_PATTERN, default_batch_size, None, tf.TFRecordReader, _VALID_FILE_PATTERN, default_batch_size, tf.TFRecordReader,
False, None, num_threads, name) False, num_epochs=None, queue_capacity=None,
num_threads=num_threads, name=name)
self.assertRaisesRegexp( self.assertRaisesRegexp(
ValueError, "Invalid num_threads", ValueError, "Invalid num_threads",
tf.contrib.learn.io.read_batch_features, tf.contrib.learn.io.read_batch_examples,
_VALID_FILE_PATTERN, default_batch_size, None, tf.TFRecordReader, _VALID_FILE_PATTERN, default_batch_size, tf.TFRecordReader,
False, queue_capacity, None, False, num_epochs=None, queue_capacity=queue_capacity,
name) num_threads=None, name=name)
self.assertRaisesRegexp( self.assertRaisesRegexp(
ValueError, "Invalid num_threads", ValueError, "Invalid num_threads",
tf.contrib.learn.io.read_batch_features, tf.contrib.learn.io.read_batch_examples,
_VALID_FILE_PATTERN, default_batch_size, None, tf.TFRecordReader, _VALID_FILE_PATTERN, default_batch_size, tf.TFRecordReader,
False, queue_capacity, -1, False, num_epochs=None, queue_capacity=queue_capacity,
name) num_threads=-1, name=name)
self.assertRaisesRegexp( self.assertRaisesRegexp(
ValueError, "Invalid batch_size", ValueError, "Invalid batch_size",
tf.contrib.learn.io.read_batch_features, tf.contrib.learn.io.read_batch_examples,
_VALID_FILE_PATTERN, queue_capacity + 1, None, tf.TFRecordReader, _VALID_FILE_PATTERN, queue_capacity + 1, tf.TFRecordReader,
False, queue_capacity, 1, name) False, num_epochs=None, queue_capacity=queue_capacity,
num_threads=1, name=name)
self.assertRaisesRegexp(
ValueError, "Invalid num_epochs",
tf.contrib.learn.io.read_batch_examples,
_VALID_FILE_PATTERN, default_batch_size, tf.TFRecordReader,
False, num_epochs=-1, queue_capacity=queue_capacity, num_threads=1,
name=name)
def test_batch_tf_record(self): def test_batch_record_features(self):
batch_size = 17
queue_capacity = 1234
name = "my_batch"
features = {"feature": tf.FixedLenFeature(shape=[0], dtype=tf.float32)}
with tf.Graph().as_default() as g, self.test_session(graph=g) as sess:
features = tf.contrib.learn.io.read_batch_record_features(
_VALID_FILE_PATTERN, batch_size, features, randomize_input=False,
queue_capacity=queue_capacity, reader_num_threads=2,
parser_num_threads=2, name=name)
self.assertEquals("%s/parse_example_batch_join:0" % name,
features["feature"].name)
file_name_queue_name = "%s/file_name_queue" % name
file_names_name = "%s/input" % file_name_queue_name
example_queue_name = "%s/fifo_queue" % name
parse_example_queue_name = "%s/parse_example_batch_join" % name
op_nodes = test_util.assert_ops_in_graph({
file_names_name: "Const",
file_name_queue_name: "FIFOQueue",
"%s/read/TFRecordReader" % name: "TFRecordReader",
example_queue_name: "FIFOQueue",
parse_example_queue_name: "QueueDequeueMany",
name: "QueueDequeueMany"
}, g)
self.assertAllEqual(_FILE_NAMES, sess.run(["%s:0" % file_names_name])[0])
self.assertEqual(
queue_capacity, op_nodes[example_queue_name].attr["capacity"].i)
def test_one_epoch(self):
batch_size = 17 batch_size = 17
queue_capacity = 1234 queue_capacity = 1234
name = "my_batch" name = "my_batch"
@ -100,20 +142,25 @@ class GraphIOTest(tf.test.TestCase):
with tf.Graph().as_default() as g, self.test_session(graph=g) as sess: with tf.Graph().as_default() as g, self.test_session(graph=g) as sess:
inputs = tf.contrib.learn.io.read_batch_examples( inputs = tf.contrib.learn.io.read_batch_examples(
_VALID_FILE_PATTERN, batch_size, _VALID_FILE_PATTERN, batch_size,
reader=tf.TFRecordReader, randomize_input=False, reader=tf.TFRecordReader, randomize_input=True,
num_epochs=1,
queue_capacity=queue_capacity, name=name) queue_capacity=queue_capacity, name=name)
self.assertEquals("%s:0" % name, inputs.name) self.assertEquals("%s:0" % name, inputs.name)
file_name_queue_name = "%s/file_name_queue" % name file_name_queue_name = "%s/file_name_queue" % name
file_name_queue_limit_name = (
"%s/limit_epochs/epochs" % file_name_queue_name)
file_names_name = "%s/input" % file_name_queue_name file_names_name = "%s/input" % file_name_queue_name
example_queue_name = "%s/fifo_queue" % name example_queue_name = "%s/random_shuffle_queue" % name
op_nodes = test_util.assert_ops_in_graph({ op_nodes = test_util.assert_ops_in_graph({
file_names_name: "Const", file_names_name: "Const",
file_name_queue_name: "FIFOQueue", file_name_queue_name: "FIFOQueue",
"%s/read/TFRecordReader" % name: "TFRecordReader", "%s/read/TFRecordReader" % name: "TFRecordReader",
example_queue_name: "FIFOQueue", example_queue_name: "RandomShuffleQueue",
name: "QueueDequeueMany" name: "QueueDequeueMany",
file_name_queue_limit_name: "Variable"
}, g) }, g)
self.assertAllEqual(_FILE_NAMES, sess.run(["%s:0" % file_names_name])[0]) self.assertEqual(
set(_FILE_NAMES), set(sess.run(["%s:0" % file_names_name])[0]))
self.assertEqual( self.assertEqual(
queue_capacity, op_nodes[example_queue_name].attr["capacity"].i) queue_capacity, op_nodes[example_queue_name].attr["capacity"].i)
@ -143,6 +190,34 @@ class GraphIOTest(tf.test.TestCase):
self.assertEqual( self.assertEqual(
queue_capacity, op_nodes[example_queue_name].attr["capacity"].i) queue_capacity, op_nodes[example_queue_name].attr["capacity"].i)
def test_read_csv(self):
gfile.Glob = self._orig_glob
tempdir = tempfile.mkdtemp()
filename = os.path.join(tempdir, "file.csv")
gfile.Open(filename, "w").write("ABC\nDEF\nGHK\n")
batch_size = 1
queue_capacity = 5
name = "my_batch"
with tf.Graph().as_default() as g, self.test_session(graph=g) as session:
inputs = tf.contrib.learn.io.read_batch_examples(
filename, batch_size,
reader=tf.TextLineReader, randomize_input=False,
num_epochs=1, queue_capacity=queue_capacity, name=name)
session.run(tf.initialize_all_variables())
coord = tf.train.Coordinator()
tf.train.start_queue_runners(session, coord=coord)
self.assertAllEqual(session.run(inputs), [b"ABC"])
self.assertAllEqual(session.run(inputs), [b"DEF"])
self.assertAllEqual(session.run(inputs), [b"GHK"])
with self.assertRaises(errors.OutOfRangeError):
session.run(inputs)
coord.request_stop()
if __name__ == "__main__": if __name__ == "__main__":
tf.test.main() tf.test.main()

View File

@ -26,15 +26,17 @@ namespace tensorflow {
class SquaredLossUpdater : public DualLossUpdater { class SquaredLossUpdater : public DualLossUpdater {
public: public:
// Closed form solution that decreases the dual squared loss. // Closed form solution that decreases the dual squared loss.
// See page 23 of http://arxiv.org/pdf/1309.2375v2.pdf // See page 23 of http://arxiv.org/pdf/1309.2375v2.pdf for the derivation of
// the update rule when the example weights are equal to 1.0.
// Note: There is a typo in the formula in the paper: the denominator should
// be 1 + ||x_i||^2/(\lambda n) (without the 2 multiplier).
double ComputeUpdatedDual(const double label, const double example_weight, double ComputeUpdatedDual(const double label, const double example_weight,
const double current_dual, const double wx, const double current_dual, const double wx,
const double weighted_example_norm, const double weighted_example_norm,
const double primal_loss_unused, const double primal_loss_unused,
const double dual_loss_unused) const final { const double dual_loss_unused) const final {
const double delta_numerator = (label - current_dual - wx) * example_weight; const double delta_numerator = label - current_dual - wx;
const double delta_denominator = const double delta_denominator = 1 + weighted_example_norm * example_weight;
1 + weighted_example_norm * example_weight * example_weight * 0.5;
return current_dual + delta_numerator / delta_denominator; return current_dual + delta_numerator / delta_denominator;
} }

View File

@ -455,6 +455,7 @@ class SdcaWithLogisticLossTest(SdcaOptimizerTest):
# TODO(katsiaspis): add a test for the case when examples at the end of an # TODO(katsiaspis): add a test for the case when examples at the end of an
# epoch are repeated, since example id may be duplicated. # epoch are repeated, since example id may be duplicated.
class SdcaWithLinearLossTest(SdcaOptimizerTest): class SdcaWithLinearLossTest(SdcaOptimizerTest):
"""SDCA optimizer test class for linear (squared) loss.""" """SDCA optimizer test class for linear (squared) loss."""
@ -488,9 +489,11 @@ class SdcaWithLinearLossTest(SdcaOptimizerTest):
self.assertAllClose([-20.0 / 3.0, 28.0 / 3.0], self.assertAllClose([-20.0 / 3.0, 28.0 / 3.0],
predictions.eval(), predictions.eval(),
rtol=0.005) rtol=0.005)
self.assertAllClose(0.01, # Approximate gap should be very close to 0.0. (In fact, because the gap
# is only approximate, it is likely that upon convergence the duality gap
# can have a tiny negative value).
self.assertAllClose(0.00,
lr.approximate_duality_gap().eval(), lr.approximate_duality_gap().eval(),
rtol=1e-2,
atol=1e-2) atol=1e-2)
def testL2Regularization(self): def testL2Regularization(self):
@ -580,7 +583,7 @@ class SdcaWithLinearLossTest(SdcaOptimizerTest):
{'age': [1], {'age': [1],
'gender': [1]}, 14.0, 2.0), 'gender': [1]}, 14.0, 2.0),
] ]
example_weights = [1.0, 1.0] example_weights = [5.0, 3.0]
with self._single_threaded_test_session(): with self._single_threaded_test_session():
examples = make_example_dict(example_protos, example_weights) examples = make_example_dict(example_protos, example_weights)
@ -597,20 +600,30 @@ class SdcaWithLinearLossTest(SdcaOptimizerTest):
for _ in xrange(_MAX_ITERATIONS): for _ in xrange(_MAX_ITERATIONS):
train_op.run() train_op.run()
# Predictions should be 8/9 of label due to minimizing regularized loss: # There are 4 (sparse) variable weights to be learned. 2 for age and 2 for
# (label - 2 * 2 * weight)^2 / 2 + L2 * 2 * weight^2 # gender. Let w_1, w_2 be age weights, w_3, w_4 be gender weights, y_1,
self.assertAllClose([-10.0 * 8 / 9, 14.0 * 8 / 9], # y_2 be the labels for examples 1 and 2 respectively and s_1, s_2 the
# corresponding *example* weights. With the given feature values, the loss
# function is given by:
# s_1/2(y_1 + 2w_1 + 2w_3)^2 + s_2/2(y_2 - 2w_2 - 2w_4)^2
# + \lambda/2 (w_1^2 + w_2^2 + w_3^2 + w_4^2). Solving for the optimal, it
# can be verified that:
# w_1* = w_3* = -2.0 s_1 y_1/(\lambda + 8 s_1) and
# w_2* = w_4* = 2 \cdot s_2 y_2/(\lambda + 8 s_2). Equivalently, due to
# regularization and example weights, the predictions are within:
# 8 \cdot s_i /(\lambda + 8 \cdot s_i) of the labels.
self.assertAllClose([-10 * 40.0 / 41.0, 14.0 * 24 / 25.0],
predictions.eval(), predictions.eval(),
rtol=0.07) atol=0.01)
def testDenseFeatures(self): def testDenseFeaturesWithDefaultWeights(self):
with self._single_threaded_test_session(): with self._single_threaded_test_session():
examples = make_dense_examples_dict( examples = make_dense_examples_dict(
dense_feature_values=[[-2.0, 0.0], [0.0, 2.0]], dense_feature_values=[[1.0, 0.0], [0.0, 1.0]],
weights=[1.0, 1.0], weights=[1.0, 1.0],
labels=[-10.0, 14.0]) labels=[10.0, -5.0])
variables = make_dense_variable_dict(2, 2) variables = make_dense_variable_dict(2, 2)
options = dict(symmetric_l2_regularization=1, options = dict(symmetric_l2_regularization=1.0,
symmetric_l1_regularization=0, symmetric_l1_regularization=0,
loss_type='squared_loss') loss_type='squared_loss')
lr = SdcaModel(CONTAINER, examples, variables, options) lr = SdcaModel(CONTAINER, examples, variables, options)
@ -621,14 +634,51 @@ class SdcaWithLinearLossTest(SdcaOptimizerTest):
for _ in xrange(_MAX_ITERATIONS): for _ in xrange(_MAX_ITERATIONS):
train_op.run() train_op.run()
# Predictions should be 4/5 of label due to minimizing regularized loss: # The loss function for these particular features is given by:
# (label - 2 * weight)^2 / 2 + L2 * weight^2 # 1/2(label_1-w_1)^2 + 1/2(label_2-w_2)^2 + \lambda/2 (w_1^2 + w_2^2). So,
self.assertAllClose([-10.0 * 4 / 5, 14.0 * 4 / 5], # differentiating wrt to w_1, w_2 yields the following optimal values:
# w_1* = label_1/(\lambda + 1)= 10/2, w_2* =label_2/(\lambda + 1)= -5/2.
# In this case the (unnormalized regularized) loss will be:
# 1/2(10-5)^2 + 1/2(5-5/2)^2 + 1/2(5^2 + (5/2)^2) = 125.0/4. The actual
# loss should be further normalized by the sum of example weights.
self.assertAllClose([5.0, -2.5],
predictions.eval(), predictions.eval(),
rtol=0.01) rtol=0.01)
loss = lr.regularized_loss(examples) loss = lr.regularized_loss(examples)
self.assertAllClose(148.0 / 10.0, loss.eval(), atol=0.01) self.assertAllClose(125.0 / 8.0, loss.eval(), atol=0.01)
def testDenseFeaturesWithArbitraryWeights(self):
with self._single_threaded_test_session():
examples = make_dense_examples_dict(
dense_feature_values=[[1.0, 0.0], [0.0, 1.0]],
weights=[20.0, 10.0],
labels=[10.0, -5.0])
variables = make_dense_variable_dict(2, 2)
options = dict(symmetric_l2_regularization=5.0,
symmetric_l1_regularization=0,
loss_type='squared_loss')
lr = SdcaModel(CONTAINER, examples, variables, options)
tf.initialize_all_variables().run()
predictions = lr.predictions(examples)
train_op = lr.minimize()
for _ in xrange(_MAX_ITERATIONS):
train_op.run()
# The loss function for these particular features is given by:
# 1/2 s_1 (label_1-w_1)^2 + 1/2 s_2(label_2-w_2)^2 +
# \lambda/2 (w_1^2 + w_2^2) where s_1, s_2 are the *example weights. It
# turns out that the optimal (variable) weights are given by:
# w_1* = label_1 \cdot s_1/(\lambda + s_1)= 8.0 and
# w_2* =label_2 \cdot s_2/(\lambda + s_2)= -10/3.
# In this case the (unnormalized regularized) loss will be:
# s_1/2(8-10)^2 + s_2/2(5-10/3)^2 + 5.0/2(8^2 + (10/3)^2) = 2175.0/9. The
# actual loss should be further normalized by the sum of example weights.
self.assertAllClose([8.0, -10.0/3],
predictions.eval(),
rtol=0.01)
loss = lr.regularized_loss(examples)
self.assertAllClose(2175.0 / 270.0, loss.eval(), atol=0.01)
class SdcaWithHingeLossTest(SdcaOptimizerTest): class SdcaWithHingeLossTest(SdcaOptimizerTest):

View File

@ -19,7 +19,10 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.contrib.losses.python.losses.loss_ops import absolute_difference from tensorflow.contrib.losses.python.losses.loss_ops import absolute_difference
from tensorflow.contrib.losses.python.losses.loss_ops import add_loss
from tensorflow.contrib.losses.python.losses.loss_ops import cosine_distance from tensorflow.contrib.losses.python.losses.loss_ops import cosine_distance
from tensorflow.contrib.losses.python.losses.loss_ops import get_losses
from tensorflow.contrib.losses.python.losses.loss_ops import get_total_loss
from tensorflow.contrib.losses.python.losses.loss_ops import log from tensorflow.contrib.losses.python.losses.loss_ops import log
from tensorflow.contrib.losses.python.losses.loss_ops import sigmoid_cross_entropy from tensorflow.contrib.losses.python.losses.loss_ops import sigmoid_cross_entropy
from tensorflow.contrib.losses.python.losses.loss_ops import softmax_cross_entropy from tensorflow.contrib.losses.python.losses.loss_ops import softmax_cross_entropy

View File

@ -104,9 +104,11 @@ weighted average over the individual prediction errors:
weight = tf.div(weight, tf.size(weight)) weight = tf.div(weight, tf.size(weight))
loss = tf.contrib.losses.sum_of_squares(predictions, depths, weight) loss = tf.contrib.losses.sum_of_squares(predictions, depths, weight)
@@absolute_difference @@absolute_difference
@@add_loss
@@cosine_distance @@cosine_distance
@@get_losses
@@get_total_loss
@@log @@log
@@sigmoid_cross_entropy @@sigmoid_cross_entropy
@@softmax_cross_entropy @@softmax_cross_entropy
@ -252,6 +254,61 @@ def _num_present(losses, weight, per_batch=False):
return num_per_batch if per_batch else math_ops.reduce_sum(num_per_batch) return num_per_batch if per_batch else math_ops.reduce_sum(num_per_batch)
def add_loss(loss):
"""Adds a externally defined loss to collection of losses.
Args:
loss: A loss `Tensor`.
"""
ops.add_to_collection(ops.GraphKeys.LOSSES, loss)
def get_losses(scope=None):
"""Gets the list of loss variables.
Args:
scope: an optional scope for filtering the losses to return.
Returns:
a list of loss variables.
"""
return ops.get_collection(ops.GraphKeys.LOSSES, scope)
def get_regularization_losses(scope=None):
"""Gets the regularization losses.
Args:
scope: an optional scope for filtering the losses to return.
Returns:
A list of loss variables.
"""
return ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES, scope)
def get_total_loss(add_regularization_losses=True, name="total_loss"):
"""Returns a tensor whose value represents the total loss.
Notice that the function adds the given losses to the regularization losses.
Args:
add_regularization_losses: A boolean indicating whether or not to use the
regularization losses in the sum.
name: The name of the returned tensor.
Returns:
A `Tensor` whose value represents the total loss.
Raises:
ValueError: if `losses` is not iterable.
"""
losses = get_losses()
if add_regularization_losses:
losses += get_regularization_losses()
return math_ops.add_n(losses, name=name)
def absolute_difference(predictions, targets, weight=1.0, scope=None): def absolute_difference(predictions, targets, weight=1.0, scope=None):
"""Adds an Absolute Difference loss to the training procedure. """Adds an Absolute Difference loss to the training procedure.

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/simple_placer.h" #include "tensorflow/core/common_runtime/simple_placer.h"
#include <memory> #include <memory>
#include <set>
#include <utility> #include <utility>
#include <vector> #include <vector>
@ -182,6 +183,7 @@ class ColocationGraph {
Status ColocateNodes(const Node& x, const Node& y) { Status ColocateNodes(const Node& x, const Node& y) {
int x_root = FindRoot(x.id()); int x_root = FindRoot(x.id());
int y_root = FindRoot(y.id()); int y_root = FindRoot(y.id());
Status s; Status s;
if (x_root != y_root) { if (x_root != y_root) {
// Merge the sets by swinging the parent pointer of the smaller // Merge the sets by swinging the parent pointer of the smaller
@ -229,6 +231,12 @@ class ColocationGraph {
s.error_message()); s.error_message());
} }
// Transfer ids in the old group to the new one.
members_[new_root].ids_in_group.insert(
members_[old_root].ids_in_group.begin(),
members_[old_root].ids_in_group.end());
members_[old_root].ids_in_group.clear();
// Ensure that the common root has at least one supported device // Ensure that the common root has at least one supported device
// type, by computing the intersection of // type, by computing the intersection of
// members_[new_root].supported_device_types and // members_[new_root].supported_device_types and
@ -267,6 +275,9 @@ class ColocationGraph {
return Status::OK(); return Status::OK();
} }
// String containing additional debugging info on failures.
string debug_info;
// We have not yet computed the possible devices for the // We have not yet computed the possible devices for the
// colocated node set containing 'node', so we do so now using the // colocated node set containing 'node', so we do so now using the
// constraints on the root node. // constraints on the root node.
@ -310,6 +321,8 @@ class ColocationGraph {
// Return an error when a physical device that matches an explicit // Return an error when a physical device that matches an explicit
// device specification is not found. This ensures that we don't // device specification is not found. This ensures that we don't
// assign a node to GPU when the user wanted to force it on CPU. // assign a node to GPU when the user wanted to force it on CPU.
AddDebugInfo(node_root, &debug_info);
DeviceNameUtils::ParsedName specified_device_name; DeviceNameUtils::ParsedName specified_device_name;
if (DeviceNameUtils::ParseFullName(node->def().device(), if (DeviceNameUtils::ParseFullName(node->def().device(),
&specified_device_name) && &specified_device_name) &&
@ -334,16 +347,17 @@ class ColocationGraph {
node->def().device(), node->def().device(),
"' because no devices matching that specification " "' because no devices matching that specification "
"are registered in this process; available devices: ", "are registered in this process; available devices: ",
str_util::Join(device_names, ", ")); str_util::Join(device_names, ", "), debug_info);
} else if (specified_device_name.has_type) { } else if (specified_device_name.has_type) {
return errors::InvalidArgument( return errors::InvalidArgument(
"Could not satisfy explicit device specification '", "Could not satisfy explicit device specification '",
node->def().device(), "' because no supported kernel for ", node->def().device(), "' because no supported kernel for ",
specified_device_name.type, " devices is available"); specified_device_name.type, " devices is available.",
debug_info);
} else { } else {
return errors::InvalidArgument( return errors::InvalidArgument(
"Could not satisfy explicit device specification '", "Could not satisfy explicit device specification '",
node->def().device()); node->def().device(), debug_info);
} }
} else { } else {
// The specified device may be a valid device but the // The specified device may be a valid device but the
@ -355,7 +369,7 @@ class ColocationGraph {
"required incompatible device '", "required incompatible device '",
DeviceNameUtils::ParsedNameToString( DeviceNameUtils::ParsedNameToString(
members_[node_root].device_name), members_[node_root].device_name),
"'"); "'", debug_info);
} }
} }
} else { } else {
@ -368,10 +382,11 @@ class ColocationGraph {
device_set_->devices(), members_[node_root].supported_device_types); device_set_->devices(), members_[node_root].supported_device_types);
if (devices.empty()) { if (devices.empty()) {
AddDebugInfo(node_root, &debug_info);
return errors::InvalidArgument( return errors::InvalidArgument(
"Node had no OpKernel registered to support this operation: ", "Node had no OpKernel registered to support this operation: ",
"Operation was ", node->type_string(), " and inputs were ", "Operation was ", node->type_string(), " and inputs were ",
DataTypeVectorString(node->input_types())); DataTypeVectorString(node->input_types()), debug_info);
} }
} }
@ -390,6 +405,15 @@ class ColocationGraph {
// id if it is a root. parent <= 0 indicates that this member is invalid. // id if it is a root. parent <= 0 indicates that this member is invalid.
int parent = -1; int parent = -1;
// The set of ids that are part of the disjoint node set forest.
//
// This is only fully specified in the root of a disjoint
// node set forest.
std::set<int> ids_in_group;
// The type of the op for this node.
string op_type;
// A proxy for the depth of the tree that is used to prefer // A proxy for the depth of the tree that is used to prefer
// connecting smaller trees to larger trees when merging disjoint // connecting smaller trees to larger trees when merging disjoint
// sets. // sets.
@ -410,8 +434,41 @@ class ColocationGraph {
std::vector<Device*> possible_devices; std::vector<Device*> possible_devices;
}; };
// Adds debugging info to 'output' for the node referred to by
// 'node_root'.
void AddDebugInfo(const int node_root, string* output) {
if (members_[node_root].ids_in_group.size() > 1) {
strings::StrAppend(output, "\nColocation Debug Info:\n");
// If this node is part of a colocation group, then we want to
// collect the mapping of ops to supported devices, so that
// the user can see why an unsatisfiable placement occurred.
strings::StrAppend(
output, "Colocation group had the following types and devices: ");
std::unordered_map<string, string> type_to_devices;
for (const int id : members_[node_root].ids_in_group) {
const string& op_type = members_[id].op_type;
string devices_registered;
for (const auto& device_type : members_[id].supported_device_types) {
strings::StrAppend(&devices_registered, DeviceTypeString(device_type),
" ");
}
type_to_devices[op_type] = devices_registered;
}
for (const auto& td : type_to_devices) {
strings::StrAppend(output, "\n", td.first, ": ", td.second);
}
}
}
Status InitializeMember(const Node& node, Member* member) { Status InitializeMember(const Node& node, Member* member) {
const int id = node.id(); const int id = node.id();
member->ids_in_group.insert(id);
member->op_type = node.type_string();
if (id < 0) { if (id < 0) {
return errors::InvalidArgument("Node id was not positive: ", id); return errors::InvalidArgument("Node id was not positive: ", id);
} }

View File

@ -729,6 +729,12 @@ TEST_F(SimplePlacerTest, TestHeterogeneousDeviceSetFailure) {
EXPECT_TRUE(StringPiece(s.error_message()) EXPECT_TRUE(StringPiece(s.error_message())
.contains("colocated with a group of nodes that required " .contains("colocated with a group of nodes that required "
"incompatible device")); "incompatible device"));
// The error message should contain information that indicates which
// op types have which registered device types.
EXPECT_TRUE(StringPiece(s.error_message()).contains("VariableGPU: GPU")) << s;
EXPECT_TRUE(StringPiece(s.error_message()).contains("TestAssign: GPU CPU"))
<< s;
} }
// Test that placement fails when an unknown device is requested. // Test that placement fails when an unknown device is requested.

View File

@ -13,75 +13,68 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/core/framework/op.h"
#include "third_party/eigen3/Eigen/Core" #include "third_party/eigen3/Eigen/Core"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/kernels/linalg_ops_common.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/binary_linalg_ops_common.h"
namespace tensorflow { namespace tensorflow {
template <typename T> template <typename Scalar, bool SupportsBatchOperationT>
class CholeskyGrad : public OpKernel { class CholeskyGrad
: public BinaryLinearAlgebraOp<Scalar, SupportsBatchOperationT> {
public: public:
explicit CholeskyGrad(OpKernelConstruction* context) : OpKernel(context) {} explicit CholeskyGrad(OpKernelConstruction* context)
: BinaryLinearAlgebraOp<Scalar, SupportsBatchOperationT>(context) {}
~CholeskyGrad() override {}
using Matrix = using Matrix =
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>; Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
using ConstMatrixMap = Eigen::Map<const Matrix>; using ConstMatrixMap = Eigen::Map<const Matrix>;
using MatrixMap = Eigen::Map<Matrix>; using MatrixMap = Eigen::Map<Matrix>;
using ConstRef = Eigen::Ref<const Matrix>; using ConstRef = Eigen::Ref<const Matrix>;
using Ref = Eigen::Ref<Matrix>; using Ref = Eigen::Ref<Matrix>;
void Compute(OpKernelContext* context) override { TensorShape GetOutputMatrixShape(
const Tensor& input_tensor_l = context->input(0); const TensorShape& input_matrix_l_full_shape,
const Tensor& input_tensor_grad = context->input(1); const TensorShape& input_matrix_grad_shape) override {
// Check that input tensors represent a matrix. return input_matrix_l_full_shape;
OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_tensor_l.shape()), }
errors::InvalidArgument("In[0] is not a matrix"));
OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_tensor_grad.shape()),
errors::InvalidArgument("In[1] is not a matrix"));
// Check that input tensors are square.
OP_REQUIRES(context,
input_tensor_l.dim_size(0) == input_tensor_l.dim_size(1),
errors::InvalidArgument("Input matrix must be square."));
OP_REQUIRES(context,
input_tensor_grad.dim_size(0) == input_tensor_grad.dim_size(1),
errors::InvalidArgument("Input matrix must be square."));
// Check that input tensors are of same size. int64 GetCostPerUnit(const TensorShape& input_matrix_shape,
OP_REQUIRES(context, const TensorShape& rhs_matrix_shape) override {
input_tensor_l.dim_size(0) == input_tensor_grad.dim_size(0), const int64 rows = input_matrix_shape.dim_size(0);
errors::InvalidArgument("Input matrices must be same size.")); if (rows > (1LL << 20)) {
// A big number to cap the cost in case overflow.
// Create an output tensor return kint64max;
Tensor* output_tensor = NULL; } else {
OP_REQUIRES_OK(context, context->allocate_output( return rows * rows * rows;
0, input_tensor_grad.shape(), &output_tensor));
if (output_tensor->NumElements() == 0) {
// the output shape is a 0-element matrix, so there is nothing to do.
return;
} }
// The next lines are necessary to get Eigen matrix behaviour. }
const ConstMatrixMap input_matrix_l_full(input_tensor_l.flat<T>().data(),
input_tensor_l.dim_size(0),
input_tensor_l.dim_size(1));
const ConstMatrixMap input_matrix_grad(input_tensor_grad.flat<T>().data(),
input_tensor_grad.dim_size(0),
input_tensor_grad.dim_size(1));
MatrixMap output_matrix(output_tensor->template flat<T>().data(),
input_tensor_l.dim_size(0),
input_tensor_l.dim_size(1));
// Algorithm only depends on lower triangular half on input_tensor_l. void ComputeMatrix(OpKernelContext* context,
const ConstMatrixMap& input_matrix_l_full,
const ConstMatrixMap& input_matrix_grad,
MatrixMap* output_matrix) override {
OP_REQUIRES(context,
input_matrix_l_full.rows() == input_matrix_l_full.cols(),
errors::InvalidArgument("Input matrix must be square."));
OP_REQUIRES(
context, input_matrix_l_full.cols() == input_matrix_grad.cols(),
errors::InvalidArgument(
"Input matrix and gradient must have same number of cols."));
OP_REQUIRES(
context, input_matrix_l_full.rows() == input_matrix_grad.rows(),
errors::InvalidArgument(
"Input matrix and gradient must have same number of rows."));
// Algorithm only depends on lower triangular half on input_matrix_l.
const Matrix input_matrix_l = const Matrix input_matrix_l =
input_matrix_l_full.template triangularView<Eigen::Lower>(); input_matrix_l_full.template triangularView<Eigen::Lower>();
// Algorithm only depends on lower triangular half on input_matrix_grad. // Algorithm only depends on lower triangular half on input_matrix_grad.
output_matrix = input_matrix_grad.template triangularView<Eigen::Lower>(); *output_matrix = input_matrix_grad.template triangularView<Eigen::Lower>();
const int64 kMatrixSize = input_matrix_l.rows(); const int64 kMatrixSize = input_matrix_l.rows();
const int64 kMaxBlockSize = 32; const int64 kMaxBlockSize = 32;
@ -104,20 +97,21 @@ class CholeskyGrad : public OpKernel {
auto B = input_matrix_l.block(block_end, 0, trailing_size, block_begin); auto B = input_matrix_l.block(block_end, 0, trailing_size, block_begin);
auto B_bar = auto B_bar =
output_matrix.block(block_end, 0, trailing_size, block_begin); output_matrix->block(block_end, 0, trailing_size, block_begin);
auto C = input_matrix_l.block(block_end, block_begin, trailing_size, auto C = input_matrix_l.block(block_end, block_begin, trailing_size,
block_size); block_size);
auto C_bar = output_matrix.block(block_end, block_begin, trailing_size, auto C_bar = output_matrix->block(block_end, block_begin, trailing_size,
block_size); block_size);
auto D = input_matrix_l.block(block_begin, block_begin, block_size, auto D = input_matrix_l.block(block_begin, block_begin, block_size,
block_size); block_size);
auto D_bar = auto D_bar = output_matrix->block(block_begin, block_begin, block_size,
output_matrix.block(block_begin, block_begin, block_size, block_size); block_size);
auto R = input_matrix_l.block(block_begin, 0, block_size, block_begin); auto R = input_matrix_l.block(block_begin, 0, block_size, block_begin);
auto R_bar = output_matrix.block(block_begin, 0, block_size, block_begin); auto R_bar =
output_matrix->block(block_begin, 0, block_size, block_begin);
C_bar = D.adjoint().template triangularView<Eigen::Upper>() C_bar = D.adjoint().template triangularView<Eigen::Upper>()
.solve(C_bar.adjoint()).adjoint(); .solve(C_bar.adjoint()).adjoint();
@ -127,9 +121,11 @@ class CholeskyGrad : public OpKernel {
CholeskyGradUnblocked(D, D_bar); CholeskyGradUnblocked(D, D_bar);
R_bar -= (D_bar + D_bar.adjoint()) * R; R_bar -= (D_bar + D_bar.adjoint()) * R;
} }
output_matrix = (0.5 * (output_matrix + output_matrix.transpose())).eval(); *output_matrix =
(0.5 * (*output_matrix + output_matrix->transpose())).eval();
} }
void CholeskyGradUnblocked(const ConstRef l_block, Ref grad_block) {
void CholeskyGradUnblocked(const ConstRef& l_block, Ref grad_block) {
const int64 kMatrixSize = l_block.rows(); const int64 kMatrixSize = l_block.rows();
for (int64 k = kMatrixSize - 1; k >= 0; k--) { for (int64 k = kMatrixSize - 1; k >= 0; k--) {
/* This shows the block structure. /* This shows the block structure.
@ -166,6 +162,11 @@ class CholeskyGrad : public OpKernel {
} }
}; };
REGISTER_LINALG_OP("CholeskyGrad", (CholeskyGrad<float>), float); REGISTER_BINARY_LINALG_OP("CholeskyGrad", (CholeskyGrad<float, false>), float);
REGISTER_LINALG_OP("CholeskyGrad", (CholeskyGrad<double>), double); REGISTER_BINARY_LINALG_OP("CholeskyGrad", (CholeskyGrad<double, false>),
double);
REGISTER_BINARY_LINALG_OP("BatchCholeskyGrad", (CholeskyGrad<float, true>),
float);
REGISTER_BINARY_LINALG_OP("BatchCholeskyGrad", (CholeskyGrad<double, true>),
double);
} // namespace tensorflow } // namespace tensorflow

View File

@ -64,8 +64,7 @@ class CudnnScratchAllocator : public perftools::gputools::ScratchAllocator {
AllocatorAttributes(), allocation_attr)); AllocatorAttributes(), allocation_attr));
if (!allocation_status.ok()) { if (!allocation_status.ok()) {
return perftools::gputools::port::StatusOr< return perftools::gputools::port::StatusOr<
perftools::gputools::DeviceMemory<uint8>>( perftools::gputools::DeviceMemory<uint8>>();
AsDeviceMemory<uint8>(nullptr, 0));
} }
// Hold the reference of the allocated tensors until the end of the // Hold the reference of the allocated tensors until the end of the
// allocator. // allocator.

View File

@ -305,7 +305,7 @@ void SpatialAvgPool(OpKernelContext* context, Tensor* output,
const int out_offset = const int out_offset =
(b * params.out_height + ph) * params.out_width + pw; (b * params.out_height + ph) * params.out_width + pw;
out_mat.col(out_offset) += in_mat.col(in_offset); out_mat.col(out_offset) += in_mat.col(in_offset);
out_count(out_offset)++; out_count(out_offset) += T(1);
} }
} }
} }

View File

@ -3175,6 +3175,31 @@ op {
} }
} }
} }
op {
name: "BatchCholeskyGrad"
input_arg {
name: "l"
type_attr: "T"
}
input_arg {
name: "grad"
type_attr: "T"
}
output_arg {
name: "output"
type_attr: "T"
}
attr {
name: "T"
type: "type"
allowed_values {
list {
type: DT_FLOAT
type: DT_DOUBLE
}
}
}
}
op { op {
name: "BatchFFT" name: "BatchFFT"
input_arg { input_arg {

View File

@ -129,11 +129,34 @@ REGISTER_OP("CholeskyGrad")
.Doc(R"doc( .Doc(R"doc(
Calculates the reverse mode backpropagated gradient of the Cholesky algorithm. Calculates the reverse mode backpropagated gradient of the Cholesky algorithm.
For an explanation see "Differentiation of the Cholesky algorithm" by Iain Murray http://arxiv.org/abs/1602.07527. For an explanation see "Differentiation of the Cholesky algorithm" by
Iain Murray http://arxiv.org/abs/1602.07527.
l: Output of Cholesky algorithm l = chol(A). Shape is `[M, M]`. Algorithm depends only on lower triangular part of this matrix. l: Output of Cholesky algorithm l = chol(A). Shape is `[M, M]`.
grad: df/dl where f is some scalar function. Shape is `[M, M]'. Algorithm depends only on lower triangular part of this matrix. Algorithm depends only on lower triangular part of this matrix.
output: Symmetrized version of df/dA . Shape is `[M, M]' grad: df/dl where f is some scalar function. Shape is `[M, M]'.
Algorithm depends only on lower triangular part of this matrix.
output: Symmetrized version of df/dA . Shape is `[M, M]'.
)doc");
REGISTER_OP("BatchCholeskyGrad")
.Input("l: T")
.Input("grad: T")
.Output("output: T")
.Attr("T: {float, double}")
.Doc(R"doc(
Calculates the reverse mode backpropagated gradient of the Cholesky algorithm.
For an explanation see "Differentiation of the Cholesky algorithm" by
Iain Murray http://arxiv.org/abs/1602.07527.
l: Output of batch Cholesky algorithm l = batch_cholesky(A). Shape is `[..., M, M]`.
Algorithm depends only on lower triangular part of the innermost matrices of
this tensor.
grad: df/dl where f is some scalar function. Shape is `[..., M, M]'.
Algorithm depends only on lower triangular part of the innermost matrices of
this tensor.
output: Symmetrized version of df/dA . Shape is `[..., M, M]'
)doc"); )doc");
REGISTER_OP("SelfAdjointEig") REGISTER_OP("SelfAdjointEig")

View File

@ -1397,6 +1397,36 @@ op {
summary: "Calculates the Cholesky decomposition of a batch of square matrices." summary: "Calculates the Cholesky decomposition of a batch of square matrices."
description: "The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions\nform square matrices, with the same constraints as the single matrix Cholesky\ndecomposition above. The output is a tensor of the same shape as the input\ncontaining the Cholesky decompositions for all input submatrices `[..., :, :]`." description: "The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions\nform square matrices, with the same constraints as the single matrix Cholesky\ndecomposition above. The output is a tensor of the same shape as the input\ncontaining the Cholesky decompositions for all input submatrices `[..., :, :]`."
} }
op {
name: "BatchCholeskyGrad"
input_arg {
name: "l"
description: "Output of batch Cholesky algorithm l = batch_cholesky(A). Shape is `[..., M, M]`.\nAlgorithm depends only on lower triangular part of the innermost matrices of\nthis tensor."
type_attr: "T"
}
input_arg {
name: "grad"
description: "df/dl where f is some scalar function. Shape is `[..., M, M]\'.\nAlgorithm depends only on lower triangular part of the innermost matrices of\nthis tensor."
type_attr: "T"
}
output_arg {
name: "output"
description: "Symmetrized version of df/dA . Shape is `[..., M, M]\'"
type_attr: "T"
}
attr {
name: "T"
type: "type"
allowed_values {
list {
type: DT_FLOAT
type: DT_DOUBLE
}
}
}
summary: "Calculates the reverse mode backpropagated gradient of the Cholesky algorithm."
description: "For an explanation see \"Differentiation of the Cholesky algorithm\" by\nIain Murray http://arxiv.org/abs/1602.07527."
}
op { op {
name: "BatchFFT" name: "BatchFFT"
input_arg { input_arg {
@ -2482,17 +2512,17 @@ op {
name: "CholeskyGrad" name: "CholeskyGrad"
input_arg { input_arg {
name: "l" name: "l"
description: "Output of Cholesky algorithm l = chol(A). Shape is `[M, M]`. Algorithm depends only on lower triangular part of this matrix." description: "Output of Cholesky algorithm l = chol(A). Shape is `[M, M]`.\nAlgorithm depends only on lower triangular part of this matrix."
type_attr: "T" type_attr: "T"
} }
input_arg { input_arg {
name: "grad" name: "grad"
description: "df/dl where f is some scalar function. Shape is `[M, M]\'. Algorithm depends only on lower triangular part of this matrix." description: "df/dl where f is some scalar function. Shape is `[M, M]\'.\nAlgorithm depends only on lower triangular part of this matrix."
type_attr: "T" type_attr: "T"
} }
output_arg { output_arg {
name: "output" name: "output"
description: "Symmetrized version of df/dA . Shape is `[M, M]\'" description: "Symmetrized version of df/dA . Shape is `[M, M]\'."
type_attr: "T" type_attr: "T"
} }
attr { attr {
@ -2506,7 +2536,7 @@ op {
} }
} }
summary: "Calculates the reverse mode backpropagated gradient of the Cholesky algorithm." summary: "Calculates the reverse mode backpropagated gradient of the Cholesky algorithm."
description: "For an explanation see \"Differentiation of the Cholesky algorithm\" by Iain Murray http://arxiv.org/abs/1602.07527." description: "For an explanation see \"Differentiation of the Cholesky algorithm\" by\nIain Murray http://arxiv.org/abs/1602.07527."
} }
op { op {
name: "Complex" name: "Complex"
@ -11482,7 +11512,7 @@ op {
} }
} }
summary: "Computes the sum of elements across dimensions of a SparseTensor." summary: "Computes the sum of elements across dimensions of a SparseTensor."
description: "This Op takes a SparseTensor and is the sparse counterpart to\n`tf.reduce_sum()`. In particular, this Op also returns a dense `Tensor`\ninstead of a sparse one.\n\nReduces `sp_input` along the dimensions given in `reduction_axes`. Unless\n`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in\n`reduction_axes`. If `keep_dims` is true, the reduced dimensions are retained\nwith length 1.\n\nIf `reduction_axes` has no entries, all dimensions are reduced, and a tensor\nwith a single element is returned." description: "This Op takes a SparseTensor and is the sparse counterpart to\n`tf.reduce_sum()`. In particular, this Op also returns a dense `Tensor`\ninstead of a sparse one.\n\nReduces `sp_input` along the dimensions given in `reduction_axes`. Unless\n`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in\n`reduction_axes`. If `keep_dims` is true, the reduced dimensions are retained\nwith length 1.\n\nIf `reduction_axes` has no entries, all dimensions are reduced, and a tensor\nwith a single element is returned. Additionally, the axes can be negative,\nwhich are interpreted according to the indexing rules in Python."
} }
op { op {
name: "SparseReorder" name: "SparseReorder"

View File

@ -52,11 +52,11 @@ def train():
# Input placehoolders # Input placehoolders
with tf.name_scope('input'): with tf.name_scope('input'):
x = tf.placeholder(tf.float32, [None, 784], name='x-input') x = tf.placeholder(tf.float32, [None, 784], name='x-input')
y_ = tf.placeholder(tf.float32, [None, 10], name='y-input')
with tf.name_scope('input_reshape'):
image_shaped_input = tf.reshape(x, [-1, 28, 28, 1]) image_shaped_input = tf.reshape(x, [-1, 28, 28, 1])
tf.image_summary('input', image_shaped_input, 10) tf.image_summary('input', image_shaped_input, 10)
y_ = tf.placeholder(tf.float32, [None, 10], name='y-input')
keep_prob = tf.placeholder(tf.float32)
tf.scalar_summary('dropout_keep_probability', keep_prob)
# We can't initialize these variables to 0 - the network will get stuck. # We can't initialize these variables to 0 - the network will get stuck.
def weight_variable(shape): def weight_variable(shape):
@ -105,7 +105,12 @@ def train():
return activations return activations
hidden1 = nn_layer(x, 784, 500, 'layer1') hidden1 = nn_layer(x, 784, 500, 'layer1')
dropped = tf.nn.dropout(hidden1, keep_prob)
with tf.name_scope('dropout'):
keep_prob = tf.placeholder(tf.float32)
tf.scalar_summary('dropout_keep_probability', keep_prob)
dropped = tf.nn.dropout(hidden1, keep_prob)
y = nn_layer(dropped, 500, 10, 'layer2', act=tf.nn.softmax) y = nn_layer(dropped, 500, 10, 'layer2', act=tf.nn.softmax)
with tf.name_scope('cross_entropy'): with tf.name_scope('cross_entropy'):
@ -151,9 +156,20 @@ def train():
summary, acc = sess.run([merged, accuracy], feed_dict=feed_dict(False)) summary, acc = sess.run([merged, accuracy], feed_dict=feed_dict(False))
test_writer.add_summary(summary, i) test_writer.add_summary(summary, i)
print('Accuracy at step %s: %s' % (i, acc)) print('Accuracy at step %s: %s' % (i, acc))
else: # Record train set summarieis, and train else: # Record train set summaries, and train
summary, _ = sess.run([merged, train_step], feed_dict=feed_dict(True)) if i % 100 == 99: # Record execution stats
train_writer.add_summary(summary, i) run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf.RunMetadata()
summary, _ = sess.run([merged, train_step],
feed_dict=feed_dict(True),
options=run_options,
run_metadata=run_metadata)
train_writer.add_run_metadata(run_metadata, 'step%d' % i)
train_writer.add_summary(summary, i)
print('Adding run metadata for', i)
else: # Record a summary
summary, _ = sess.run([merged, train_step], feed_dict=feed_dict(True))
train_writer.add_summary(summary, i)
def main(_): def main(_):

View File

@ -1338,9 +1338,9 @@ Variance of each batch member.
- - - - - -
### `class tf.contrib.distributions.Gaussian` {#Gaussian} ### `class tf.contrib.distributions.Normal` {#Normal}
The scalar Gaussian distribution with mean and stddev parameters mu, sigma. The scalar Normal distribution with mean and stddev parameters mu, sigma.
#### Mathematical details #### Mathematical details
@ -1353,15 +1353,15 @@ The PDF of this distribution is:
Examples of initialization of one or a batch of distributions. Examples of initialization of one or a batch of distributions.
```python ```python
# Define a single scalar Gaussian distribution. # Define a single scalar Normal distribution.
dist = tf.contrib.distributions.Gaussian(mu=0, sigma=3) dist = tf.contrib.distributions.Normal(mu=0, sigma=3)
# Evaluate the cdf at 1, returning a scalar. # Evaluate the cdf at 1, returning a scalar.
dist.cdf(1) dist.cdf(1)
# Define a batch of two scalar valued Gaussians. # Define a batch of two scalar valued Normals.
# The first has mean 1 and standard deviation 11, the second 2 and 22. # The first has mean 1 and standard deviation 11, the second 2 and 22.
dist = tf.contrib.distributions.Gaussian(mu=[1, 2.], sigma=[11, 22.]) dist = tf.contrib.distributions.Normal(mu=[1, 2.], sigma=[11, 22.])
# Evaluate the pdf of the first distribution on 0, and the second on 1.5, # Evaluate the pdf of the first distribution on 0, and the second on 1.5,
# returning a length two tensor. # returning a length two tensor.
@ -1374,9 +1374,9 @@ dist.sample(3)
Arguments are broadcast when possible. Arguments are broadcast when possible.
```python ```python
# Define a batch of two scalar valued Gaussians. # Define a batch of two scalar valued Normals.
# Both have mean 1, but different standard deviations. # Both have mean 1, but different standard deviations.
dist = tf.contrib.distributions.Gaussian(mu=1, sigma=[11, 22.]) dist = tf.contrib.distributions.Normal(mu=1, sigma=[11, 22.])
# Evaluate the pdf of both distributions on the same point, 3.0, # Evaluate the pdf of both distributions on the same point, 3.0,
# returning a length 2 tensor. # returning a length 2 tensor.
@ -1384,9 +1384,9 @@ dist.pdf(3.0)
``` ```
- - - - - -
#### `tf.contrib.distributions.Gaussian.__init__(mu, sigma, name=None)` {#Gaussian.__init__} #### `tf.contrib.distributions.Normal.__init__(mu, sigma, name=None)` {#Normal.__init__}
Construct Gaussian distributions with mean and stddev `mu` and `sigma`. Construct Normal distributions with mean and stddev `mu` and `sigma`.
The parameters `mu` and `sigma` must be shaped in a way that supports The parameters `mu` and `sigma` must be shaped in a way that supports
broadcasting (e.g. `mu + sigma` is a valid operation). broadcasting (e.g. `mu + sigma` is a valid operation).
@ -1407,9 +1407,9 @@ broadcasting (e.g. `mu + sigma` is a valid operation).
- - - - - -
#### `tf.contrib.distributions.Gaussian.cdf(x, name=None)` {#Gaussian.cdf} #### `tf.contrib.distributions.Normal.cdf(x, name=None)` {#Normal.cdf}
CDF of observations in `x` under these Gaussian distribution(s). CDF of observations in `x` under these Normal distribution(s).
##### Args: ##### Args:
@ -1425,16 +1425,16 @@ CDF of observations in `x` under these Gaussian distribution(s).
- - - - - -
#### `tf.contrib.distributions.Gaussian.dtype` {#Gaussian.dtype} #### `tf.contrib.distributions.Normal.dtype` {#Normal.dtype}
- - - - - -
#### `tf.contrib.distributions.Gaussian.entropy(name=None)` {#Gaussian.entropy} #### `tf.contrib.distributions.Normal.entropy(name=None)` {#Normal.entropy}
The entropy of Gaussian distribution(s). The entropy of Normal distribution(s).
##### Args: ##### Args:
@ -1449,16 +1449,16 @@ The entropy of Gaussian distribution(s).
- - - - - -
#### `tf.contrib.distributions.Gaussian.is_reparameterized` {#Gaussian.is_reparameterized} #### `tf.contrib.distributions.Normal.is_reparameterized` {#Normal.is_reparameterized}
- - - - - -
#### `tf.contrib.distributions.Gaussian.log_cdf(x, name=None)` {#Gaussian.log_cdf} #### `tf.contrib.distributions.Normal.log_cdf(x, name=None)` {#Normal.log_cdf}
Log CDF of observations `x` under these Gaussian distribution(s). Log CDF of observations `x` under these Normal distribution(s).
##### Args: ##### Args:
@ -1474,9 +1474,9 @@ Log CDF of observations `x` under these Gaussian distribution(s).
- - - - - -
#### `tf.contrib.distributions.Gaussian.log_pdf(x, name=None)` {#Gaussian.log_pdf} #### `tf.contrib.distributions.Normal.log_pdf(x, name=None)` {#Normal.log_pdf}
Log pdf of observations in `x` under these Gaussian distribution(s). Log pdf of observations in `x` under these Normal distribution(s).
##### Args: ##### Args:
@ -1492,23 +1492,23 @@ Log pdf of observations in `x` under these Gaussian distribution(s).
- - - - - -
#### `tf.contrib.distributions.Gaussian.mean` {#Gaussian.mean} #### `tf.contrib.distributions.Normal.mean` {#Normal.mean}
- - - - - -
#### `tf.contrib.distributions.Gaussian.mu` {#Gaussian.mu} #### `tf.contrib.distributions.Normal.mu` {#Normal.mu}
- - - - - -
#### `tf.contrib.distributions.Gaussian.pdf(x, name=None)` {#Gaussian.pdf} #### `tf.contrib.distributions.Normal.pdf(x, name=None)` {#Normal.pdf}
The PDF of observations in `x` under these Gaussian distribution(s). The PDF of observations in `x` under these Normal distribution(s).
##### Args: ##### Args:
@ -1524,9 +1524,9 @@ The PDF of observations in `x` under these Gaussian distribution(s).
- - - - - -
#### `tf.contrib.distributions.Gaussian.sample(n, seed=None, name=None)` {#Gaussian.sample} #### `tf.contrib.distributions.Normal.sample(n, seed=None, name=None)` {#Normal.sample}
Sample `n` observations from the Gaussian Distributions. Sample `n` observations from the Normal Distributions.
##### Args: ##### Args:
@ -1544,7 +1544,7 @@ Sample `n` observations from the Gaussian Distributions.
- - - - - -
#### `tf.contrib.distributions.Gaussian.sigma` {#Gaussian.sigma} #### `tf.contrib.distributions.Normal.sigma` {#Normal.sigma}
@ -2443,26 +2443,26 @@ probability includes a combinatorial coefficient.
Functions that transform conjugate prior/likelihood pairs to distributions Functions that transform conjugate prior/likelihood pairs to distributions
representing the posterior or posterior predictive. representing the posterior or posterior predictive.
### Gaussian likelihood with conjugate prior. ### Normal likelihood with conjugate prior.
- - - - - -
### `tf.contrib.distributions.gaussian_conjugates_known_sigma_posterior(prior, sigma, s, n)` {#gaussian_conjugates_known_sigma_posterior} ### `tf.contrib.distributions.normal_conjugates_known_sigma_posterior(prior, sigma, s, n)` {#normal_conjugates_known_sigma_posterior}
Posterior Gaussian distribution with conjugate prior on the mean. Posterior Normal distribution with conjugate prior on the mean.
This model assumes that `n` observations (with sum `s`) come from a This model assumes that `n` observations (with sum `s`) come from a
Gaussian with unknown mean `mu` (described by the Gaussian `prior`) Normal with unknown mean `mu` (described by the Normal `prior`)
and known variance `sigma^2`. The "known sigma posterior" is and known variance `sigma^2`. The "known sigma posterior" is
the distribution of the unknown `mu`. the distribution of the unknown `mu`.
Accepts a prior Gaussian distribution object, having parameters Accepts a prior Normal distribution object, having parameters
`mu0` and `sigma0`, as well as known `sigma` values of the predictive `mu0` and `sigma0`, as well as known `sigma` values of the predictive
distribution(s) (also assumed Gaussian), distribution(s) (also assumed Normal),
and statistical estimates `s` (the sum(s) of the observations) and and statistical estimates `s` (the sum(s) of the observations) and
`n` (the number(s) of observations). `n` (the number(s) of observations).
Returns a posterior (also Gaussian) distribution object, with parameters Returns a posterior (also Normal) distribution object, with parameters
`(mu', sigma'^2)`, where: `(mu', sigma'^2)`, where:
``` ```
@ -2477,7 +2477,7 @@ will broadcast in the case of multidimensional sets of parameters.
##### Args: ##### Args:
* <b>`prior`</b>: `Gaussian` object of type `dtype`: * <b>`prior`</b>: `Normal` object of type `dtype`:
the prior distribution having parameters `(mu0, sigma0)`. the prior distribution having parameters `(mu0, sigma0)`.
* <b>`sigma`</b>: tensor of type `dtype`, taking values `sigma > 0`. * <b>`sigma`</b>: tensor of type `dtype`, taking values `sigma > 0`.
The known stddev parameter(s). The known stddev parameter(s).
@ -2486,35 +2486,35 @@ will broadcast in the case of multidimensional sets of parameters.
##### Returns: ##### Returns:
A new Gaussian posterior distribution object for the unknown observation A new Normal posterior distribution object for the unknown observation
mean `mu`. mean `mu`.
##### Raises: ##### Raises:
* <b>`TypeError`</b>: if dtype of `s` does not match `dtype`, or `prior` is not a * <b>`TypeError`</b>: if dtype of `s` does not match `dtype`, or `prior` is not a
Gaussian object. Normal object.
- - - - - -
### `tf.contrib.distributions.gaussian_congugates_known_sigma_predictive(prior, sigma, s, n)` {#gaussian_congugates_known_sigma_predictive} ### `tf.contrib.distributions.normal_congugates_known_sigma_predictive(prior, sigma, s, n)` {#normal_congugates_known_sigma_predictive}
Posterior predictive Gaussian distribution w. conjugate prior on the mean. Posterior predictive Normal distribution w. conjugate prior on the mean.
This model assumes that `n` observations (with sum `s`) come from a This model assumes that `n` observations (with sum `s`) come from a
Gaussian with unknown mean `mu` (described by the Gaussian `prior`) Normal with unknown mean `mu` (described by the Normal `prior`)
and known variance `sigma^2`. The "known sigma predictive" and known variance `sigma^2`. The "known sigma predictive"
is the distribution of new observations, conditioned on the existing is the distribution of new observations, conditioned on the existing
observations and our prior. observations and our prior.
Accepts a prior Gaussian distribution object, having parameters Accepts a prior Normal distribution object, having parameters
`mu0` and `sigma0`, as well as known `sigma` values of the predictive `mu0` and `sigma0`, as well as known `sigma` values of the predictive
distribution(s) (also assumed Gaussian), distribution(s) (also assumed Normal),
and statistical estimates `s` (the sum(s) of the observations) and and statistical estimates `s` (the sum(s) of the observations) and
`n` (the number(s) of observations). `n` (the number(s) of observations).
Calculates the Gaussian distribution(s) `p(x | sigma^2)`: Calculates the Normal distribution(s) `p(x | sigma^2)`:
``` ```
p(x | sigma^2) = int N(x | mu, sigma^2) N(mu | prior.mu, prior.sigma^2) dmu p(x | sigma^2) = int N(x | mu, sigma^2) N(mu | prior.mu, prior.sigma^2) dmu
@ -2536,7 +2536,7 @@ will broadcast in the case of multidimensional sets of parameters.
##### Args: ##### Args:
* <b>`prior`</b>: `Gaussian` object of type `dtype`: * <b>`prior`</b>: `Normal` object of type `dtype`:
the prior distribution having parameters `(mu0, sigma0)`. the prior distribution having parameters `(mu0, sigma0)`.
* <b>`sigma`</b>: tensor of type `dtype`, taking values `sigma > 0`. * <b>`sigma`</b>: tensor of type `dtype`, taking values `sigma > 0`.
The known stddev parameter(s). The known stddev parameter(s).
@ -2545,12 +2545,12 @@ will broadcast in the case of multidimensional sets of parameters.
##### Returns: ##### Returns:
A new Gaussian predictive distribution object. A new Normal predictive distribution object.
##### Raises: ##### Raises:
* <b>`TypeError`</b>: if dtype of `s` does not match `dtype`, or `prior` is not a * <b>`TypeError`</b>: if dtype of `s` does not match `dtype`, or `prior` is not a
Gaussian object. Normal object.

View File

@ -339,7 +339,7 @@ Optimize weights given a loss.
- - - - - -
### `tf.contrib.layers.optimize_loss(loss, global_step, learning_rate, optimizer, gradient_noise_scale=None, gradient_multipliers=None, clip_gradients=None, moving_average_decay=0.9, learning_rate_decay_fn=None, variables=None, name=None)` {#optimize_loss} ### `tf.contrib.layers.optimize_loss(loss, global_step, learning_rate, optimizer, gradient_noise_scale=None, gradient_multipliers=None, clip_gradients=None, moving_average_decay=0.9, learning_rate_decay_fn=None, update_ops=None, variables=None, name=None)` {#optimize_loss}
Given loss and parameters for optimizer, returns a training op. Given loss and parameters for optimizer, returns a training op.
@ -369,6 +369,8 @@ Given loss and parameters for optimizer, returns a training op.
Can be used to implement any learning rate decay Can be used to implement any learning rate decay
functions. functions.
For example: tf.train.exponential_decay. For example: tf.train.exponential_decay.
* <b>`update_ops`</b>: list of update `Operation`s to execute at each step. If `None`,
uses elements of UPDATE_OPS collection.
* <b>`variables`</b>: list of variables to optimize or * <b>`variables`</b>: list of variables to optimize or
`None` to use all trainable variables. `None` to use all trainable variables.
* <b>`name`</b>: The name for this operation is used to scope operations and summaries. * <b>`name`</b>: The name for this operation is used to scope operations and summaries.

View File

@ -3396,7 +3396,7 @@ Extracts numpy matrix from pandas DataFrame.
- - - - - -
### `tf.contrib.learn.read_batch_examples(file_pattern, batch_size, reader, randomize_input=True, queue_capacity=10000, num_threads=1, name='dequeue_examples')` {#read_batch_examples} ### `tf.contrib.learn.read_batch_examples(file_pattern, batch_size, reader, randomize_input=True, num_epochs=None, queue_capacity=10000, num_threads=1, name=None)` {#read_batch_examples}
Adds operations to read, queue, batch `Example` protos. Adds operations to read, queue, batch `Example` protos.
@ -3418,6 +3418,10 @@ All ops are added to the default graph.
* <b>`reader`</b>: A function or class that returns an object with * <b>`reader`</b>: A function or class that returns an object with
`read` method, (filename tensor) -> (example tensor). `read` method, (filename tensor) -> (example tensor).
* <b>`randomize_input`</b>: Whether the input should be randomized. * <b>`randomize_input`</b>: Whether the input should be randomized.
* <b>`num_epochs`</b>: Integer specifying the number of times to read through the
dataset. If `None`, cycles through the dataset forever.
NOTE - If specified, creates a variable that must be initialized, so call
`tf.initialize_all_variables()` as shown in the tests.
* <b>`queue_capacity`</b>: Capacity for input queue. * <b>`queue_capacity`</b>: Capacity for input queue.
* <b>`num_threads`</b>: The number of threads enqueuing examples. * <b>`num_threads`</b>: The number of threads enqueuing examples.
* <b>`name`</b>: Name of resulting op. * <b>`name`</b>: Name of resulting op.
@ -3434,7 +3438,7 @@ All ops are added to the default graph.
- - - - - -
### `tf.contrib.learn.read_batch_features(file_pattern, batch_size, features, reader, randomize_input=True, queue_capacity=10000, num_threads=1, name='dequeue_examples')` {#read_batch_features} ### `tf.contrib.learn.read_batch_features(file_pattern, batch_size, features, reader, randomize_input=True, num_epochs=None, queue_capacity=10000, reader_num_threads=1, parser_num_threads=1, name=None)` {#read_batch_features}
Adds operations to read, queue, batch and parse `Example` protos. Adds operations to read, queue, batch and parse `Example` protos.
@ -3459,8 +3463,13 @@ All ops are added to the default graph.
* <b>`reader`</b>: A function or class that returns an object with * <b>`reader`</b>: A function or class that returns an object with
`read` method, (filename tensor) -> (example tensor). `read` method, (filename tensor) -> (example tensor).
* <b>`randomize_input`</b>: Whether the input should be randomized. * <b>`randomize_input`</b>: Whether the input should be randomized.
* <b>`num_epochs`</b>: Integer specifying the number of times to read through the
dataset. If None, cycles through the dataset forever. NOTE - If specified,
creates a variable that must be initialized, so call
tf.initialize_all_variables() as shown in the tests.
* <b>`queue_capacity`</b>: Capacity for input queue. * <b>`queue_capacity`</b>: Capacity for input queue.
* <b>`num_threads`</b>: The number of threads enqueuing examples. * <b>`reader_num_threads`</b>: The number of threads to read examples.
* <b>`parser_num_threads`</b>: The number of threads to parse examples.
* <b>`name`</b>: Name of resulting op. * <b>`name`</b>: Name of resulting op.
##### Returns: ##### Returns:
@ -3475,7 +3484,7 @@ All ops are added to the default graph.
- - - - - -
### `tf.contrib.learn.read_batch_record_features(file_pattern, batch_size, features, randomize_input=True, queue_capacity=10000, num_threads=1, name='dequeue_record_examples')` {#read_batch_record_features} ### `tf.contrib.learn.read_batch_record_features(file_pattern, batch_size, features, randomize_input=True, num_epochs=None, queue_capacity=10000, reader_num_threads=1, parser_num_threads=1, name='dequeue_record_examples')` {#read_batch_record_features}
Reads TFRecord, queues, batches and parses `Example` proto. Reads TFRecord, queues, batches and parses `Example` proto.
@ -3490,8 +3499,13 @@ See more detailed description in `read_examples`.
* <b>`features`</b>: A `dict` mapping feature keys to `FixedLenFeature` or * <b>`features`</b>: A `dict` mapping feature keys to `FixedLenFeature` or
`VarLenFeature` values. `VarLenFeature` values.
* <b>`randomize_input`</b>: Whether the input should be randomized. * <b>`randomize_input`</b>: Whether the input should be randomized.
* <b>`num_epochs`</b>: Integer specifying the number of times to read through the
dataset. If None, cycles through the dataset forever. NOTE - If specified,
creates a variable that must be initialized, so call
tf.initialize_all_variables() as shown in the tests.
* <b>`queue_capacity`</b>: Capacity for input queue. * <b>`queue_capacity`</b>: Capacity for input queue.
* <b>`num_threads`</b>: The number of threads enqueuing examples. * <b>`reader_num_threads`</b>: The number of threads to read examples.
* <b>`parser_num_threads`</b>: The number of threads to parse examples.
* <b>`name`</b>: Name of resulting op. * <b>`name`</b>: Name of resulting op.
##### Returns: ##### Returns:

View File

@ -1,19 +1,19 @@
### `tf.contrib.distributions.gaussian_conjugates_known_sigma_posterior(prior, sigma, s, n)` {#gaussian_conjugates_known_sigma_posterior} ### `tf.contrib.distributions.normal_conjugates_known_sigma_posterior(prior, sigma, s, n)` {#normal_conjugates_known_sigma_posterior}
Posterior Gaussian distribution with conjugate prior on the mean. Posterior Normal distribution with conjugate prior on the mean.
This model assumes that `n` observations (with sum `s`) come from a This model assumes that `n` observations (with sum `s`) come from a
Gaussian with unknown mean `mu` (described by the Gaussian `prior`) Normal with unknown mean `mu` (described by the Normal `prior`)
and known variance `sigma^2`. The "known sigma posterior" is and known variance `sigma^2`. The "known sigma posterior" is
the distribution of the unknown `mu`. the distribution of the unknown `mu`.
Accepts a prior Gaussian distribution object, having parameters Accepts a prior Normal distribution object, having parameters
`mu0` and `sigma0`, as well as known `sigma` values of the predictive `mu0` and `sigma0`, as well as known `sigma` values of the predictive
distribution(s) (also assumed Gaussian), distribution(s) (also assumed Normal),
and statistical estimates `s` (the sum(s) of the observations) and and statistical estimates `s` (the sum(s) of the observations) and
`n` (the number(s) of observations). `n` (the number(s) of observations).
Returns a posterior (also Gaussian) distribution object, with parameters Returns a posterior (also Normal) distribution object, with parameters
`(mu', sigma'^2)`, where: `(mu', sigma'^2)`, where:
``` ```
@ -28,7 +28,7 @@ will broadcast in the case of multidimensional sets of parameters.
##### Args: ##### Args:
* <b>`prior`</b>: `Gaussian` object of type `dtype`: * <b>`prior`</b>: `Normal` object of type `dtype`:
the prior distribution having parameters `(mu0, sigma0)`. the prior distribution having parameters `(mu0, sigma0)`.
* <b>`sigma`</b>: tensor of type `dtype`, taking values `sigma > 0`. * <b>`sigma`</b>: tensor of type `dtype`, taking values `sigma > 0`.
The known stddev parameter(s). The known stddev parameter(s).
@ -37,12 +37,12 @@ will broadcast in the case of multidimensional sets of parameters.
##### Returns: ##### Returns:
A new Gaussian posterior distribution object for the unknown observation A new Normal posterior distribution object for the unknown observation
mean `mu`. mean `mu`.
##### Raises: ##### Raises:
* <b>`TypeError`</b>: if dtype of `s` does not match `dtype`, or `prior` is not a * <b>`TypeError`</b>: if dtype of `s` does not match `dtype`, or `prior` is not a
Gaussian object. Normal object.

View File

@ -1,4 +1,4 @@
### `tf.contrib.learn.read_batch_record_features(file_pattern, batch_size, features, randomize_input=True, queue_capacity=10000, num_threads=1, name='dequeue_record_examples')` {#read_batch_record_features} ### `tf.contrib.learn.read_batch_record_features(file_pattern, batch_size, features, randomize_input=True, num_epochs=None, queue_capacity=10000, reader_num_threads=1, parser_num_threads=1, name='dequeue_record_examples')` {#read_batch_record_features}
Reads TFRecord, queues, batches and parses `Example` proto. Reads TFRecord, queues, batches and parses `Example` proto.
@ -13,8 +13,13 @@ See more detailed description in `read_examples`.
* <b>`features`</b>: A `dict` mapping feature keys to `FixedLenFeature` or * <b>`features`</b>: A `dict` mapping feature keys to `FixedLenFeature` or
`VarLenFeature` values. `VarLenFeature` values.
* <b>`randomize_input`</b>: Whether the input should be randomized. * <b>`randomize_input`</b>: Whether the input should be randomized.
* <b>`num_epochs`</b>: Integer specifying the number of times to read through the
dataset. If None, cycles through the dataset forever. NOTE - If specified,
creates a variable that must be initialized, so call
tf.initialize_all_variables() as shown in the tests.
* <b>`queue_capacity`</b>: Capacity for input queue. * <b>`queue_capacity`</b>: Capacity for input queue.
* <b>`num_threads`</b>: The number of threads enqueuing examples. * <b>`reader_num_threads`</b>: The number of threads to read examples.
* <b>`parser_num_threads`</b>: The number of threads to parse examples.
* <b>`name`</b>: Name of resulting op. * <b>`name`</b>: Name of resulting op.
##### Returns: ##### Returns:

Some files were not shown because too many files have changed in this diff Show More