Merge pull request #2518 from vrv/branch_123332988
Upstream changes from internal
This commit is contained in:
commit
15e51e6113
@ -53,6 +53,7 @@ def optimize_loss(loss,
|
||||
clip_gradients=None,
|
||||
moving_average_decay=0.9,
|
||||
learning_rate_decay_fn=None,
|
||||
update_ops=None,
|
||||
variables=None,
|
||||
name=None):
|
||||
"""Given loss and parameters for optimizer, returns a training op.
|
||||
@ -81,6 +82,8 @@ def optimize_loss(loss,
|
||||
Can be used to implement any learning rate decay
|
||||
functions.
|
||||
For example: tf.train.exponential_decay.
|
||||
update_ops: list of update `Operation`s to execute at each step. If `None`,
|
||||
uses elements of UPDATE_OPS collection.
|
||||
variables: list of variables to optimize or
|
||||
`None` to use all trainable variables.
|
||||
name: The name for this operation is used to scope operations and summaries.
|
||||
@ -92,6 +95,15 @@ def optimize_loss(loss,
|
||||
ValueError: if optimizer is wrong type.
|
||||
"""
|
||||
with vs.variable_op_scope([loss, global_step], name, "OptimizeLoss"):
|
||||
# Update ops take UPDATE_OPS collection if not provided.
|
||||
update_ops = (set(update_ops or []) or
|
||||
set(ops.get_collection(ops.GraphKeys.UPDATE_OPS)))
|
||||
# Make sure update ops are ran before computing loss.
|
||||
if update_ops:
|
||||
with ops.control_dependencies(update_ops):
|
||||
barrier = control_flow_ops.no_op(name="update_barrier")
|
||||
loss = control_flow_ops.with_dependencies([barrier], loss)
|
||||
|
||||
# Moving average of the loss with decay.
|
||||
if moving_average_decay is not None:
|
||||
# Generate moving averages of the loss.
|
||||
|
@ -132,6 +132,25 @@ class OptimizersTest(tf.test.TestCase):
|
||||
tf.contrib.layers.optimize_loss(
|
||||
loss, global_step, learning_rate=0.1, optimizer="SGD")
|
||||
|
||||
def testUpdateOp(self):
|
||||
optimizers = ["SGD", tf.train.GradientDescentOptimizer,
|
||||
tf.train.GradientDescentOptimizer(learning_rate=0.1)]
|
||||
for optimizer in optimizers:
|
||||
with tf.Graph().as_default() as g:
|
||||
with self.test_session(graph=g) as session:
|
||||
x, var, loss, global_step = _setup_model()
|
||||
update_op = tf.assign(var, 20)
|
||||
train = tf.contrib.layers.optimize_loss(loss,
|
||||
global_step,
|
||||
learning_rate=0.1,
|
||||
optimizer=optimizer,
|
||||
update_ops=[update_op])
|
||||
tf.initialize_all_variables().run()
|
||||
session.run(train, feed_dict={x: 5})
|
||||
var_value, global_step_value = session.run([var, global_step])
|
||||
# 19.5, due to update of var to 20 before loss computation.
|
||||
self.assertEqual(var_value, 19.5)
|
||||
self.assertEqual(global_step_value, 1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
||||
|
@ -195,7 +195,10 @@ def train(graph,
|
||||
raise ValueError('No "global_step" was provided or found in the graph.')
|
||||
|
||||
# TODO(ipolosukhin): Replace all functionality of Supervisor with Monitors.
|
||||
if not monitors:
|
||||
if not supervisor_is_chief:
|
||||
# monitors should run only in supervisor.
|
||||
monitors = []
|
||||
elif not monitors:
|
||||
monitors = monitors_lib.get_default_monitors(
|
||||
loss_op=loss_op,
|
||||
summary_op=logging_ops.get_summary_op(),
|
||||
|
@ -26,8 +26,9 @@ from tensorflow.python.training import input as input_ops
|
||||
|
||||
|
||||
def read_batch_examples(file_pattern, batch_size, reader,
|
||||
randomize_input=True, queue_capacity=10000,
|
||||
num_threads=1, name='dequeue_examples'):
|
||||
randomize_input=True, num_epochs=None,
|
||||
queue_capacity=10000, num_threads=1,
|
||||
name=None):
|
||||
"""Adds operations to read, queue, batch `Example` protos.
|
||||
|
||||
Given file pattern (or list of files), will setup a queue for file names,
|
||||
@ -46,6 +47,10 @@ def read_batch_examples(file_pattern, batch_size, reader,
|
||||
reader: A function or class that returns an object with
|
||||
`read` method, (filename tensor) -> (example tensor).
|
||||
randomize_input: Whether the input should be randomized.
|
||||
num_epochs: Integer specifying the number of times to read through the
|
||||
dataset. If `None`, cycles through the dataset forever.
|
||||
NOTE - If specified, creates a variable that must be initialized, so call
|
||||
`tf.initialize_all_variables()` as shown in the tests.
|
||||
queue_capacity: Capacity for input queue.
|
||||
num_threads: The number of threads enqueuing examples.
|
||||
name: Name of resulting op.
|
||||
@ -82,39 +87,47 @@ def read_batch_examples(file_pattern, batch_size, reader,
|
||||
(batch_size, queue_capacity))
|
||||
if (not num_threads) or (num_threads <= 0):
|
||||
raise ValueError('Invalid num_threads %s.' % num_threads)
|
||||
if (num_epochs is not None) and (num_epochs <= 0):
|
||||
raise ValueError('Invalid num_epochs %s.' % num_epochs)
|
||||
|
||||
with ops.name_scope(name) as scope:
|
||||
with ops.op_scope([file_pattern], name, 'read_batch_examples') as scope:
|
||||
# Setup filename queue with shuffling.
|
||||
with ops.name_scope('file_name_queue') as file_name_queue_scope:
|
||||
file_name_queue = input_ops.string_input_producer(
|
||||
constant_op.constant(file_names, name='input'),
|
||||
shuffle=randomize_input, name=file_name_queue_scope)
|
||||
shuffle=randomize_input, num_epochs=num_epochs,
|
||||
name=file_name_queue_scope)
|
||||
|
||||
# Create reader and set it to read from filename queue.
|
||||
# Create readers, one per thread and set them to read from filename queue.
|
||||
with ops.name_scope('read'):
|
||||
_, example_proto = reader().read(file_name_queue)
|
||||
example_list = []
|
||||
for _ in range(num_threads):
|
||||
_, example_proto = reader().read(file_name_queue)
|
||||
example_list.append([example_proto])
|
||||
|
||||
# Setup batching queue.
|
||||
# Setup batching queue given list of read example tensors.
|
||||
if randomize_input:
|
||||
if isinstance(batch_size, ops.Tensor):
|
||||
min_after_dequeue = int(queue_capacity * 0.4)
|
||||
else:
|
||||
min_after_dequeue = max(queue_capacity - (3 * batch_size), batch_size)
|
||||
examples = input_ops.shuffle_batch(
|
||||
[example_proto], batch_size, capacity=queue_capacity,
|
||||
num_threads=num_threads, min_after_dequeue=min_after_dequeue,
|
||||
examples = input_ops.shuffle_batch_join(
|
||||
example_list, batch_size, capacity=queue_capacity,
|
||||
min_after_dequeue=min_after_dequeue,
|
||||
name=scope)
|
||||
else:
|
||||
examples = input_ops.batch(
|
||||
[example_proto], batch_size, capacity=queue_capacity,
|
||||
num_threads=num_threads, name=scope)
|
||||
examples = input_ops.batch_join(
|
||||
example_list, batch_size, capacity=queue_capacity,
|
||||
name=scope)
|
||||
|
||||
return examples
|
||||
|
||||
|
||||
def read_batch_features(file_pattern, batch_size, features, reader,
|
||||
randomize_input=True, queue_capacity=10000,
|
||||
num_threads=1, name='dequeue_examples'):
|
||||
randomize_input=True, num_epochs=None,
|
||||
queue_capacity=10000, reader_num_threads=1,
|
||||
parser_num_threads=1,
|
||||
name=None):
|
||||
"""Adds operations to read, queue, batch and parse `Example` protos.
|
||||
|
||||
Given file pattern (or list of files), will setup a queue for file names,
|
||||
@ -136,8 +149,13 @@ def read_batch_features(file_pattern, batch_size, features, reader,
|
||||
reader: A function or class that returns an object with
|
||||
`read` method, (filename tensor) -> (example tensor).
|
||||
randomize_input: Whether the input should be randomized.
|
||||
num_epochs: Integer specifying the number of times to read through the
|
||||
dataset. If None, cycles through the dataset forever. NOTE - If specified,
|
||||
creates a variable that must be initialized, so call
|
||||
tf.initialize_all_variables() as shown in the tests.
|
||||
queue_capacity: Capacity for input queue.
|
||||
num_threads: The number of threads enqueuing examples.
|
||||
reader_num_threads: The number of threads to read examples.
|
||||
parser_num_threads: The number of threads to parse examples.
|
||||
name: Name of resulting op.
|
||||
|
||||
Returns:
|
||||
@ -146,17 +164,29 @@ def read_batch_features(file_pattern, batch_size, features, reader,
|
||||
Raises:
|
||||
ValueError: for invalid inputs.
|
||||
"""
|
||||
examples = read_batch_examples(
|
||||
file_pattern, batch_size, reader, randomize_input,
|
||||
queue_capacity, num_threads, name=name)
|
||||
with ops.op_scope([file_pattern], name, 'read_batch_features') as scope:
|
||||
examples = read_batch_examples(
|
||||
file_pattern, batch_size, reader, randomize_input=randomize_input,
|
||||
num_epochs=num_epochs, queue_capacity=queue_capacity,
|
||||
num_threads=reader_num_threads, name=scope)
|
||||
|
||||
# Parse features into tensors.
|
||||
return parsing_ops.parse_example(examples, features)
|
||||
# Parse features into tensors in many threads and put on the queue.
|
||||
features_list = []
|
||||
for _ in range(parser_num_threads):
|
||||
features_list.append(parsing_ops.parse_example(examples, features))
|
||||
return input_ops.batch_join(
|
||||
features_list,
|
||||
batch_size=batch_size,
|
||||
capacity=queue_capacity,
|
||||
enqueue_many=True,
|
||||
name='parse_example_batch_join')
|
||||
|
||||
|
||||
def read_batch_record_features(file_pattern, batch_size, features,
|
||||
randomize_input=True, queue_capacity=10000,
|
||||
num_threads=1, name='dequeue_record_examples'):
|
||||
randomize_input=True, num_epochs=None,
|
||||
queue_capacity=10000, reader_num_threads=1,
|
||||
parser_num_threads=1,
|
||||
name='dequeue_record_examples'):
|
||||
"""Reads TFRecord, queues, batches and parses `Example` proto.
|
||||
|
||||
See more detailed description in `read_examples`.
|
||||
@ -168,8 +198,13 @@ def read_batch_record_features(file_pattern, batch_size, features,
|
||||
features: A `dict` mapping feature keys to `FixedLenFeature` or
|
||||
`VarLenFeature` values.
|
||||
randomize_input: Whether the input should be randomized.
|
||||
num_epochs: Integer specifying the number of times to read through the
|
||||
dataset. If None, cycles through the dataset forever. NOTE - If specified,
|
||||
creates a variable that must be initialized, so call
|
||||
tf.initialize_all_variables() as shown in the tests.
|
||||
queue_capacity: Capacity for input queue.
|
||||
num_threads: The number of threads enqueuing examples.
|
||||
reader_num_threads: The number of threads to read examples.
|
||||
parser_num_threads: The number of threads to parse examples.
|
||||
name: Name of resulting op.
|
||||
|
||||
Returns:
|
||||
@ -181,5 +216,6 @@ def read_batch_record_features(file_pattern, batch_size, features,
|
||||
return read_batch_features(
|
||||
file_pattern=file_pattern, batch_size=batch_size, features=features,
|
||||
reader=io_ops.TFRecordReader,
|
||||
randomize_input=randomize_input,
|
||||
queue_capacity=queue_capacity, num_threads=num_threads, name=name)
|
||||
randomize_input=randomize_input, num_epochs=num_epochs,
|
||||
queue_capacity=queue_capacity, reader_num_threads=reader_num_threads,
|
||||
parser_num_threads=parser_num_threads, name=name)
|
||||
|
@ -17,10 +17,13 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import gfile
|
||||
|
||||
@ -55,44 +58,83 @@ class GraphIOTest(tf.test.TestCase):
|
||||
|
||||
self.assertRaisesRegexp(
|
||||
ValueError, "No files match",
|
||||
tf.contrib.learn.io.read_batch_features,
|
||||
_INVALID_FILE_PATTERN, default_batch_size, None, tf.TFRecordReader,
|
||||
False, queue_capacity,
|
||||
num_threads, name)
|
||||
tf.contrib.learn.io.read_batch_examples,
|
||||
_INVALID_FILE_PATTERN, default_batch_size, tf.TFRecordReader,
|
||||
False, num_epochs=None, queue_capacity=queue_capacity,
|
||||
num_threads=num_threads, name=name)
|
||||
self.assertRaisesRegexp(
|
||||
ValueError, "Invalid batch_size",
|
||||
tf.contrib.learn.io.read_batch_features,
|
||||
_VALID_FILE_PATTERN, None, None, tf.TFRecordReader,
|
||||
False, queue_capacity, num_threads, name)
|
||||
tf.contrib.learn.io.read_batch_examples,
|
||||
_VALID_FILE_PATTERN, None, tf.TFRecordReader,
|
||||
False, num_epochs=None, queue_capacity=queue_capacity,
|
||||
num_threads=num_threads, name=name)
|
||||
self.assertRaisesRegexp(
|
||||
ValueError, "Invalid batch_size",
|
||||
tf.contrib.learn.io.read_batch_features,
|
||||
_VALID_FILE_PATTERN, -1, None, tf.TFRecordReader,
|
||||
False, queue_capacity, num_threads, name)
|
||||
tf.contrib.learn.io.read_batch_examples,
|
||||
_VALID_FILE_PATTERN, -1, tf.TFRecordReader,
|
||||
False, num_epochs=None, queue_capacity=queue_capacity,
|
||||
num_threads=num_threads, name=name)
|
||||
self.assertRaisesRegexp(
|
||||
ValueError, "Invalid queue_capacity",
|
||||
tf.contrib.learn.io.read_batch_features,
|
||||
_VALID_FILE_PATTERN, default_batch_size, None, tf.TFRecordReader,
|
||||
False, None, num_threads, name)
|
||||
tf.contrib.learn.io.read_batch_examples,
|
||||
_VALID_FILE_PATTERN, default_batch_size, tf.TFRecordReader,
|
||||
False, num_epochs=None, queue_capacity=None,
|
||||
num_threads=num_threads, name=name)
|
||||
self.assertRaisesRegexp(
|
||||
ValueError, "Invalid num_threads",
|
||||
tf.contrib.learn.io.read_batch_features,
|
||||
_VALID_FILE_PATTERN, default_batch_size, None, tf.TFRecordReader,
|
||||
False, queue_capacity, None,
|
||||
name)
|
||||
tf.contrib.learn.io.read_batch_examples,
|
||||
_VALID_FILE_PATTERN, default_batch_size, tf.TFRecordReader,
|
||||
False, num_epochs=None, queue_capacity=queue_capacity,
|
||||
num_threads=None, name=name)
|
||||
self.assertRaisesRegexp(
|
||||
ValueError, "Invalid num_threads",
|
||||
tf.contrib.learn.io.read_batch_features,
|
||||
_VALID_FILE_PATTERN, default_batch_size, None, tf.TFRecordReader,
|
||||
False, queue_capacity, -1,
|
||||
name)
|
||||
tf.contrib.learn.io.read_batch_examples,
|
||||
_VALID_FILE_PATTERN, default_batch_size, tf.TFRecordReader,
|
||||
False, num_epochs=None, queue_capacity=queue_capacity,
|
||||
num_threads=-1, name=name)
|
||||
self.assertRaisesRegexp(
|
||||
ValueError, "Invalid batch_size",
|
||||
tf.contrib.learn.io.read_batch_features,
|
||||
_VALID_FILE_PATTERN, queue_capacity + 1, None, tf.TFRecordReader,
|
||||
False, queue_capacity, 1, name)
|
||||
tf.contrib.learn.io.read_batch_examples,
|
||||
_VALID_FILE_PATTERN, queue_capacity + 1, tf.TFRecordReader,
|
||||
False, num_epochs=None, queue_capacity=queue_capacity,
|
||||
num_threads=1, name=name)
|
||||
self.assertRaisesRegexp(
|
||||
ValueError, "Invalid num_epochs",
|
||||
tf.contrib.learn.io.read_batch_examples,
|
||||
_VALID_FILE_PATTERN, default_batch_size, tf.TFRecordReader,
|
||||
False, num_epochs=-1, queue_capacity=queue_capacity, num_threads=1,
|
||||
name=name)
|
||||
|
||||
def test_batch_tf_record(self):
|
||||
def test_batch_record_features(self):
|
||||
batch_size = 17
|
||||
queue_capacity = 1234
|
||||
name = "my_batch"
|
||||
features = {"feature": tf.FixedLenFeature(shape=[0], dtype=tf.float32)}
|
||||
|
||||
with tf.Graph().as_default() as g, self.test_session(graph=g) as sess:
|
||||
features = tf.contrib.learn.io.read_batch_record_features(
|
||||
_VALID_FILE_PATTERN, batch_size, features, randomize_input=False,
|
||||
queue_capacity=queue_capacity, reader_num_threads=2,
|
||||
parser_num_threads=2, name=name)
|
||||
self.assertEquals("%s/parse_example_batch_join:0" % name,
|
||||
features["feature"].name)
|
||||
file_name_queue_name = "%s/file_name_queue" % name
|
||||
file_names_name = "%s/input" % file_name_queue_name
|
||||
example_queue_name = "%s/fifo_queue" % name
|
||||
parse_example_queue_name = "%s/parse_example_batch_join" % name
|
||||
op_nodes = test_util.assert_ops_in_graph({
|
||||
file_names_name: "Const",
|
||||
file_name_queue_name: "FIFOQueue",
|
||||
"%s/read/TFRecordReader" % name: "TFRecordReader",
|
||||
example_queue_name: "FIFOQueue",
|
||||
parse_example_queue_name: "QueueDequeueMany",
|
||||
name: "QueueDequeueMany"
|
||||
}, g)
|
||||
self.assertAllEqual(_FILE_NAMES, sess.run(["%s:0" % file_names_name])[0])
|
||||
self.assertEqual(
|
||||
queue_capacity, op_nodes[example_queue_name].attr["capacity"].i)
|
||||
|
||||
def test_one_epoch(self):
|
||||
batch_size = 17
|
||||
queue_capacity = 1234
|
||||
name = "my_batch"
|
||||
@ -100,20 +142,25 @@ class GraphIOTest(tf.test.TestCase):
|
||||
with tf.Graph().as_default() as g, self.test_session(graph=g) as sess:
|
||||
inputs = tf.contrib.learn.io.read_batch_examples(
|
||||
_VALID_FILE_PATTERN, batch_size,
|
||||
reader=tf.TFRecordReader, randomize_input=False,
|
||||
reader=tf.TFRecordReader, randomize_input=True,
|
||||
num_epochs=1,
|
||||
queue_capacity=queue_capacity, name=name)
|
||||
self.assertEquals("%s:0" % name, inputs.name)
|
||||
file_name_queue_name = "%s/file_name_queue" % name
|
||||
file_name_queue_limit_name = (
|
||||
"%s/limit_epochs/epochs" % file_name_queue_name)
|
||||
file_names_name = "%s/input" % file_name_queue_name
|
||||
example_queue_name = "%s/fifo_queue" % name
|
||||
example_queue_name = "%s/random_shuffle_queue" % name
|
||||
op_nodes = test_util.assert_ops_in_graph({
|
||||
file_names_name: "Const",
|
||||
file_name_queue_name: "FIFOQueue",
|
||||
"%s/read/TFRecordReader" % name: "TFRecordReader",
|
||||
example_queue_name: "FIFOQueue",
|
||||
name: "QueueDequeueMany"
|
||||
example_queue_name: "RandomShuffleQueue",
|
||||
name: "QueueDequeueMany",
|
||||
file_name_queue_limit_name: "Variable"
|
||||
}, g)
|
||||
self.assertAllEqual(_FILE_NAMES, sess.run(["%s:0" % file_names_name])[0])
|
||||
self.assertEqual(
|
||||
set(_FILE_NAMES), set(sess.run(["%s:0" % file_names_name])[0]))
|
||||
self.assertEqual(
|
||||
queue_capacity, op_nodes[example_queue_name].attr["capacity"].i)
|
||||
|
||||
@ -143,6 +190,34 @@ class GraphIOTest(tf.test.TestCase):
|
||||
self.assertEqual(
|
||||
queue_capacity, op_nodes[example_queue_name].attr["capacity"].i)
|
||||
|
||||
def test_read_csv(self):
|
||||
gfile.Glob = self._orig_glob
|
||||
tempdir = tempfile.mkdtemp()
|
||||
filename = os.path.join(tempdir, "file.csv")
|
||||
gfile.Open(filename, "w").write("ABC\nDEF\nGHK\n")
|
||||
|
||||
batch_size = 1
|
||||
queue_capacity = 5
|
||||
name = "my_batch"
|
||||
|
||||
with tf.Graph().as_default() as g, self.test_session(graph=g) as session:
|
||||
inputs = tf.contrib.learn.io.read_batch_examples(
|
||||
filename, batch_size,
|
||||
reader=tf.TextLineReader, randomize_input=False,
|
||||
num_epochs=1, queue_capacity=queue_capacity, name=name)
|
||||
session.run(tf.initialize_all_variables())
|
||||
|
||||
coord = tf.train.Coordinator()
|
||||
tf.train.start_queue_runners(session, coord=coord)
|
||||
|
||||
self.assertAllEqual(session.run(inputs), [b"ABC"])
|
||||
self.assertAllEqual(session.run(inputs), [b"DEF"])
|
||||
self.assertAllEqual(session.run(inputs), [b"GHK"])
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
session.run(inputs)
|
||||
|
||||
coord.request_stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
||||
|
@ -26,15 +26,17 @@ namespace tensorflow {
|
||||
class SquaredLossUpdater : public DualLossUpdater {
|
||||
public:
|
||||
// Closed form solution that decreases the dual squared loss.
|
||||
// See page 23 of http://arxiv.org/pdf/1309.2375v2.pdf
|
||||
// See page 23 of http://arxiv.org/pdf/1309.2375v2.pdf for the derivation of
|
||||
// the update rule when the example weights are equal to 1.0.
|
||||
// Note: There is a typo in the formula in the paper: the denominator should
|
||||
// be 1 + ||x_i||^2/(\lambda n) (without the 2 multiplier).
|
||||
double ComputeUpdatedDual(const double label, const double example_weight,
|
||||
const double current_dual, const double wx,
|
||||
const double weighted_example_norm,
|
||||
const double primal_loss_unused,
|
||||
const double dual_loss_unused) const final {
|
||||
const double delta_numerator = (label - current_dual - wx) * example_weight;
|
||||
const double delta_denominator =
|
||||
1 + weighted_example_norm * example_weight * example_weight * 0.5;
|
||||
const double delta_numerator = label - current_dual - wx;
|
||||
const double delta_denominator = 1 + weighted_example_norm * example_weight;
|
||||
return current_dual + delta_numerator / delta_denominator;
|
||||
}
|
||||
|
||||
|
@ -455,6 +455,7 @@ class SdcaWithLogisticLossTest(SdcaOptimizerTest):
|
||||
# TODO(katsiaspis): add a test for the case when examples at the end of an
|
||||
# epoch are repeated, since example id may be duplicated.
|
||||
|
||||
|
||||
class SdcaWithLinearLossTest(SdcaOptimizerTest):
|
||||
"""SDCA optimizer test class for linear (squared) loss."""
|
||||
|
||||
@ -488,9 +489,11 @@ class SdcaWithLinearLossTest(SdcaOptimizerTest):
|
||||
self.assertAllClose([-20.0 / 3.0, 28.0 / 3.0],
|
||||
predictions.eval(),
|
||||
rtol=0.005)
|
||||
self.assertAllClose(0.01,
|
||||
# Approximate gap should be very close to 0.0. (In fact, because the gap
|
||||
# is only approximate, it is likely that upon convergence the duality gap
|
||||
# can have a tiny negative value).
|
||||
self.assertAllClose(0.00,
|
||||
lr.approximate_duality_gap().eval(),
|
||||
rtol=1e-2,
|
||||
atol=1e-2)
|
||||
|
||||
def testL2Regularization(self):
|
||||
@ -580,7 +583,7 @@ class SdcaWithLinearLossTest(SdcaOptimizerTest):
|
||||
{'age': [1],
|
||||
'gender': [1]}, 14.0, 2.0),
|
||||
]
|
||||
example_weights = [1.0, 1.0]
|
||||
example_weights = [5.0, 3.0]
|
||||
with self._single_threaded_test_session():
|
||||
examples = make_example_dict(example_protos, example_weights)
|
||||
|
||||
@ -597,20 +600,30 @@ class SdcaWithLinearLossTest(SdcaOptimizerTest):
|
||||
for _ in xrange(_MAX_ITERATIONS):
|
||||
train_op.run()
|
||||
|
||||
# Predictions should be 8/9 of label due to minimizing regularized loss:
|
||||
# (label - 2 * 2 * weight)^2 / 2 + L2 * 2 * weight^2
|
||||
self.assertAllClose([-10.0 * 8 / 9, 14.0 * 8 / 9],
|
||||
# There are 4 (sparse) variable weights to be learned. 2 for age and 2 for
|
||||
# gender. Let w_1, w_2 be age weights, w_3, w_4 be gender weights, y_1,
|
||||
# y_2 be the labels for examples 1 and 2 respectively and s_1, s_2 the
|
||||
# corresponding *example* weights. With the given feature values, the loss
|
||||
# function is given by:
|
||||
# s_1/2(y_1 + 2w_1 + 2w_3)^2 + s_2/2(y_2 - 2w_2 - 2w_4)^2
|
||||
# + \lambda/2 (w_1^2 + w_2^2 + w_3^2 + w_4^2). Solving for the optimal, it
|
||||
# can be verified that:
|
||||
# w_1* = w_3* = -2.0 s_1 y_1/(\lambda + 8 s_1) and
|
||||
# w_2* = w_4* = 2 \cdot s_2 y_2/(\lambda + 8 s_2). Equivalently, due to
|
||||
# regularization and example weights, the predictions are within:
|
||||
# 8 \cdot s_i /(\lambda + 8 \cdot s_i) of the labels.
|
||||
self.assertAllClose([-10 * 40.0 / 41.0, 14.0 * 24 / 25.0],
|
||||
predictions.eval(),
|
||||
rtol=0.07)
|
||||
atol=0.01)
|
||||
|
||||
def testDenseFeatures(self):
|
||||
def testDenseFeaturesWithDefaultWeights(self):
|
||||
with self._single_threaded_test_session():
|
||||
examples = make_dense_examples_dict(
|
||||
dense_feature_values=[[-2.0, 0.0], [0.0, 2.0]],
|
||||
dense_feature_values=[[1.0, 0.0], [0.0, 1.0]],
|
||||
weights=[1.0, 1.0],
|
||||
labels=[-10.0, 14.0])
|
||||
labels=[10.0, -5.0])
|
||||
variables = make_dense_variable_dict(2, 2)
|
||||
options = dict(symmetric_l2_regularization=1,
|
||||
options = dict(symmetric_l2_regularization=1.0,
|
||||
symmetric_l1_regularization=0,
|
||||
loss_type='squared_loss')
|
||||
lr = SdcaModel(CONTAINER, examples, variables, options)
|
||||
@ -621,14 +634,51 @@ class SdcaWithLinearLossTest(SdcaOptimizerTest):
|
||||
for _ in xrange(_MAX_ITERATIONS):
|
||||
train_op.run()
|
||||
|
||||
# Predictions should be 4/5 of label due to minimizing regularized loss:
|
||||
# (label - 2 * weight)^2 / 2 + L2 * weight^2
|
||||
self.assertAllClose([-10.0 * 4 / 5, 14.0 * 4 / 5],
|
||||
# The loss function for these particular features is given by:
|
||||
# 1/2(label_1-w_1)^2 + 1/2(label_2-w_2)^2 + \lambda/2 (w_1^2 + w_2^2). So,
|
||||
# differentiating wrt to w_1, w_2 yields the following optimal values:
|
||||
# w_1* = label_1/(\lambda + 1)= 10/2, w_2* =label_2/(\lambda + 1)= -5/2.
|
||||
# In this case the (unnormalized regularized) loss will be:
|
||||
# 1/2(10-5)^2 + 1/2(5-5/2)^2 + 1/2(5^2 + (5/2)^2) = 125.0/4. The actual
|
||||
# loss should be further normalized by the sum of example weights.
|
||||
self.assertAllClose([5.0, -2.5],
|
||||
predictions.eval(),
|
||||
rtol=0.01)
|
||||
|
||||
loss = lr.regularized_loss(examples)
|
||||
self.assertAllClose(148.0 / 10.0, loss.eval(), atol=0.01)
|
||||
self.assertAllClose(125.0 / 8.0, loss.eval(), atol=0.01)
|
||||
|
||||
def testDenseFeaturesWithArbitraryWeights(self):
|
||||
with self._single_threaded_test_session():
|
||||
examples = make_dense_examples_dict(
|
||||
dense_feature_values=[[1.0, 0.0], [0.0, 1.0]],
|
||||
weights=[20.0, 10.0],
|
||||
labels=[10.0, -5.0])
|
||||
variables = make_dense_variable_dict(2, 2)
|
||||
options = dict(symmetric_l2_regularization=5.0,
|
||||
symmetric_l1_regularization=0,
|
||||
loss_type='squared_loss')
|
||||
lr = SdcaModel(CONTAINER, examples, variables, options)
|
||||
tf.initialize_all_variables().run()
|
||||
predictions = lr.predictions(examples)
|
||||
|
||||
train_op = lr.minimize()
|
||||
for _ in xrange(_MAX_ITERATIONS):
|
||||
train_op.run()
|
||||
|
||||
# The loss function for these particular features is given by:
|
||||
# 1/2 s_1 (label_1-w_1)^2 + 1/2 s_2(label_2-w_2)^2 +
|
||||
# \lambda/2 (w_1^2 + w_2^2) where s_1, s_2 are the *example weights. It
|
||||
# turns out that the optimal (variable) weights are given by:
|
||||
# w_1* = label_1 \cdot s_1/(\lambda + s_1)= 8.0 and
|
||||
# w_2* =label_2 \cdot s_2/(\lambda + s_2)= -10/3.
|
||||
# In this case the (unnormalized regularized) loss will be:
|
||||
# s_1/2(8-10)^2 + s_2/2(5-10/3)^2 + 5.0/2(8^2 + (10/3)^2) = 2175.0/9. The
|
||||
# actual loss should be further normalized by the sum of example weights.
|
||||
self.assertAllClose([8.0, -10.0/3],
|
||||
predictions.eval(),
|
||||
rtol=0.01)
|
||||
loss = lr.regularized_loss(examples)
|
||||
self.assertAllClose(2175.0 / 270.0, loss.eval(), atol=0.01)
|
||||
|
||||
|
||||
class SdcaWithHingeLossTest(SdcaOptimizerTest):
|
||||
|
@ -19,7 +19,10 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.losses.python.losses.loss_ops import absolute_difference
|
||||
from tensorflow.contrib.losses.python.losses.loss_ops import add_loss
|
||||
from tensorflow.contrib.losses.python.losses.loss_ops import cosine_distance
|
||||
from tensorflow.contrib.losses.python.losses.loss_ops import get_losses
|
||||
from tensorflow.contrib.losses.python.losses.loss_ops import get_total_loss
|
||||
from tensorflow.contrib.losses.python.losses.loss_ops import log
|
||||
from tensorflow.contrib.losses.python.losses.loss_ops import sigmoid_cross_entropy
|
||||
from tensorflow.contrib.losses.python.losses.loss_ops import softmax_cross_entropy
|
||||
|
@ -104,9 +104,11 @@ weighted average over the individual prediction errors:
|
||||
weight = tf.div(weight, tf.size(weight))
|
||||
loss = tf.contrib.losses.sum_of_squares(predictions, depths, weight)
|
||||
|
||||
|
||||
@@absolute_difference
|
||||
@@add_loss
|
||||
@@cosine_distance
|
||||
@@get_losses
|
||||
@@get_total_loss
|
||||
@@log
|
||||
@@sigmoid_cross_entropy
|
||||
@@softmax_cross_entropy
|
||||
@ -252,6 +254,61 @@ def _num_present(losses, weight, per_batch=False):
|
||||
return num_per_batch if per_batch else math_ops.reduce_sum(num_per_batch)
|
||||
|
||||
|
||||
def add_loss(loss):
|
||||
"""Adds a externally defined loss to collection of losses.
|
||||
|
||||
Args:
|
||||
loss: A loss `Tensor`.
|
||||
"""
|
||||
ops.add_to_collection(ops.GraphKeys.LOSSES, loss)
|
||||
|
||||
|
||||
def get_losses(scope=None):
|
||||
"""Gets the list of loss variables.
|
||||
|
||||
Args:
|
||||
scope: an optional scope for filtering the losses to return.
|
||||
|
||||
Returns:
|
||||
a list of loss variables.
|
||||
"""
|
||||
return ops.get_collection(ops.GraphKeys.LOSSES, scope)
|
||||
|
||||
|
||||
def get_regularization_losses(scope=None):
|
||||
"""Gets the regularization losses.
|
||||
|
||||
Args:
|
||||
scope: an optional scope for filtering the losses to return.
|
||||
|
||||
Returns:
|
||||
A list of loss variables.
|
||||
"""
|
||||
return ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES, scope)
|
||||
|
||||
|
||||
def get_total_loss(add_regularization_losses=True, name="total_loss"):
|
||||
"""Returns a tensor whose value represents the total loss.
|
||||
|
||||
Notice that the function adds the given losses to the regularization losses.
|
||||
|
||||
Args:
|
||||
add_regularization_losses: A boolean indicating whether or not to use the
|
||||
regularization losses in the sum.
|
||||
name: The name of the returned tensor.
|
||||
|
||||
Returns:
|
||||
A `Tensor` whose value represents the total loss.
|
||||
|
||||
Raises:
|
||||
ValueError: if `losses` is not iterable.
|
||||
"""
|
||||
losses = get_losses()
|
||||
if add_regularization_losses:
|
||||
losses += get_regularization_losses()
|
||||
return math_ops.add_n(losses, name=name)
|
||||
|
||||
|
||||
def absolute_difference(predictions, targets, weight=1.0, scope=None):
|
||||
"""Adds an Absolute Difference loss to the training procedure.
|
||||
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/simple_placer.h"
|
||||
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
@ -182,6 +183,7 @@ class ColocationGraph {
|
||||
Status ColocateNodes(const Node& x, const Node& y) {
|
||||
int x_root = FindRoot(x.id());
|
||||
int y_root = FindRoot(y.id());
|
||||
|
||||
Status s;
|
||||
if (x_root != y_root) {
|
||||
// Merge the sets by swinging the parent pointer of the smaller
|
||||
@ -229,6 +231,12 @@ class ColocationGraph {
|
||||
s.error_message());
|
||||
}
|
||||
|
||||
// Transfer ids in the old group to the new one.
|
||||
members_[new_root].ids_in_group.insert(
|
||||
members_[old_root].ids_in_group.begin(),
|
||||
members_[old_root].ids_in_group.end());
|
||||
members_[old_root].ids_in_group.clear();
|
||||
|
||||
// Ensure that the common root has at least one supported device
|
||||
// type, by computing the intersection of
|
||||
// members_[new_root].supported_device_types and
|
||||
@ -267,6 +275,9 @@ class ColocationGraph {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// String containing additional debugging info on failures.
|
||||
string debug_info;
|
||||
|
||||
// We have not yet computed the possible devices for the
|
||||
// colocated node set containing 'node', so we do so now using the
|
||||
// constraints on the root node.
|
||||
@ -310,6 +321,8 @@ class ColocationGraph {
|
||||
// Return an error when a physical device that matches an explicit
|
||||
// device specification is not found. This ensures that we don't
|
||||
// assign a node to GPU when the user wanted to force it on CPU.
|
||||
AddDebugInfo(node_root, &debug_info);
|
||||
|
||||
DeviceNameUtils::ParsedName specified_device_name;
|
||||
if (DeviceNameUtils::ParseFullName(node->def().device(),
|
||||
&specified_device_name) &&
|
||||
@ -334,16 +347,17 @@ class ColocationGraph {
|
||||
node->def().device(),
|
||||
"' because no devices matching that specification "
|
||||
"are registered in this process; available devices: ",
|
||||
str_util::Join(device_names, ", "));
|
||||
str_util::Join(device_names, ", "), debug_info);
|
||||
} else if (specified_device_name.has_type) {
|
||||
return errors::InvalidArgument(
|
||||
"Could not satisfy explicit device specification '",
|
||||
node->def().device(), "' because no supported kernel for ",
|
||||
specified_device_name.type, " devices is available");
|
||||
specified_device_name.type, " devices is available.",
|
||||
debug_info);
|
||||
} else {
|
||||
return errors::InvalidArgument(
|
||||
"Could not satisfy explicit device specification '",
|
||||
node->def().device());
|
||||
node->def().device(), debug_info);
|
||||
}
|
||||
} else {
|
||||
// The specified device may be a valid device but the
|
||||
@ -355,7 +369,7 @@ class ColocationGraph {
|
||||
"required incompatible device '",
|
||||
DeviceNameUtils::ParsedNameToString(
|
||||
members_[node_root].device_name),
|
||||
"'");
|
||||
"'", debug_info);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@ -368,10 +382,11 @@ class ColocationGraph {
|
||||
device_set_->devices(), members_[node_root].supported_device_types);
|
||||
|
||||
if (devices.empty()) {
|
||||
AddDebugInfo(node_root, &debug_info);
|
||||
return errors::InvalidArgument(
|
||||
"Node had no OpKernel registered to support this operation: ",
|
||||
"Operation was ", node->type_string(), " and inputs were ",
|
||||
DataTypeVectorString(node->input_types()));
|
||||
DataTypeVectorString(node->input_types()), debug_info);
|
||||
}
|
||||
}
|
||||
|
||||
@ -390,6 +405,15 @@ class ColocationGraph {
|
||||
// id if it is a root. parent <= 0 indicates that this member is invalid.
|
||||
int parent = -1;
|
||||
|
||||
// The set of ids that are part of the disjoint node set forest.
|
||||
//
|
||||
// This is only fully specified in the root of a disjoint
|
||||
// node set forest.
|
||||
std::set<int> ids_in_group;
|
||||
|
||||
// The type of the op for this node.
|
||||
string op_type;
|
||||
|
||||
// A proxy for the depth of the tree that is used to prefer
|
||||
// connecting smaller trees to larger trees when merging disjoint
|
||||
// sets.
|
||||
@ -410,8 +434,41 @@ class ColocationGraph {
|
||||
std::vector<Device*> possible_devices;
|
||||
};
|
||||
|
||||
// Adds debugging info to 'output' for the node referred to by
|
||||
// 'node_root'.
|
||||
void AddDebugInfo(const int node_root, string* output) {
|
||||
if (members_[node_root].ids_in_group.size() > 1) {
|
||||
strings::StrAppend(output, "\nColocation Debug Info:\n");
|
||||
|
||||
// If this node is part of a colocation group, then we want to
|
||||
// collect the mapping of ops to supported devices, so that
|
||||
// the user can see why an unsatisfiable placement occurred.
|
||||
strings::StrAppend(
|
||||
output, "Colocation group had the following types and devices: ");
|
||||
|
||||
std::unordered_map<string, string> type_to_devices;
|
||||
for (const int id : members_[node_root].ids_in_group) {
|
||||
const string& op_type = members_[id].op_type;
|
||||
string devices_registered;
|
||||
for (const auto& device_type : members_[id].supported_device_types) {
|
||||
strings::StrAppend(&devices_registered, DeviceTypeString(device_type),
|
||||
" ");
|
||||
}
|
||||
|
||||
type_to_devices[op_type] = devices_registered;
|
||||
}
|
||||
|
||||
for (const auto& td : type_to_devices) {
|
||||
strings::StrAppend(output, "\n", td.first, ": ", td.second);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Status InitializeMember(const Node& node, Member* member) {
|
||||
const int id = node.id();
|
||||
member->ids_in_group.insert(id);
|
||||
member->op_type = node.type_string();
|
||||
|
||||
if (id < 0) {
|
||||
return errors::InvalidArgument("Node id was not positive: ", id);
|
||||
}
|
||||
|
@ -729,6 +729,12 @@ TEST_F(SimplePlacerTest, TestHeterogeneousDeviceSetFailure) {
|
||||
EXPECT_TRUE(StringPiece(s.error_message())
|
||||
.contains("colocated with a group of nodes that required "
|
||||
"incompatible device"));
|
||||
|
||||
// The error message should contain information that indicates which
|
||||
// op types have which registered device types.
|
||||
EXPECT_TRUE(StringPiece(s.error_message()).contains("VariableGPU: GPU")) << s;
|
||||
EXPECT_TRUE(StringPiece(s.error_message()).contains("TestAssign: GPU CPU"))
|
||||
<< s;
|
||||
}
|
||||
|
||||
// Test that placement fails when an unknown device is requested.
|
||||
|
@ -13,75 +13,68 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "third_party/eigen3/Eigen/Core"
|
||||
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
|
||||
#include "tensorflow/core/kernels/linalg_ops_common.h"
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/kernels/binary_linalg_ops_common.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
template <typename T>
|
||||
class CholeskyGrad : public OpKernel {
|
||||
template <typename Scalar, bool SupportsBatchOperationT>
|
||||
class CholeskyGrad
|
||||
: public BinaryLinearAlgebraOp<Scalar, SupportsBatchOperationT> {
|
||||
public:
|
||||
explicit CholeskyGrad(OpKernelConstruction* context) : OpKernel(context) {}
|
||||
explicit CholeskyGrad(OpKernelConstruction* context)
|
||||
: BinaryLinearAlgebraOp<Scalar, SupportsBatchOperationT>(context) {}
|
||||
~CholeskyGrad() override {}
|
||||
|
||||
using Matrix =
|
||||
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
|
||||
Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
|
||||
using ConstMatrixMap = Eigen::Map<const Matrix>;
|
||||
using MatrixMap = Eigen::Map<Matrix>;
|
||||
using ConstRef = Eigen::Ref<const Matrix>;
|
||||
using Ref = Eigen::Ref<Matrix>;
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor& input_tensor_l = context->input(0);
|
||||
const Tensor& input_tensor_grad = context->input(1);
|
||||
// Check that input tensors represent a matrix.
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_tensor_l.shape()),
|
||||
errors::InvalidArgument("In[0] is not a matrix"));
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_tensor_grad.shape()),
|
||||
errors::InvalidArgument("In[1] is not a matrix"));
|
||||
// Check that input tensors are square.
|
||||
OP_REQUIRES(context,
|
||||
input_tensor_l.dim_size(0) == input_tensor_l.dim_size(1),
|
||||
errors::InvalidArgument("Input matrix must be square."));
|
||||
OP_REQUIRES(context,
|
||||
input_tensor_grad.dim_size(0) == input_tensor_grad.dim_size(1),
|
||||
errors::InvalidArgument("Input matrix must be square."));
|
||||
TensorShape GetOutputMatrixShape(
|
||||
const TensorShape& input_matrix_l_full_shape,
|
||||
const TensorShape& input_matrix_grad_shape) override {
|
||||
return input_matrix_l_full_shape;
|
||||
}
|
||||
|
||||
// Check that input tensors are of same size.
|
||||
OP_REQUIRES(context,
|
||||
input_tensor_l.dim_size(0) == input_tensor_grad.dim_size(0),
|
||||
errors::InvalidArgument("Input matrices must be same size."));
|
||||
|
||||
// Create an output tensor
|
||||
Tensor* output_tensor = NULL;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(
|
||||
0, input_tensor_grad.shape(), &output_tensor));
|
||||
|
||||
if (output_tensor->NumElements() == 0) {
|
||||
// the output shape is a 0-element matrix, so there is nothing to do.
|
||||
return;
|
||||
int64 GetCostPerUnit(const TensorShape& input_matrix_shape,
|
||||
const TensorShape& rhs_matrix_shape) override {
|
||||
const int64 rows = input_matrix_shape.dim_size(0);
|
||||
if (rows > (1LL << 20)) {
|
||||
// A big number to cap the cost in case overflow.
|
||||
return kint64max;
|
||||
} else {
|
||||
return rows * rows * rows;
|
||||
}
|
||||
// The next lines are necessary to get Eigen matrix behaviour.
|
||||
const ConstMatrixMap input_matrix_l_full(input_tensor_l.flat<T>().data(),
|
||||
input_tensor_l.dim_size(0),
|
||||
input_tensor_l.dim_size(1));
|
||||
const ConstMatrixMap input_matrix_grad(input_tensor_grad.flat<T>().data(),
|
||||
input_tensor_grad.dim_size(0),
|
||||
input_tensor_grad.dim_size(1));
|
||||
MatrixMap output_matrix(output_tensor->template flat<T>().data(),
|
||||
input_tensor_l.dim_size(0),
|
||||
input_tensor_l.dim_size(1));
|
||||
}
|
||||
|
||||
// Algorithm only depends on lower triangular half on input_tensor_l.
|
||||
void ComputeMatrix(OpKernelContext* context,
|
||||
const ConstMatrixMap& input_matrix_l_full,
|
||||
const ConstMatrixMap& input_matrix_grad,
|
||||
MatrixMap* output_matrix) override {
|
||||
OP_REQUIRES(context,
|
||||
input_matrix_l_full.rows() == input_matrix_l_full.cols(),
|
||||
errors::InvalidArgument("Input matrix must be square."));
|
||||
OP_REQUIRES(
|
||||
context, input_matrix_l_full.cols() == input_matrix_grad.cols(),
|
||||
errors::InvalidArgument(
|
||||
"Input matrix and gradient must have same number of cols."));
|
||||
OP_REQUIRES(
|
||||
context, input_matrix_l_full.rows() == input_matrix_grad.rows(),
|
||||
errors::InvalidArgument(
|
||||
"Input matrix and gradient must have same number of rows."));
|
||||
|
||||
// Algorithm only depends on lower triangular half on input_matrix_l.
|
||||
const Matrix input_matrix_l =
|
||||
input_matrix_l_full.template triangularView<Eigen::Lower>();
|
||||
// Algorithm only depends on lower triangular half on input_matrix_grad.
|
||||
output_matrix = input_matrix_grad.template triangularView<Eigen::Lower>();
|
||||
*output_matrix = input_matrix_grad.template triangularView<Eigen::Lower>();
|
||||
|
||||
const int64 kMatrixSize = input_matrix_l.rows();
|
||||
const int64 kMaxBlockSize = 32;
|
||||
@ -104,20 +97,21 @@ class CholeskyGrad : public OpKernel {
|
||||
|
||||
auto B = input_matrix_l.block(block_end, 0, trailing_size, block_begin);
|
||||
auto B_bar =
|
||||
output_matrix.block(block_end, 0, trailing_size, block_begin);
|
||||
output_matrix->block(block_end, 0, trailing_size, block_begin);
|
||||
|
||||
auto C = input_matrix_l.block(block_end, block_begin, trailing_size,
|
||||
block_size);
|
||||
auto C_bar = output_matrix.block(block_end, block_begin, trailing_size,
|
||||
block_size);
|
||||
auto C_bar = output_matrix->block(block_end, block_begin, trailing_size,
|
||||
block_size);
|
||||
|
||||
auto D = input_matrix_l.block(block_begin, block_begin, block_size,
|
||||
block_size);
|
||||
auto D_bar =
|
||||
output_matrix.block(block_begin, block_begin, block_size, block_size);
|
||||
auto D_bar = output_matrix->block(block_begin, block_begin, block_size,
|
||||
block_size);
|
||||
|
||||
auto R = input_matrix_l.block(block_begin, 0, block_size, block_begin);
|
||||
auto R_bar = output_matrix.block(block_begin, 0, block_size, block_begin);
|
||||
auto R_bar =
|
||||
output_matrix->block(block_begin, 0, block_size, block_begin);
|
||||
|
||||
C_bar = D.adjoint().template triangularView<Eigen::Upper>()
|
||||
.solve(C_bar.adjoint()).adjoint();
|
||||
@ -127,9 +121,11 @@ class CholeskyGrad : public OpKernel {
|
||||
CholeskyGradUnblocked(D, D_bar);
|
||||
R_bar -= (D_bar + D_bar.adjoint()) * R;
|
||||
}
|
||||
output_matrix = (0.5 * (output_matrix + output_matrix.transpose())).eval();
|
||||
*output_matrix =
|
||||
(0.5 * (*output_matrix + output_matrix->transpose())).eval();
|
||||
}
|
||||
void CholeskyGradUnblocked(const ConstRef l_block, Ref grad_block) {
|
||||
|
||||
void CholeskyGradUnblocked(const ConstRef& l_block, Ref grad_block) {
|
||||
const int64 kMatrixSize = l_block.rows();
|
||||
for (int64 k = kMatrixSize - 1; k >= 0; k--) {
|
||||
/* This shows the block structure.
|
||||
@ -166,6 +162,11 @@ class CholeskyGrad : public OpKernel {
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_LINALG_OP("CholeskyGrad", (CholeskyGrad<float>), float);
|
||||
REGISTER_LINALG_OP("CholeskyGrad", (CholeskyGrad<double>), double);
|
||||
REGISTER_BINARY_LINALG_OP("CholeskyGrad", (CholeskyGrad<float, false>), float);
|
||||
REGISTER_BINARY_LINALG_OP("CholeskyGrad", (CholeskyGrad<double, false>),
|
||||
double);
|
||||
REGISTER_BINARY_LINALG_OP("BatchCholeskyGrad", (CholeskyGrad<float, true>),
|
||||
float);
|
||||
REGISTER_BINARY_LINALG_OP("BatchCholeskyGrad", (CholeskyGrad<double, true>),
|
||||
double);
|
||||
} // namespace tensorflow
|
||||
|
@ -64,8 +64,7 @@ class CudnnScratchAllocator : public perftools::gputools::ScratchAllocator {
|
||||
AllocatorAttributes(), allocation_attr));
|
||||
if (!allocation_status.ok()) {
|
||||
return perftools::gputools::port::StatusOr<
|
||||
perftools::gputools::DeviceMemory<uint8>>(
|
||||
AsDeviceMemory<uint8>(nullptr, 0));
|
||||
perftools::gputools::DeviceMemory<uint8>>();
|
||||
}
|
||||
// Hold the reference of the allocated tensors until the end of the
|
||||
// allocator.
|
||||
|
@ -305,7 +305,7 @@ void SpatialAvgPool(OpKernelContext* context, Tensor* output,
|
||||
const int out_offset =
|
||||
(b * params.out_height + ph) * params.out_width + pw;
|
||||
out_mat.col(out_offset) += in_mat.col(in_offset);
|
||||
out_count(out_offset)++;
|
||||
out_count(out_offset) += T(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -3175,6 +3175,31 @@ op {
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "BatchCholeskyGrad"
|
||||
input_arg {
|
||||
name: "l"
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "grad"
|
||||
type_attr: "T"
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "BatchFFT"
|
||||
input_arg {
|
||||
|
@ -129,11 +129,34 @@ REGISTER_OP("CholeskyGrad")
|
||||
.Doc(R"doc(
|
||||
Calculates the reverse mode backpropagated gradient of the Cholesky algorithm.
|
||||
|
||||
For an explanation see "Differentiation of the Cholesky algorithm" by Iain Murray http://arxiv.org/abs/1602.07527.
|
||||
For an explanation see "Differentiation of the Cholesky algorithm" by
|
||||
Iain Murray http://arxiv.org/abs/1602.07527.
|
||||
|
||||
l: Output of Cholesky algorithm l = chol(A). Shape is `[M, M]`. Algorithm depends only on lower triangular part of this matrix.
|
||||
grad: df/dl where f is some scalar function. Shape is `[M, M]'. Algorithm depends only on lower triangular part of this matrix.
|
||||
output: Symmetrized version of df/dA . Shape is `[M, M]'
|
||||
l: Output of Cholesky algorithm l = chol(A). Shape is `[M, M]`.
|
||||
Algorithm depends only on lower triangular part of this matrix.
|
||||
grad: df/dl where f is some scalar function. Shape is `[M, M]'.
|
||||
Algorithm depends only on lower triangular part of this matrix.
|
||||
output: Symmetrized version of df/dA . Shape is `[M, M]'.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("BatchCholeskyGrad")
|
||||
.Input("l: T")
|
||||
.Input("grad: T")
|
||||
.Output("output: T")
|
||||
.Attr("T: {float, double}")
|
||||
.Doc(R"doc(
|
||||
Calculates the reverse mode backpropagated gradient of the Cholesky algorithm.
|
||||
|
||||
For an explanation see "Differentiation of the Cholesky algorithm" by
|
||||
Iain Murray http://arxiv.org/abs/1602.07527.
|
||||
|
||||
l: Output of batch Cholesky algorithm l = batch_cholesky(A). Shape is `[..., M, M]`.
|
||||
Algorithm depends only on lower triangular part of the innermost matrices of
|
||||
this tensor.
|
||||
grad: df/dl where f is some scalar function. Shape is `[..., M, M]'.
|
||||
Algorithm depends only on lower triangular part of the innermost matrices of
|
||||
this tensor.
|
||||
output: Symmetrized version of df/dA . Shape is `[..., M, M]'
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("SelfAdjointEig")
|
||||
|
@ -1397,6 +1397,36 @@ op {
|
||||
summary: "Calculates the Cholesky decomposition of a batch of square matrices."
|
||||
description: "The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions\nform square matrices, with the same constraints as the single matrix Cholesky\ndecomposition above. The output is a tensor of the same shape as the input\ncontaining the Cholesky decompositions for all input submatrices `[..., :, :]`."
|
||||
}
|
||||
op {
|
||||
name: "BatchCholeskyGrad"
|
||||
input_arg {
|
||||
name: "l"
|
||||
description: "Output of batch Cholesky algorithm l = batch_cholesky(A). Shape is `[..., M, M]`.\nAlgorithm depends only on lower triangular part of the innermost matrices of\nthis tensor."
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "grad"
|
||||
description: "df/dl where f is some scalar function. Shape is `[..., M, M]\'.\nAlgorithm depends only on lower triangular part of the innermost matrices of\nthis tensor."
|
||||
type_attr: "T"
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
description: "Symmetrized version of df/dA . Shape is `[..., M, M]\'"
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
}
|
||||
}
|
||||
}
|
||||
summary: "Calculates the reverse mode backpropagated gradient of the Cholesky algorithm."
|
||||
description: "For an explanation see \"Differentiation of the Cholesky algorithm\" by\nIain Murray http://arxiv.org/abs/1602.07527."
|
||||
}
|
||||
op {
|
||||
name: "BatchFFT"
|
||||
input_arg {
|
||||
@ -2482,17 +2512,17 @@ op {
|
||||
name: "CholeskyGrad"
|
||||
input_arg {
|
||||
name: "l"
|
||||
description: "Output of Cholesky algorithm l = chol(A). Shape is `[M, M]`. Algorithm depends only on lower triangular part of this matrix."
|
||||
description: "Output of Cholesky algorithm l = chol(A). Shape is `[M, M]`.\nAlgorithm depends only on lower triangular part of this matrix."
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "grad"
|
||||
description: "df/dl where f is some scalar function. Shape is `[M, M]\'. Algorithm depends only on lower triangular part of this matrix."
|
||||
description: "df/dl where f is some scalar function. Shape is `[M, M]\'.\nAlgorithm depends only on lower triangular part of this matrix."
|
||||
type_attr: "T"
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
description: "Symmetrized version of df/dA . Shape is `[M, M]\'"
|
||||
description: "Symmetrized version of df/dA . Shape is `[M, M]\'."
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
@ -2506,7 +2536,7 @@ op {
|
||||
}
|
||||
}
|
||||
summary: "Calculates the reverse mode backpropagated gradient of the Cholesky algorithm."
|
||||
description: "For an explanation see \"Differentiation of the Cholesky algorithm\" by Iain Murray http://arxiv.org/abs/1602.07527."
|
||||
description: "For an explanation see \"Differentiation of the Cholesky algorithm\" by\nIain Murray http://arxiv.org/abs/1602.07527."
|
||||
}
|
||||
op {
|
||||
name: "Complex"
|
||||
@ -11482,7 +11512,7 @@ op {
|
||||
}
|
||||
}
|
||||
summary: "Computes the sum of elements across dimensions of a SparseTensor."
|
||||
description: "This Op takes a SparseTensor and is the sparse counterpart to\n`tf.reduce_sum()`. In particular, this Op also returns a dense `Tensor`\ninstead of a sparse one.\n\nReduces `sp_input` along the dimensions given in `reduction_axes`. Unless\n`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in\n`reduction_axes`. If `keep_dims` is true, the reduced dimensions are retained\nwith length 1.\n\nIf `reduction_axes` has no entries, all dimensions are reduced, and a tensor\nwith a single element is returned."
|
||||
description: "This Op takes a SparseTensor and is the sparse counterpart to\n`tf.reduce_sum()`. In particular, this Op also returns a dense `Tensor`\ninstead of a sparse one.\n\nReduces `sp_input` along the dimensions given in `reduction_axes`. Unless\n`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in\n`reduction_axes`. If `keep_dims` is true, the reduced dimensions are retained\nwith length 1.\n\nIf `reduction_axes` has no entries, all dimensions are reduced, and a tensor\nwith a single element is returned. Additionally, the axes can be negative,\nwhich are interpreted according to the indexing rules in Python."
|
||||
}
|
||||
op {
|
||||
name: "SparseReorder"
|
||||
|
@ -52,11 +52,11 @@ def train():
|
||||
# Input placehoolders
|
||||
with tf.name_scope('input'):
|
||||
x = tf.placeholder(tf.float32, [None, 784], name='x-input')
|
||||
y_ = tf.placeholder(tf.float32, [None, 10], name='y-input')
|
||||
|
||||
with tf.name_scope('input_reshape'):
|
||||
image_shaped_input = tf.reshape(x, [-1, 28, 28, 1])
|
||||
tf.image_summary('input', image_shaped_input, 10)
|
||||
y_ = tf.placeholder(tf.float32, [None, 10], name='y-input')
|
||||
keep_prob = tf.placeholder(tf.float32)
|
||||
tf.scalar_summary('dropout_keep_probability', keep_prob)
|
||||
|
||||
# We can't initialize these variables to 0 - the network will get stuck.
|
||||
def weight_variable(shape):
|
||||
@ -105,7 +105,12 @@ def train():
|
||||
return activations
|
||||
|
||||
hidden1 = nn_layer(x, 784, 500, 'layer1')
|
||||
dropped = tf.nn.dropout(hidden1, keep_prob)
|
||||
|
||||
with tf.name_scope('dropout'):
|
||||
keep_prob = tf.placeholder(tf.float32)
|
||||
tf.scalar_summary('dropout_keep_probability', keep_prob)
|
||||
dropped = tf.nn.dropout(hidden1, keep_prob)
|
||||
|
||||
y = nn_layer(dropped, 500, 10, 'layer2', act=tf.nn.softmax)
|
||||
|
||||
with tf.name_scope('cross_entropy'):
|
||||
@ -151,9 +156,20 @@ def train():
|
||||
summary, acc = sess.run([merged, accuracy], feed_dict=feed_dict(False))
|
||||
test_writer.add_summary(summary, i)
|
||||
print('Accuracy at step %s: %s' % (i, acc))
|
||||
else: # Record train set summarieis, and train
|
||||
summary, _ = sess.run([merged, train_step], feed_dict=feed_dict(True))
|
||||
train_writer.add_summary(summary, i)
|
||||
else: # Record train set summaries, and train
|
||||
if i % 100 == 99: # Record execution stats
|
||||
run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
|
||||
run_metadata = tf.RunMetadata()
|
||||
summary, _ = sess.run([merged, train_step],
|
||||
feed_dict=feed_dict(True),
|
||||
options=run_options,
|
||||
run_metadata=run_metadata)
|
||||
train_writer.add_run_metadata(run_metadata, 'step%d' % i)
|
||||
train_writer.add_summary(summary, i)
|
||||
print('Adding run metadata for', i)
|
||||
else: # Record a summary
|
||||
summary, _ = sess.run([merged, train_step], feed_dict=feed_dict(True))
|
||||
train_writer.add_summary(summary, i)
|
||||
|
||||
|
||||
def main(_):
|
||||
|
@ -1338,9 +1338,9 @@ Variance of each batch member.
|
||||
|
||||
- - -
|
||||
|
||||
### `class tf.contrib.distributions.Gaussian` {#Gaussian}
|
||||
### `class tf.contrib.distributions.Normal` {#Normal}
|
||||
|
||||
The scalar Gaussian distribution with mean and stddev parameters mu, sigma.
|
||||
The scalar Normal distribution with mean and stddev parameters mu, sigma.
|
||||
|
||||
#### Mathematical details
|
||||
|
||||
@ -1353,15 +1353,15 @@ The PDF of this distribution is:
|
||||
Examples of initialization of one or a batch of distributions.
|
||||
|
||||
```python
|
||||
# Define a single scalar Gaussian distribution.
|
||||
dist = tf.contrib.distributions.Gaussian(mu=0, sigma=3)
|
||||
# Define a single scalar Normal distribution.
|
||||
dist = tf.contrib.distributions.Normal(mu=0, sigma=3)
|
||||
|
||||
# Evaluate the cdf at 1, returning a scalar.
|
||||
dist.cdf(1)
|
||||
|
||||
# Define a batch of two scalar valued Gaussians.
|
||||
# Define a batch of two scalar valued Normals.
|
||||
# The first has mean 1 and standard deviation 11, the second 2 and 22.
|
||||
dist = tf.contrib.distributions.Gaussian(mu=[1, 2.], sigma=[11, 22.])
|
||||
dist = tf.contrib.distributions.Normal(mu=[1, 2.], sigma=[11, 22.])
|
||||
|
||||
# Evaluate the pdf of the first distribution on 0, and the second on 1.5,
|
||||
# returning a length two tensor.
|
||||
@ -1374,9 +1374,9 @@ dist.sample(3)
|
||||
Arguments are broadcast when possible.
|
||||
|
||||
```python
|
||||
# Define a batch of two scalar valued Gaussians.
|
||||
# Define a batch of two scalar valued Normals.
|
||||
# Both have mean 1, but different standard deviations.
|
||||
dist = tf.contrib.distributions.Gaussian(mu=1, sigma=[11, 22.])
|
||||
dist = tf.contrib.distributions.Normal(mu=1, sigma=[11, 22.])
|
||||
|
||||
# Evaluate the pdf of both distributions on the same point, 3.0,
|
||||
# returning a length 2 tensor.
|
||||
@ -1384,9 +1384,9 @@ dist.pdf(3.0)
|
||||
```
|
||||
- - -
|
||||
|
||||
#### `tf.contrib.distributions.Gaussian.__init__(mu, sigma, name=None)` {#Gaussian.__init__}
|
||||
#### `tf.contrib.distributions.Normal.__init__(mu, sigma, name=None)` {#Normal.__init__}
|
||||
|
||||
Construct Gaussian distributions with mean and stddev `mu` and `sigma`.
|
||||
Construct Normal distributions with mean and stddev `mu` and `sigma`.
|
||||
|
||||
The parameters `mu` and `sigma` must be shaped in a way that supports
|
||||
broadcasting (e.g. `mu + sigma` is a valid operation).
|
||||
@ -1407,9 +1407,9 @@ broadcasting (e.g. `mu + sigma` is a valid operation).
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.contrib.distributions.Gaussian.cdf(x, name=None)` {#Gaussian.cdf}
|
||||
#### `tf.contrib.distributions.Normal.cdf(x, name=None)` {#Normal.cdf}
|
||||
|
||||
CDF of observations in `x` under these Gaussian distribution(s).
|
||||
CDF of observations in `x` under these Normal distribution(s).
|
||||
|
||||
##### Args:
|
||||
|
||||
@ -1425,16 +1425,16 @@ CDF of observations in `x` under these Gaussian distribution(s).
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.contrib.distributions.Gaussian.dtype` {#Gaussian.dtype}
|
||||
#### `tf.contrib.distributions.Normal.dtype` {#Normal.dtype}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.contrib.distributions.Gaussian.entropy(name=None)` {#Gaussian.entropy}
|
||||
#### `tf.contrib.distributions.Normal.entropy(name=None)` {#Normal.entropy}
|
||||
|
||||
The entropy of Gaussian distribution(s).
|
||||
The entropy of Normal distribution(s).
|
||||
|
||||
##### Args:
|
||||
|
||||
@ -1449,16 +1449,16 @@ The entropy of Gaussian distribution(s).
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.contrib.distributions.Gaussian.is_reparameterized` {#Gaussian.is_reparameterized}
|
||||
#### `tf.contrib.distributions.Normal.is_reparameterized` {#Normal.is_reparameterized}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.contrib.distributions.Gaussian.log_cdf(x, name=None)` {#Gaussian.log_cdf}
|
||||
#### `tf.contrib.distributions.Normal.log_cdf(x, name=None)` {#Normal.log_cdf}
|
||||
|
||||
Log CDF of observations `x` under these Gaussian distribution(s).
|
||||
Log CDF of observations `x` under these Normal distribution(s).
|
||||
|
||||
##### Args:
|
||||
|
||||
@ -1474,9 +1474,9 @@ Log CDF of observations `x` under these Gaussian distribution(s).
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.contrib.distributions.Gaussian.log_pdf(x, name=None)` {#Gaussian.log_pdf}
|
||||
#### `tf.contrib.distributions.Normal.log_pdf(x, name=None)` {#Normal.log_pdf}
|
||||
|
||||
Log pdf of observations in `x` under these Gaussian distribution(s).
|
||||
Log pdf of observations in `x` under these Normal distribution(s).
|
||||
|
||||
##### Args:
|
||||
|
||||
@ -1492,23 +1492,23 @@ Log pdf of observations in `x` under these Gaussian distribution(s).
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.contrib.distributions.Gaussian.mean` {#Gaussian.mean}
|
||||
#### `tf.contrib.distributions.Normal.mean` {#Normal.mean}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.contrib.distributions.Gaussian.mu` {#Gaussian.mu}
|
||||
#### `tf.contrib.distributions.Normal.mu` {#Normal.mu}
|
||||
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.contrib.distributions.Gaussian.pdf(x, name=None)` {#Gaussian.pdf}
|
||||
#### `tf.contrib.distributions.Normal.pdf(x, name=None)` {#Normal.pdf}
|
||||
|
||||
The PDF of observations in `x` under these Gaussian distribution(s).
|
||||
The PDF of observations in `x` under these Normal distribution(s).
|
||||
|
||||
##### Args:
|
||||
|
||||
@ -1524,9 +1524,9 @@ The PDF of observations in `x` under these Gaussian distribution(s).
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.contrib.distributions.Gaussian.sample(n, seed=None, name=None)` {#Gaussian.sample}
|
||||
#### `tf.contrib.distributions.Normal.sample(n, seed=None, name=None)` {#Normal.sample}
|
||||
|
||||
Sample `n` observations from the Gaussian Distributions.
|
||||
Sample `n` observations from the Normal Distributions.
|
||||
|
||||
##### Args:
|
||||
|
||||
@ -1544,7 +1544,7 @@ Sample `n` observations from the Gaussian Distributions.
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.contrib.distributions.Gaussian.sigma` {#Gaussian.sigma}
|
||||
#### `tf.contrib.distributions.Normal.sigma` {#Normal.sigma}
|
||||
|
||||
|
||||
|
||||
@ -2443,26 +2443,26 @@ probability includes a combinatorial coefficient.
|
||||
Functions that transform conjugate prior/likelihood pairs to distributions
|
||||
representing the posterior or posterior predictive.
|
||||
|
||||
### Gaussian likelihood with conjugate prior.
|
||||
### Normal likelihood with conjugate prior.
|
||||
|
||||
- - -
|
||||
|
||||
### `tf.contrib.distributions.gaussian_conjugates_known_sigma_posterior(prior, sigma, s, n)` {#gaussian_conjugates_known_sigma_posterior}
|
||||
### `tf.contrib.distributions.normal_conjugates_known_sigma_posterior(prior, sigma, s, n)` {#normal_conjugates_known_sigma_posterior}
|
||||
|
||||
Posterior Gaussian distribution with conjugate prior on the mean.
|
||||
Posterior Normal distribution with conjugate prior on the mean.
|
||||
|
||||
This model assumes that `n` observations (with sum `s`) come from a
|
||||
Gaussian with unknown mean `mu` (described by the Gaussian `prior`)
|
||||
Normal with unknown mean `mu` (described by the Normal `prior`)
|
||||
and known variance `sigma^2`. The "known sigma posterior" is
|
||||
the distribution of the unknown `mu`.
|
||||
|
||||
Accepts a prior Gaussian distribution object, having parameters
|
||||
Accepts a prior Normal distribution object, having parameters
|
||||
`mu0` and `sigma0`, as well as known `sigma` values of the predictive
|
||||
distribution(s) (also assumed Gaussian),
|
||||
distribution(s) (also assumed Normal),
|
||||
and statistical estimates `s` (the sum(s) of the observations) and
|
||||
`n` (the number(s) of observations).
|
||||
|
||||
Returns a posterior (also Gaussian) distribution object, with parameters
|
||||
Returns a posterior (also Normal) distribution object, with parameters
|
||||
`(mu', sigma'^2)`, where:
|
||||
|
||||
```
|
||||
@ -2477,7 +2477,7 @@ will broadcast in the case of multidimensional sets of parameters.
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`prior`</b>: `Gaussian` object of type `dtype`:
|
||||
* <b>`prior`</b>: `Normal` object of type `dtype`:
|
||||
the prior distribution having parameters `(mu0, sigma0)`.
|
||||
* <b>`sigma`</b>: tensor of type `dtype`, taking values `sigma > 0`.
|
||||
The known stddev parameter(s).
|
||||
@ -2486,35 +2486,35 @@ will broadcast in the case of multidimensional sets of parameters.
|
||||
|
||||
##### Returns:
|
||||
|
||||
A new Gaussian posterior distribution object for the unknown observation
|
||||
A new Normal posterior distribution object for the unknown observation
|
||||
mean `mu`.
|
||||
|
||||
##### Raises:
|
||||
|
||||
|
||||
* <b>`TypeError`</b>: if dtype of `s` does not match `dtype`, or `prior` is not a
|
||||
Gaussian object.
|
||||
Normal object.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
### `tf.contrib.distributions.gaussian_congugates_known_sigma_predictive(prior, sigma, s, n)` {#gaussian_congugates_known_sigma_predictive}
|
||||
### `tf.contrib.distributions.normal_congugates_known_sigma_predictive(prior, sigma, s, n)` {#normal_congugates_known_sigma_predictive}
|
||||
|
||||
Posterior predictive Gaussian distribution w. conjugate prior on the mean.
|
||||
Posterior predictive Normal distribution w. conjugate prior on the mean.
|
||||
|
||||
This model assumes that `n` observations (with sum `s`) come from a
|
||||
Gaussian with unknown mean `mu` (described by the Gaussian `prior`)
|
||||
Normal with unknown mean `mu` (described by the Normal `prior`)
|
||||
and known variance `sigma^2`. The "known sigma predictive"
|
||||
is the distribution of new observations, conditioned on the existing
|
||||
observations and our prior.
|
||||
|
||||
Accepts a prior Gaussian distribution object, having parameters
|
||||
Accepts a prior Normal distribution object, having parameters
|
||||
`mu0` and `sigma0`, as well as known `sigma` values of the predictive
|
||||
distribution(s) (also assumed Gaussian),
|
||||
distribution(s) (also assumed Normal),
|
||||
and statistical estimates `s` (the sum(s) of the observations) and
|
||||
`n` (the number(s) of observations).
|
||||
|
||||
Calculates the Gaussian distribution(s) `p(x | sigma^2)`:
|
||||
Calculates the Normal distribution(s) `p(x | sigma^2)`:
|
||||
|
||||
```
|
||||
p(x | sigma^2) = int N(x | mu, sigma^2) N(mu | prior.mu, prior.sigma^2) dmu
|
||||
@ -2536,7 +2536,7 @@ will broadcast in the case of multidimensional sets of parameters.
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`prior`</b>: `Gaussian` object of type `dtype`:
|
||||
* <b>`prior`</b>: `Normal` object of type `dtype`:
|
||||
the prior distribution having parameters `(mu0, sigma0)`.
|
||||
* <b>`sigma`</b>: tensor of type `dtype`, taking values `sigma > 0`.
|
||||
The known stddev parameter(s).
|
||||
@ -2545,12 +2545,12 @@ will broadcast in the case of multidimensional sets of parameters.
|
||||
|
||||
##### Returns:
|
||||
|
||||
A new Gaussian predictive distribution object.
|
||||
A new Normal predictive distribution object.
|
||||
|
||||
##### Raises:
|
||||
|
||||
|
||||
* <b>`TypeError`</b>: if dtype of `s` does not match `dtype`, or `prior` is not a
|
||||
Gaussian object.
|
||||
Normal object.
|
||||
|
||||
|
||||
|
@ -339,7 +339,7 @@ Optimize weights given a loss.
|
||||
|
||||
- - -
|
||||
|
||||
### `tf.contrib.layers.optimize_loss(loss, global_step, learning_rate, optimizer, gradient_noise_scale=None, gradient_multipliers=None, clip_gradients=None, moving_average_decay=0.9, learning_rate_decay_fn=None, variables=None, name=None)` {#optimize_loss}
|
||||
### `tf.contrib.layers.optimize_loss(loss, global_step, learning_rate, optimizer, gradient_noise_scale=None, gradient_multipliers=None, clip_gradients=None, moving_average_decay=0.9, learning_rate_decay_fn=None, update_ops=None, variables=None, name=None)` {#optimize_loss}
|
||||
|
||||
Given loss and parameters for optimizer, returns a training op.
|
||||
|
||||
@ -369,6 +369,8 @@ Given loss and parameters for optimizer, returns a training op.
|
||||
Can be used to implement any learning rate decay
|
||||
functions.
|
||||
For example: tf.train.exponential_decay.
|
||||
* <b>`update_ops`</b>: list of update `Operation`s to execute at each step. If `None`,
|
||||
uses elements of UPDATE_OPS collection.
|
||||
* <b>`variables`</b>: list of variables to optimize or
|
||||
`None` to use all trainable variables.
|
||||
* <b>`name`</b>: The name for this operation is used to scope operations and summaries.
|
||||
|
@ -3396,7 +3396,7 @@ Extracts numpy matrix from pandas DataFrame.
|
||||
|
||||
- - -
|
||||
|
||||
### `tf.contrib.learn.read_batch_examples(file_pattern, batch_size, reader, randomize_input=True, queue_capacity=10000, num_threads=1, name='dequeue_examples')` {#read_batch_examples}
|
||||
### `tf.contrib.learn.read_batch_examples(file_pattern, batch_size, reader, randomize_input=True, num_epochs=None, queue_capacity=10000, num_threads=1, name=None)` {#read_batch_examples}
|
||||
|
||||
Adds operations to read, queue, batch `Example` protos.
|
||||
|
||||
@ -3418,6 +3418,10 @@ All ops are added to the default graph.
|
||||
* <b>`reader`</b>: A function or class that returns an object with
|
||||
`read` method, (filename tensor) -> (example tensor).
|
||||
* <b>`randomize_input`</b>: Whether the input should be randomized.
|
||||
* <b>`num_epochs`</b>: Integer specifying the number of times to read through the
|
||||
dataset. If `None`, cycles through the dataset forever.
|
||||
NOTE - If specified, creates a variable that must be initialized, so call
|
||||
`tf.initialize_all_variables()` as shown in the tests.
|
||||
* <b>`queue_capacity`</b>: Capacity for input queue.
|
||||
* <b>`num_threads`</b>: The number of threads enqueuing examples.
|
||||
* <b>`name`</b>: Name of resulting op.
|
||||
@ -3434,7 +3438,7 @@ All ops are added to the default graph.
|
||||
|
||||
- - -
|
||||
|
||||
### `tf.contrib.learn.read_batch_features(file_pattern, batch_size, features, reader, randomize_input=True, queue_capacity=10000, num_threads=1, name='dequeue_examples')` {#read_batch_features}
|
||||
### `tf.contrib.learn.read_batch_features(file_pattern, batch_size, features, reader, randomize_input=True, num_epochs=None, queue_capacity=10000, reader_num_threads=1, parser_num_threads=1, name=None)` {#read_batch_features}
|
||||
|
||||
Adds operations to read, queue, batch and parse `Example` protos.
|
||||
|
||||
@ -3459,8 +3463,13 @@ All ops are added to the default graph.
|
||||
* <b>`reader`</b>: A function or class that returns an object with
|
||||
`read` method, (filename tensor) -> (example tensor).
|
||||
* <b>`randomize_input`</b>: Whether the input should be randomized.
|
||||
* <b>`num_epochs`</b>: Integer specifying the number of times to read through the
|
||||
dataset. If None, cycles through the dataset forever. NOTE - If specified,
|
||||
creates a variable that must be initialized, so call
|
||||
tf.initialize_all_variables() as shown in the tests.
|
||||
* <b>`queue_capacity`</b>: Capacity for input queue.
|
||||
* <b>`num_threads`</b>: The number of threads enqueuing examples.
|
||||
* <b>`reader_num_threads`</b>: The number of threads to read examples.
|
||||
* <b>`parser_num_threads`</b>: The number of threads to parse examples.
|
||||
* <b>`name`</b>: Name of resulting op.
|
||||
|
||||
##### Returns:
|
||||
@ -3475,7 +3484,7 @@ All ops are added to the default graph.
|
||||
|
||||
- - -
|
||||
|
||||
### `tf.contrib.learn.read_batch_record_features(file_pattern, batch_size, features, randomize_input=True, queue_capacity=10000, num_threads=1, name='dequeue_record_examples')` {#read_batch_record_features}
|
||||
### `tf.contrib.learn.read_batch_record_features(file_pattern, batch_size, features, randomize_input=True, num_epochs=None, queue_capacity=10000, reader_num_threads=1, parser_num_threads=1, name='dequeue_record_examples')` {#read_batch_record_features}
|
||||
|
||||
Reads TFRecord, queues, batches and parses `Example` proto.
|
||||
|
||||
@ -3490,8 +3499,13 @@ See more detailed description in `read_examples`.
|
||||
* <b>`features`</b>: A `dict` mapping feature keys to `FixedLenFeature` or
|
||||
`VarLenFeature` values.
|
||||
* <b>`randomize_input`</b>: Whether the input should be randomized.
|
||||
* <b>`num_epochs`</b>: Integer specifying the number of times to read through the
|
||||
dataset. If None, cycles through the dataset forever. NOTE - If specified,
|
||||
creates a variable that must be initialized, so call
|
||||
tf.initialize_all_variables() as shown in the tests.
|
||||
* <b>`queue_capacity`</b>: Capacity for input queue.
|
||||
* <b>`num_threads`</b>: The number of threads enqueuing examples.
|
||||
* <b>`reader_num_threads`</b>: The number of threads to read examples.
|
||||
* <b>`parser_num_threads`</b>: The number of threads to parse examples.
|
||||
* <b>`name`</b>: Name of resulting op.
|
||||
|
||||
##### Returns:
|
||||
|
@ -1,19 +1,19 @@
|
||||
### `tf.contrib.distributions.gaussian_conjugates_known_sigma_posterior(prior, sigma, s, n)` {#gaussian_conjugates_known_sigma_posterior}
|
||||
### `tf.contrib.distributions.normal_conjugates_known_sigma_posterior(prior, sigma, s, n)` {#normal_conjugates_known_sigma_posterior}
|
||||
|
||||
Posterior Gaussian distribution with conjugate prior on the mean.
|
||||
Posterior Normal distribution with conjugate prior on the mean.
|
||||
|
||||
This model assumes that `n` observations (with sum `s`) come from a
|
||||
Gaussian with unknown mean `mu` (described by the Gaussian `prior`)
|
||||
Normal with unknown mean `mu` (described by the Normal `prior`)
|
||||
and known variance `sigma^2`. The "known sigma posterior" is
|
||||
the distribution of the unknown `mu`.
|
||||
|
||||
Accepts a prior Gaussian distribution object, having parameters
|
||||
Accepts a prior Normal distribution object, having parameters
|
||||
`mu0` and `sigma0`, as well as known `sigma` values of the predictive
|
||||
distribution(s) (also assumed Gaussian),
|
||||
distribution(s) (also assumed Normal),
|
||||
and statistical estimates `s` (the sum(s) of the observations) and
|
||||
`n` (the number(s) of observations).
|
||||
|
||||
Returns a posterior (also Gaussian) distribution object, with parameters
|
||||
Returns a posterior (also Normal) distribution object, with parameters
|
||||
`(mu', sigma'^2)`, where:
|
||||
|
||||
```
|
||||
@ -28,7 +28,7 @@ will broadcast in the case of multidimensional sets of parameters.
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`prior`</b>: `Gaussian` object of type `dtype`:
|
||||
* <b>`prior`</b>: `Normal` object of type `dtype`:
|
||||
the prior distribution having parameters `(mu0, sigma0)`.
|
||||
* <b>`sigma`</b>: tensor of type `dtype`, taking values `sigma > 0`.
|
||||
The known stddev parameter(s).
|
||||
@ -37,12 +37,12 @@ will broadcast in the case of multidimensional sets of parameters.
|
||||
|
||||
##### Returns:
|
||||
|
||||
A new Gaussian posterior distribution object for the unknown observation
|
||||
A new Normal posterior distribution object for the unknown observation
|
||||
mean `mu`.
|
||||
|
||||
##### Raises:
|
||||
|
||||
|
||||
* <b>`TypeError`</b>: if dtype of `s` does not match `dtype`, or `prior` is not a
|
||||
Gaussian object.
|
||||
Normal object.
|
||||
|
@ -1,4 +1,4 @@
|
||||
### `tf.contrib.learn.read_batch_record_features(file_pattern, batch_size, features, randomize_input=True, queue_capacity=10000, num_threads=1, name='dequeue_record_examples')` {#read_batch_record_features}
|
||||
### `tf.contrib.learn.read_batch_record_features(file_pattern, batch_size, features, randomize_input=True, num_epochs=None, queue_capacity=10000, reader_num_threads=1, parser_num_threads=1, name='dequeue_record_examples')` {#read_batch_record_features}
|
||||
|
||||
Reads TFRecord, queues, batches and parses `Example` proto.
|
||||
|
||||
@ -13,8 +13,13 @@ See more detailed description in `read_examples`.
|
||||
* <b>`features`</b>: A `dict` mapping feature keys to `FixedLenFeature` or
|
||||
`VarLenFeature` values.
|
||||
* <b>`randomize_input`</b>: Whether the input should be randomized.
|
||||
* <b>`num_epochs`</b>: Integer specifying the number of times to read through the
|
||||
dataset. If None, cycles through the dataset forever. NOTE - If specified,
|
||||
creates a variable that must be initialized, so call
|
||||
tf.initialize_all_variables() as shown in the tests.
|
||||
* <b>`queue_capacity`</b>: Capacity for input queue.
|
||||
* <b>`num_threads`</b>: The number of threads enqueuing examples.
|
||||
* <b>`reader_num_threads`</b>: The number of threads to read examples.
|
||||
* <b>`parser_num_threads`</b>: The number of threads to parse examples.
|
||||
* <b>`name`</b>: Name of resulting op.
|
||||
|
||||
##### Returns:
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user