diff --git a/tensorflow/contrib/layers/python/layers/optimizers.py b/tensorflow/contrib/layers/python/layers/optimizers.py index 1485cf958cb..3edd9e70105 100644 --- a/tensorflow/contrib/layers/python/layers/optimizers.py +++ b/tensorflow/contrib/layers/python/layers/optimizers.py @@ -53,6 +53,7 @@ def optimize_loss(loss, clip_gradients=None, moving_average_decay=0.9, learning_rate_decay_fn=None, + update_ops=None, variables=None, name=None): """Given loss and parameters for optimizer, returns a training op. @@ -81,6 +82,8 @@ def optimize_loss(loss, Can be used to implement any learning rate decay functions. For example: tf.train.exponential_decay. + update_ops: list of update `Operation`s to execute at each step. If `None`, + uses elements of UPDATE_OPS collection. variables: list of variables to optimize or `None` to use all trainable variables. name: The name for this operation is used to scope operations and summaries. @@ -92,6 +95,15 @@ def optimize_loss(loss, ValueError: if optimizer is wrong type. """ with vs.variable_op_scope([loss, global_step], name, "OptimizeLoss"): + # Update ops take UPDATE_OPS collection if not provided. + update_ops = (set(update_ops or []) or + set(ops.get_collection(ops.GraphKeys.UPDATE_OPS))) + # Make sure update ops are ran before computing loss. + if update_ops: + with ops.control_dependencies(update_ops): + barrier = control_flow_ops.no_op(name="update_barrier") + loss = control_flow_ops.with_dependencies([barrier], loss) + # Moving average of the loss with decay. if moving_average_decay is not None: # Generate moving averages of the loss. diff --git a/tensorflow/contrib/layers/python/layers/optimizers_test.py b/tensorflow/contrib/layers/python/layers/optimizers_test.py index 0f0bfe568b8..49baffb6f95 100644 --- a/tensorflow/contrib/layers/python/layers/optimizers_test.py +++ b/tensorflow/contrib/layers/python/layers/optimizers_test.py @@ -132,6 +132,25 @@ class OptimizersTest(tf.test.TestCase): tf.contrib.layers.optimize_loss( loss, global_step, learning_rate=0.1, optimizer="SGD") + def testUpdateOp(self): + optimizers = ["SGD", tf.train.GradientDescentOptimizer, + tf.train.GradientDescentOptimizer(learning_rate=0.1)] + for optimizer in optimizers: + with tf.Graph().as_default() as g: + with self.test_session(graph=g) as session: + x, var, loss, global_step = _setup_model() + update_op = tf.assign(var, 20) + train = tf.contrib.layers.optimize_loss(loss, + global_step, + learning_rate=0.1, + optimizer=optimizer, + update_ops=[update_op]) + tf.initialize_all_variables().run() + session.run(train, feed_dict={x: 5}) + var_value, global_step_value = session.run([var, global_step]) + # 19.5, due to update of var to 20 before loss computation. + self.assertEqual(var_value, 19.5) + self.assertEqual(global_step_value, 1) if __name__ == "__main__": tf.test.main() diff --git a/tensorflow/contrib/learn/python/learn/graph_actions.py b/tensorflow/contrib/learn/python/learn/graph_actions.py index 898590e52ff..1a69e291d1c 100644 --- a/tensorflow/contrib/learn/python/learn/graph_actions.py +++ b/tensorflow/contrib/learn/python/learn/graph_actions.py @@ -195,7 +195,10 @@ def train(graph, raise ValueError('No "global_step" was provided or found in the graph.') # TODO(ipolosukhin): Replace all functionality of Supervisor with Monitors. - if not monitors: + if not supervisor_is_chief: + # monitors should run only in supervisor. + monitors = [] + elif not monitors: monitors = monitors_lib.get_default_monitors( loss_op=loss_op, summary_op=logging_ops.get_summary_op(), diff --git a/tensorflow/contrib/learn/python/learn/io/graph_io.py b/tensorflow/contrib/learn/python/learn/io/graph_io.py index b9fffb2fb0c..bd1f4f3c0e6 100644 --- a/tensorflow/contrib/learn/python/learn/io/graph_io.py +++ b/tensorflow/contrib/learn/python/learn/io/graph_io.py @@ -26,8 +26,9 @@ from tensorflow.python.training import input as input_ops def read_batch_examples(file_pattern, batch_size, reader, - randomize_input=True, queue_capacity=10000, - num_threads=1, name='dequeue_examples'): + randomize_input=True, num_epochs=None, + queue_capacity=10000, num_threads=1, + name=None): """Adds operations to read, queue, batch `Example` protos. Given file pattern (or list of files), will setup a queue for file names, @@ -46,6 +47,10 @@ def read_batch_examples(file_pattern, batch_size, reader, reader: A function or class that returns an object with `read` method, (filename tensor) -> (example tensor). randomize_input: Whether the input should be randomized. + num_epochs: Integer specifying the number of times to read through the + dataset. If `None`, cycles through the dataset forever. + NOTE - If specified, creates a variable that must be initialized, so call + `tf.initialize_all_variables()` as shown in the tests. queue_capacity: Capacity for input queue. num_threads: The number of threads enqueuing examples. name: Name of resulting op. @@ -82,39 +87,47 @@ def read_batch_examples(file_pattern, batch_size, reader, (batch_size, queue_capacity)) if (not num_threads) or (num_threads <= 0): raise ValueError('Invalid num_threads %s.' % num_threads) + if (num_epochs is not None) and (num_epochs <= 0): + raise ValueError('Invalid num_epochs %s.' % num_epochs) - with ops.name_scope(name) as scope: + with ops.op_scope([file_pattern], name, 'read_batch_examples') as scope: # Setup filename queue with shuffling. with ops.name_scope('file_name_queue') as file_name_queue_scope: file_name_queue = input_ops.string_input_producer( constant_op.constant(file_names, name='input'), - shuffle=randomize_input, name=file_name_queue_scope) + shuffle=randomize_input, num_epochs=num_epochs, + name=file_name_queue_scope) - # Create reader and set it to read from filename queue. + # Create readers, one per thread and set them to read from filename queue. with ops.name_scope('read'): - _, example_proto = reader().read(file_name_queue) + example_list = [] + for _ in range(num_threads): + _, example_proto = reader().read(file_name_queue) + example_list.append([example_proto]) - # Setup batching queue. + # Setup batching queue given list of read example tensors. if randomize_input: if isinstance(batch_size, ops.Tensor): min_after_dequeue = int(queue_capacity * 0.4) else: min_after_dequeue = max(queue_capacity - (3 * batch_size), batch_size) - examples = input_ops.shuffle_batch( - [example_proto], batch_size, capacity=queue_capacity, - num_threads=num_threads, min_after_dequeue=min_after_dequeue, + examples = input_ops.shuffle_batch_join( + example_list, batch_size, capacity=queue_capacity, + min_after_dequeue=min_after_dequeue, name=scope) else: - examples = input_ops.batch( - [example_proto], batch_size, capacity=queue_capacity, - num_threads=num_threads, name=scope) + examples = input_ops.batch_join( + example_list, batch_size, capacity=queue_capacity, + name=scope) return examples def read_batch_features(file_pattern, batch_size, features, reader, - randomize_input=True, queue_capacity=10000, - num_threads=1, name='dequeue_examples'): + randomize_input=True, num_epochs=None, + queue_capacity=10000, reader_num_threads=1, + parser_num_threads=1, + name=None): """Adds operations to read, queue, batch and parse `Example` protos. Given file pattern (or list of files), will setup a queue for file names, @@ -136,8 +149,13 @@ def read_batch_features(file_pattern, batch_size, features, reader, reader: A function or class that returns an object with `read` method, (filename tensor) -> (example tensor). randomize_input: Whether the input should be randomized. + num_epochs: Integer specifying the number of times to read through the + dataset. If None, cycles through the dataset forever. NOTE - If specified, + creates a variable that must be initialized, so call + tf.initialize_all_variables() as shown in the tests. queue_capacity: Capacity for input queue. - num_threads: The number of threads enqueuing examples. + reader_num_threads: The number of threads to read examples. + parser_num_threads: The number of threads to parse examples. name: Name of resulting op. Returns: @@ -146,17 +164,29 @@ def read_batch_features(file_pattern, batch_size, features, reader, Raises: ValueError: for invalid inputs. """ - examples = read_batch_examples( - file_pattern, batch_size, reader, randomize_input, - queue_capacity, num_threads, name=name) + with ops.op_scope([file_pattern], name, 'read_batch_features') as scope: + examples = read_batch_examples( + file_pattern, batch_size, reader, randomize_input=randomize_input, + num_epochs=num_epochs, queue_capacity=queue_capacity, + num_threads=reader_num_threads, name=scope) - # Parse features into tensors. - return parsing_ops.parse_example(examples, features) + # Parse features into tensors in many threads and put on the queue. + features_list = [] + for _ in range(parser_num_threads): + features_list.append(parsing_ops.parse_example(examples, features)) + return input_ops.batch_join( + features_list, + batch_size=batch_size, + capacity=queue_capacity, + enqueue_many=True, + name='parse_example_batch_join') def read_batch_record_features(file_pattern, batch_size, features, - randomize_input=True, queue_capacity=10000, - num_threads=1, name='dequeue_record_examples'): + randomize_input=True, num_epochs=None, + queue_capacity=10000, reader_num_threads=1, + parser_num_threads=1, + name='dequeue_record_examples'): """Reads TFRecord, queues, batches and parses `Example` proto. See more detailed description in `read_examples`. @@ -168,8 +198,13 @@ def read_batch_record_features(file_pattern, batch_size, features, features: A `dict` mapping feature keys to `FixedLenFeature` or `VarLenFeature` values. randomize_input: Whether the input should be randomized. + num_epochs: Integer specifying the number of times to read through the + dataset. If None, cycles through the dataset forever. NOTE - If specified, + creates a variable that must be initialized, so call + tf.initialize_all_variables() as shown in the tests. queue_capacity: Capacity for input queue. - num_threads: The number of threads enqueuing examples. + reader_num_threads: The number of threads to read examples. + parser_num_threads: The number of threads to parse examples. name: Name of resulting op. Returns: @@ -181,5 +216,6 @@ def read_batch_record_features(file_pattern, batch_size, features, return read_batch_features( file_pattern=file_pattern, batch_size=batch_size, features=features, reader=io_ops.TFRecordReader, - randomize_input=randomize_input, - queue_capacity=queue_capacity, num_threads=num_threads, name=name) + randomize_input=randomize_input, num_epochs=num_epochs, + queue_capacity=queue_capacity, reader_num_threads=reader_num_threads, + parser_num_threads=parser_num_threads, name=name) diff --git a/tensorflow/contrib/learn/python/learn/io/graph_io_test.py b/tensorflow/contrib/learn/python/learn/io/graph_io_test.py index 175c29ac4ed..eaf62f003af 100644 --- a/tensorflow/contrib/learn/python/learn/io/graph_io_test.py +++ b/tensorflow/contrib/learn/python/learn/io/graph_io_test.py @@ -17,10 +17,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os import random +import tempfile import tensorflow as tf +from tensorflow.python.framework import errors from tensorflow.python.framework import test_util from tensorflow.python.platform import gfile @@ -55,44 +58,83 @@ class GraphIOTest(tf.test.TestCase): self.assertRaisesRegexp( ValueError, "No files match", - tf.contrib.learn.io.read_batch_features, - _INVALID_FILE_PATTERN, default_batch_size, None, tf.TFRecordReader, - False, queue_capacity, - num_threads, name) + tf.contrib.learn.io.read_batch_examples, + _INVALID_FILE_PATTERN, default_batch_size, tf.TFRecordReader, + False, num_epochs=None, queue_capacity=queue_capacity, + num_threads=num_threads, name=name) self.assertRaisesRegexp( ValueError, "Invalid batch_size", - tf.contrib.learn.io.read_batch_features, - _VALID_FILE_PATTERN, None, None, tf.TFRecordReader, - False, queue_capacity, num_threads, name) + tf.contrib.learn.io.read_batch_examples, + _VALID_FILE_PATTERN, None, tf.TFRecordReader, + False, num_epochs=None, queue_capacity=queue_capacity, + num_threads=num_threads, name=name) self.assertRaisesRegexp( ValueError, "Invalid batch_size", - tf.contrib.learn.io.read_batch_features, - _VALID_FILE_PATTERN, -1, None, tf.TFRecordReader, - False, queue_capacity, num_threads, name) + tf.contrib.learn.io.read_batch_examples, + _VALID_FILE_PATTERN, -1, tf.TFRecordReader, + False, num_epochs=None, queue_capacity=queue_capacity, + num_threads=num_threads, name=name) self.assertRaisesRegexp( ValueError, "Invalid queue_capacity", - tf.contrib.learn.io.read_batch_features, - _VALID_FILE_PATTERN, default_batch_size, None, tf.TFRecordReader, - False, None, num_threads, name) + tf.contrib.learn.io.read_batch_examples, + _VALID_FILE_PATTERN, default_batch_size, tf.TFRecordReader, + False, num_epochs=None, queue_capacity=None, + num_threads=num_threads, name=name) self.assertRaisesRegexp( ValueError, "Invalid num_threads", - tf.contrib.learn.io.read_batch_features, - _VALID_FILE_PATTERN, default_batch_size, None, tf.TFRecordReader, - False, queue_capacity, None, - name) + tf.contrib.learn.io.read_batch_examples, + _VALID_FILE_PATTERN, default_batch_size, tf.TFRecordReader, + False, num_epochs=None, queue_capacity=queue_capacity, + num_threads=None, name=name) self.assertRaisesRegexp( ValueError, "Invalid num_threads", - tf.contrib.learn.io.read_batch_features, - _VALID_FILE_PATTERN, default_batch_size, None, tf.TFRecordReader, - False, queue_capacity, -1, - name) + tf.contrib.learn.io.read_batch_examples, + _VALID_FILE_PATTERN, default_batch_size, tf.TFRecordReader, + False, num_epochs=None, queue_capacity=queue_capacity, + num_threads=-1, name=name) self.assertRaisesRegexp( ValueError, "Invalid batch_size", - tf.contrib.learn.io.read_batch_features, - _VALID_FILE_PATTERN, queue_capacity + 1, None, tf.TFRecordReader, - False, queue_capacity, 1, name) + tf.contrib.learn.io.read_batch_examples, + _VALID_FILE_PATTERN, queue_capacity + 1, tf.TFRecordReader, + False, num_epochs=None, queue_capacity=queue_capacity, + num_threads=1, name=name) + self.assertRaisesRegexp( + ValueError, "Invalid num_epochs", + tf.contrib.learn.io.read_batch_examples, + _VALID_FILE_PATTERN, default_batch_size, tf.TFRecordReader, + False, num_epochs=-1, queue_capacity=queue_capacity, num_threads=1, + name=name) - def test_batch_tf_record(self): + def test_batch_record_features(self): + batch_size = 17 + queue_capacity = 1234 + name = "my_batch" + features = {"feature": tf.FixedLenFeature(shape=[0], dtype=tf.float32)} + + with tf.Graph().as_default() as g, self.test_session(graph=g) as sess: + features = tf.contrib.learn.io.read_batch_record_features( + _VALID_FILE_PATTERN, batch_size, features, randomize_input=False, + queue_capacity=queue_capacity, reader_num_threads=2, + parser_num_threads=2, name=name) + self.assertEquals("%s/parse_example_batch_join:0" % name, + features["feature"].name) + file_name_queue_name = "%s/file_name_queue" % name + file_names_name = "%s/input" % file_name_queue_name + example_queue_name = "%s/fifo_queue" % name + parse_example_queue_name = "%s/parse_example_batch_join" % name + op_nodes = test_util.assert_ops_in_graph({ + file_names_name: "Const", + file_name_queue_name: "FIFOQueue", + "%s/read/TFRecordReader" % name: "TFRecordReader", + example_queue_name: "FIFOQueue", + parse_example_queue_name: "QueueDequeueMany", + name: "QueueDequeueMany" + }, g) + self.assertAllEqual(_FILE_NAMES, sess.run(["%s:0" % file_names_name])[0]) + self.assertEqual( + queue_capacity, op_nodes[example_queue_name].attr["capacity"].i) + + def test_one_epoch(self): batch_size = 17 queue_capacity = 1234 name = "my_batch" @@ -100,20 +142,25 @@ class GraphIOTest(tf.test.TestCase): with tf.Graph().as_default() as g, self.test_session(graph=g) as sess: inputs = tf.contrib.learn.io.read_batch_examples( _VALID_FILE_PATTERN, batch_size, - reader=tf.TFRecordReader, randomize_input=False, + reader=tf.TFRecordReader, randomize_input=True, + num_epochs=1, queue_capacity=queue_capacity, name=name) self.assertEquals("%s:0" % name, inputs.name) file_name_queue_name = "%s/file_name_queue" % name + file_name_queue_limit_name = ( + "%s/limit_epochs/epochs" % file_name_queue_name) file_names_name = "%s/input" % file_name_queue_name - example_queue_name = "%s/fifo_queue" % name + example_queue_name = "%s/random_shuffle_queue" % name op_nodes = test_util.assert_ops_in_graph({ file_names_name: "Const", file_name_queue_name: "FIFOQueue", "%s/read/TFRecordReader" % name: "TFRecordReader", - example_queue_name: "FIFOQueue", - name: "QueueDequeueMany" + example_queue_name: "RandomShuffleQueue", + name: "QueueDequeueMany", + file_name_queue_limit_name: "Variable" }, g) - self.assertAllEqual(_FILE_NAMES, sess.run(["%s:0" % file_names_name])[0]) + self.assertEqual( + set(_FILE_NAMES), set(sess.run(["%s:0" % file_names_name])[0])) self.assertEqual( queue_capacity, op_nodes[example_queue_name].attr["capacity"].i) @@ -143,6 +190,34 @@ class GraphIOTest(tf.test.TestCase): self.assertEqual( queue_capacity, op_nodes[example_queue_name].attr["capacity"].i) + def test_read_csv(self): + gfile.Glob = self._orig_glob + tempdir = tempfile.mkdtemp() + filename = os.path.join(tempdir, "file.csv") + gfile.Open(filename, "w").write("ABC\nDEF\nGHK\n") + + batch_size = 1 + queue_capacity = 5 + name = "my_batch" + + with tf.Graph().as_default() as g, self.test_session(graph=g) as session: + inputs = tf.contrib.learn.io.read_batch_examples( + filename, batch_size, + reader=tf.TextLineReader, randomize_input=False, + num_epochs=1, queue_capacity=queue_capacity, name=name) + session.run(tf.initialize_all_variables()) + + coord = tf.train.Coordinator() + tf.train.start_queue_runners(session, coord=coord) + + self.assertAllEqual(session.run(inputs), [b"ABC"]) + self.assertAllEqual(session.run(inputs), [b"DEF"]) + self.assertAllEqual(session.run(inputs), [b"GHK"]) + with self.assertRaises(errors.OutOfRangeError): + session.run(inputs) + + coord.request_stop() + if __name__ == "__main__": tf.test.main() diff --git a/tensorflow/contrib/linear_optimizer/kernels/squared-loss.h b/tensorflow/contrib/linear_optimizer/kernels/squared-loss.h index fc37a98a4fc..da61e6dbfdb 100644 --- a/tensorflow/contrib/linear_optimizer/kernels/squared-loss.h +++ b/tensorflow/contrib/linear_optimizer/kernels/squared-loss.h @@ -26,15 +26,17 @@ namespace tensorflow { class SquaredLossUpdater : public DualLossUpdater { public: // Closed form solution that decreases the dual squared loss. - // See page 23 of http://arxiv.org/pdf/1309.2375v2.pdf + // See page 23 of http://arxiv.org/pdf/1309.2375v2.pdf for the derivation of + // the update rule when the example weights are equal to 1.0. + // Note: There is a typo in the formula in the paper: the denominator should + // be 1 + ||x_i||^2/(\lambda n) (without the 2 multiplier). double ComputeUpdatedDual(const double label, const double example_weight, const double current_dual, const double wx, const double weighted_example_norm, const double primal_loss_unused, const double dual_loss_unused) const final { - const double delta_numerator = (label - current_dual - wx) * example_weight; - const double delta_denominator = - 1 + weighted_example_norm * example_weight * example_weight * 0.5; + const double delta_numerator = label - current_dual - wx; + const double delta_denominator = 1 + weighted_example_norm * example_weight; return current_dual + delta_numerator / delta_denominator; } diff --git a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py index e9739465ccb..211df8edd1e 100644 --- a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py +++ b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py @@ -455,6 +455,7 @@ class SdcaWithLogisticLossTest(SdcaOptimizerTest): # TODO(katsiaspis): add a test for the case when examples at the end of an # epoch are repeated, since example id may be duplicated. + class SdcaWithLinearLossTest(SdcaOptimizerTest): """SDCA optimizer test class for linear (squared) loss.""" @@ -488,9 +489,11 @@ class SdcaWithLinearLossTest(SdcaOptimizerTest): self.assertAllClose([-20.0 / 3.0, 28.0 / 3.0], predictions.eval(), rtol=0.005) - self.assertAllClose(0.01, + # Approximate gap should be very close to 0.0. (In fact, because the gap + # is only approximate, it is likely that upon convergence the duality gap + # can have a tiny negative value). + self.assertAllClose(0.00, lr.approximate_duality_gap().eval(), - rtol=1e-2, atol=1e-2) def testL2Regularization(self): @@ -580,7 +583,7 @@ class SdcaWithLinearLossTest(SdcaOptimizerTest): {'age': [1], 'gender': [1]}, 14.0, 2.0), ] - example_weights = [1.0, 1.0] + example_weights = [5.0, 3.0] with self._single_threaded_test_session(): examples = make_example_dict(example_protos, example_weights) @@ -597,20 +600,30 @@ class SdcaWithLinearLossTest(SdcaOptimizerTest): for _ in xrange(_MAX_ITERATIONS): train_op.run() - # Predictions should be 8/9 of label due to minimizing regularized loss: - # (label - 2 * 2 * weight)^2 / 2 + L2 * 2 * weight^2 - self.assertAllClose([-10.0 * 8 / 9, 14.0 * 8 / 9], + # There are 4 (sparse) variable weights to be learned. 2 for age and 2 for + # gender. Let w_1, w_2 be age weights, w_3, w_4 be gender weights, y_1, + # y_2 be the labels for examples 1 and 2 respectively and s_1, s_2 the + # corresponding *example* weights. With the given feature values, the loss + # function is given by: + # s_1/2(y_1 + 2w_1 + 2w_3)^2 + s_2/2(y_2 - 2w_2 - 2w_4)^2 + # + \lambda/2 (w_1^2 + w_2^2 + w_3^2 + w_4^2). Solving for the optimal, it + # can be verified that: + # w_1* = w_3* = -2.0 s_1 y_1/(\lambda + 8 s_1) and + # w_2* = w_4* = 2 \cdot s_2 y_2/(\lambda + 8 s_2). Equivalently, due to + # regularization and example weights, the predictions are within: + # 8 \cdot s_i /(\lambda + 8 \cdot s_i) of the labels. + self.assertAllClose([-10 * 40.0 / 41.0, 14.0 * 24 / 25.0], predictions.eval(), - rtol=0.07) + atol=0.01) - def testDenseFeatures(self): + def testDenseFeaturesWithDefaultWeights(self): with self._single_threaded_test_session(): examples = make_dense_examples_dict( - dense_feature_values=[[-2.0, 0.0], [0.0, 2.0]], + dense_feature_values=[[1.0, 0.0], [0.0, 1.0]], weights=[1.0, 1.0], - labels=[-10.0, 14.0]) + labels=[10.0, -5.0]) variables = make_dense_variable_dict(2, 2) - options = dict(symmetric_l2_regularization=1, + options = dict(symmetric_l2_regularization=1.0, symmetric_l1_regularization=0, loss_type='squared_loss') lr = SdcaModel(CONTAINER, examples, variables, options) @@ -621,14 +634,51 @@ class SdcaWithLinearLossTest(SdcaOptimizerTest): for _ in xrange(_MAX_ITERATIONS): train_op.run() - # Predictions should be 4/5 of label due to minimizing regularized loss: - # (label - 2 * weight)^2 / 2 + L2 * weight^2 - self.assertAllClose([-10.0 * 4 / 5, 14.0 * 4 / 5], + # The loss function for these particular features is given by: + # 1/2(label_1-w_1)^2 + 1/2(label_2-w_2)^2 + \lambda/2 (w_1^2 + w_2^2). So, + # differentiating wrt to w_1, w_2 yields the following optimal values: + # w_1* = label_1/(\lambda + 1)= 10/2, w_2* =label_2/(\lambda + 1)= -5/2. + # In this case the (unnormalized regularized) loss will be: + # 1/2(10-5)^2 + 1/2(5-5/2)^2 + 1/2(5^2 + (5/2)^2) = 125.0/4. The actual + # loss should be further normalized by the sum of example weights. + self.assertAllClose([5.0, -2.5], predictions.eval(), rtol=0.01) - loss = lr.regularized_loss(examples) - self.assertAllClose(148.0 / 10.0, loss.eval(), atol=0.01) + self.assertAllClose(125.0 / 8.0, loss.eval(), atol=0.01) + + def testDenseFeaturesWithArbitraryWeights(self): + with self._single_threaded_test_session(): + examples = make_dense_examples_dict( + dense_feature_values=[[1.0, 0.0], [0.0, 1.0]], + weights=[20.0, 10.0], + labels=[10.0, -5.0]) + variables = make_dense_variable_dict(2, 2) + options = dict(symmetric_l2_regularization=5.0, + symmetric_l1_regularization=0, + loss_type='squared_loss') + lr = SdcaModel(CONTAINER, examples, variables, options) + tf.initialize_all_variables().run() + predictions = lr.predictions(examples) + + train_op = lr.minimize() + for _ in xrange(_MAX_ITERATIONS): + train_op.run() + + # The loss function for these particular features is given by: + # 1/2 s_1 (label_1-w_1)^2 + 1/2 s_2(label_2-w_2)^2 + + # \lambda/2 (w_1^2 + w_2^2) where s_1, s_2 are the *example weights. It + # turns out that the optimal (variable) weights are given by: + # w_1* = label_1 \cdot s_1/(\lambda + s_1)= 8.0 and + # w_2* =label_2 \cdot s_2/(\lambda + s_2)= -10/3. + # In this case the (unnormalized regularized) loss will be: + # s_1/2(8-10)^2 + s_2/2(5-10/3)^2 + 5.0/2(8^2 + (10/3)^2) = 2175.0/9. The + # actual loss should be further normalized by the sum of example weights. + self.assertAllClose([8.0, -10.0/3], + predictions.eval(), + rtol=0.01) + loss = lr.regularized_loss(examples) + self.assertAllClose(2175.0 / 270.0, loss.eval(), atol=0.01) class SdcaWithHingeLossTest(SdcaOptimizerTest): diff --git a/tensorflow/contrib/losses/python/losses/__init__.py b/tensorflow/contrib/losses/python/losses/__init__.py index 0a549a6bc27..05bad60cabf 100644 --- a/tensorflow/contrib/losses/python/losses/__init__.py +++ b/tensorflow/contrib/losses/python/losses/__init__.py @@ -19,7 +19,10 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.losses.python.losses.loss_ops import absolute_difference +from tensorflow.contrib.losses.python.losses.loss_ops import add_loss from tensorflow.contrib.losses.python.losses.loss_ops import cosine_distance +from tensorflow.contrib.losses.python.losses.loss_ops import get_losses +from tensorflow.contrib.losses.python.losses.loss_ops import get_total_loss from tensorflow.contrib.losses.python.losses.loss_ops import log from tensorflow.contrib.losses.python.losses.loss_ops import sigmoid_cross_entropy from tensorflow.contrib.losses.python.losses.loss_ops import softmax_cross_entropy diff --git a/tensorflow/contrib/losses/python/losses/loss_ops.py b/tensorflow/contrib/losses/python/losses/loss_ops.py index c9947a4ec2b..2c22fb73d0e 100644 --- a/tensorflow/contrib/losses/python/losses/loss_ops.py +++ b/tensorflow/contrib/losses/python/losses/loss_ops.py @@ -104,9 +104,11 @@ weighted average over the individual prediction errors: weight = tf.div(weight, tf.size(weight)) loss = tf.contrib.losses.sum_of_squares(predictions, depths, weight) - @@absolute_difference +@@add_loss @@cosine_distance +@@get_losses +@@get_total_loss @@log @@sigmoid_cross_entropy @@softmax_cross_entropy @@ -252,6 +254,61 @@ def _num_present(losses, weight, per_batch=False): return num_per_batch if per_batch else math_ops.reduce_sum(num_per_batch) +def add_loss(loss): + """Adds a externally defined loss to collection of losses. + + Args: + loss: A loss `Tensor`. + """ + ops.add_to_collection(ops.GraphKeys.LOSSES, loss) + + +def get_losses(scope=None): + """Gets the list of loss variables. + + Args: + scope: an optional scope for filtering the losses to return. + + Returns: + a list of loss variables. + """ + return ops.get_collection(ops.GraphKeys.LOSSES, scope) + + +def get_regularization_losses(scope=None): + """Gets the regularization losses. + + Args: + scope: an optional scope for filtering the losses to return. + + Returns: + A list of loss variables. + """ + return ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES, scope) + + +def get_total_loss(add_regularization_losses=True, name="total_loss"): + """Returns a tensor whose value represents the total loss. + + Notice that the function adds the given losses to the regularization losses. + + Args: + add_regularization_losses: A boolean indicating whether or not to use the + regularization losses in the sum. + name: The name of the returned tensor. + + Returns: + A `Tensor` whose value represents the total loss. + + Raises: + ValueError: if `losses` is not iterable. + """ + losses = get_losses() + if add_regularization_losses: + losses += get_regularization_losses() + return math_ops.add_n(losses, name=name) + + def absolute_difference(predictions, targets, weight=1.0, scope=None): """Adds an Absolute Difference loss to the training procedure. diff --git a/tensorflow/core/common_runtime/simple_placer.cc b/tensorflow/core/common_runtime/simple_placer.cc index c14c37a4c32..0c2ee20670f 100644 --- a/tensorflow/core/common_runtime/simple_placer.cc +++ b/tensorflow/core/common_runtime/simple_placer.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/simple_placer.h" #include +#include #include #include @@ -182,6 +183,7 @@ class ColocationGraph { Status ColocateNodes(const Node& x, const Node& y) { int x_root = FindRoot(x.id()); int y_root = FindRoot(y.id()); + Status s; if (x_root != y_root) { // Merge the sets by swinging the parent pointer of the smaller @@ -229,6 +231,12 @@ class ColocationGraph { s.error_message()); } + // Transfer ids in the old group to the new one. + members_[new_root].ids_in_group.insert( + members_[old_root].ids_in_group.begin(), + members_[old_root].ids_in_group.end()); + members_[old_root].ids_in_group.clear(); + // Ensure that the common root has at least one supported device // type, by computing the intersection of // members_[new_root].supported_device_types and @@ -267,6 +275,9 @@ class ColocationGraph { return Status::OK(); } + // String containing additional debugging info on failures. + string debug_info; + // We have not yet computed the possible devices for the // colocated node set containing 'node', so we do so now using the // constraints on the root node. @@ -310,6 +321,8 @@ class ColocationGraph { // Return an error when a physical device that matches an explicit // device specification is not found. This ensures that we don't // assign a node to GPU when the user wanted to force it on CPU. + AddDebugInfo(node_root, &debug_info); + DeviceNameUtils::ParsedName specified_device_name; if (DeviceNameUtils::ParseFullName(node->def().device(), &specified_device_name) && @@ -334,16 +347,17 @@ class ColocationGraph { node->def().device(), "' because no devices matching that specification " "are registered in this process; available devices: ", - str_util::Join(device_names, ", ")); + str_util::Join(device_names, ", "), debug_info); } else if (specified_device_name.has_type) { return errors::InvalidArgument( "Could not satisfy explicit device specification '", node->def().device(), "' because no supported kernel for ", - specified_device_name.type, " devices is available"); + specified_device_name.type, " devices is available.", + debug_info); } else { return errors::InvalidArgument( "Could not satisfy explicit device specification '", - node->def().device()); + node->def().device(), debug_info); } } else { // The specified device may be a valid device but the @@ -355,7 +369,7 @@ class ColocationGraph { "required incompatible device '", DeviceNameUtils::ParsedNameToString( members_[node_root].device_name), - "'"); + "'", debug_info); } } } else { @@ -368,10 +382,11 @@ class ColocationGraph { device_set_->devices(), members_[node_root].supported_device_types); if (devices.empty()) { + AddDebugInfo(node_root, &debug_info); return errors::InvalidArgument( "Node had no OpKernel registered to support this operation: ", "Operation was ", node->type_string(), " and inputs were ", - DataTypeVectorString(node->input_types())); + DataTypeVectorString(node->input_types()), debug_info); } } @@ -390,6 +405,15 @@ class ColocationGraph { // id if it is a root. parent <= 0 indicates that this member is invalid. int parent = -1; + // The set of ids that are part of the disjoint node set forest. + // + // This is only fully specified in the root of a disjoint + // node set forest. + std::set ids_in_group; + + // The type of the op for this node. + string op_type; + // A proxy for the depth of the tree that is used to prefer // connecting smaller trees to larger trees when merging disjoint // sets. @@ -410,8 +434,41 @@ class ColocationGraph { std::vector possible_devices; }; + // Adds debugging info to 'output' for the node referred to by + // 'node_root'. + void AddDebugInfo(const int node_root, string* output) { + if (members_[node_root].ids_in_group.size() > 1) { + strings::StrAppend(output, "\nColocation Debug Info:\n"); + + // If this node is part of a colocation group, then we want to + // collect the mapping of ops to supported devices, so that + // the user can see why an unsatisfiable placement occurred. + strings::StrAppend( + output, "Colocation group had the following types and devices: "); + + std::unordered_map type_to_devices; + for (const int id : members_[node_root].ids_in_group) { + const string& op_type = members_[id].op_type; + string devices_registered; + for (const auto& device_type : members_[id].supported_device_types) { + strings::StrAppend(&devices_registered, DeviceTypeString(device_type), + " "); + } + + type_to_devices[op_type] = devices_registered; + } + + for (const auto& td : type_to_devices) { + strings::StrAppend(output, "\n", td.first, ": ", td.second); + } + } + } + Status InitializeMember(const Node& node, Member* member) { const int id = node.id(); + member->ids_in_group.insert(id); + member->op_type = node.type_string(); + if (id < 0) { return errors::InvalidArgument("Node id was not positive: ", id); } diff --git a/tensorflow/core/common_runtime/simple_placer_test.cc b/tensorflow/core/common_runtime/simple_placer_test.cc index cc726f376c4..1d1217bed44 100644 --- a/tensorflow/core/common_runtime/simple_placer_test.cc +++ b/tensorflow/core/common_runtime/simple_placer_test.cc @@ -729,6 +729,12 @@ TEST_F(SimplePlacerTest, TestHeterogeneousDeviceSetFailure) { EXPECT_TRUE(StringPiece(s.error_message()) .contains("colocated with a group of nodes that required " "incompatible device")); + + // The error message should contain information that indicates which + // op types have which registered device types. + EXPECT_TRUE(StringPiece(s.error_message()).contains("VariableGPU: GPU")) << s; + EXPECT_TRUE(StringPiece(s.error_message()).contains("TestAssign: GPU CPU")) + << s; } // Test that placement fails when an unknown device is requested. diff --git a/tensorflow/core/kernels/cholesky_grad.cc b/tensorflow/core/kernels/cholesky_grad.cc index 4fefcee55e4..7a1c44da426 100644 --- a/tensorflow/core/kernels/cholesky_grad.cc +++ b/tensorflow/core/kernels/cholesky_grad.cc @@ -13,75 +13,68 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/framework/op.h" #include "third_party/eigen3/Eigen/Core" - +#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" - -#include "tensorflow/core/kernels/linalg_ops_common.h" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/binary_linalg_ops_common.h" namespace tensorflow { -template -class CholeskyGrad : public OpKernel { +template +class CholeskyGrad + : public BinaryLinearAlgebraOp { public: - explicit CholeskyGrad(OpKernelConstruction* context) : OpKernel(context) {} + explicit CholeskyGrad(OpKernelConstruction* context) + : BinaryLinearAlgebraOp(context) {} + ~CholeskyGrad() override {} + using Matrix = - Eigen::Matrix; + Eigen::Matrix; using ConstMatrixMap = Eigen::Map; using MatrixMap = Eigen::Map; using ConstRef = Eigen::Ref; using Ref = Eigen::Ref; - void Compute(OpKernelContext* context) override { - const Tensor& input_tensor_l = context->input(0); - const Tensor& input_tensor_grad = context->input(1); - // Check that input tensors represent a matrix. - OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_tensor_l.shape()), - errors::InvalidArgument("In[0] is not a matrix")); - OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_tensor_grad.shape()), - errors::InvalidArgument("In[1] is not a matrix")); - // Check that input tensors are square. - OP_REQUIRES(context, - input_tensor_l.dim_size(0) == input_tensor_l.dim_size(1), - errors::InvalidArgument("Input matrix must be square.")); - OP_REQUIRES(context, - input_tensor_grad.dim_size(0) == input_tensor_grad.dim_size(1), - errors::InvalidArgument("Input matrix must be square.")); + TensorShape GetOutputMatrixShape( + const TensorShape& input_matrix_l_full_shape, + const TensorShape& input_matrix_grad_shape) override { + return input_matrix_l_full_shape; + } - // Check that input tensors are of same size. - OP_REQUIRES(context, - input_tensor_l.dim_size(0) == input_tensor_grad.dim_size(0), - errors::InvalidArgument("Input matrices must be same size.")); - - // Create an output tensor - Tensor* output_tensor = NULL; - OP_REQUIRES_OK(context, context->allocate_output( - 0, input_tensor_grad.shape(), &output_tensor)); - - if (output_tensor->NumElements() == 0) { - // the output shape is a 0-element matrix, so there is nothing to do. - return; + int64 GetCostPerUnit(const TensorShape& input_matrix_shape, + const TensorShape& rhs_matrix_shape) override { + const int64 rows = input_matrix_shape.dim_size(0); + if (rows > (1LL << 20)) { + // A big number to cap the cost in case overflow. + return kint64max; + } else { + return rows * rows * rows; } - // The next lines are necessary to get Eigen matrix behaviour. - const ConstMatrixMap input_matrix_l_full(input_tensor_l.flat().data(), - input_tensor_l.dim_size(0), - input_tensor_l.dim_size(1)); - const ConstMatrixMap input_matrix_grad(input_tensor_grad.flat().data(), - input_tensor_grad.dim_size(0), - input_tensor_grad.dim_size(1)); - MatrixMap output_matrix(output_tensor->template flat().data(), - input_tensor_l.dim_size(0), - input_tensor_l.dim_size(1)); + } - // Algorithm only depends on lower triangular half on input_tensor_l. + void ComputeMatrix(OpKernelContext* context, + const ConstMatrixMap& input_matrix_l_full, + const ConstMatrixMap& input_matrix_grad, + MatrixMap* output_matrix) override { + OP_REQUIRES(context, + input_matrix_l_full.rows() == input_matrix_l_full.cols(), + errors::InvalidArgument("Input matrix must be square.")); + OP_REQUIRES( + context, input_matrix_l_full.cols() == input_matrix_grad.cols(), + errors::InvalidArgument( + "Input matrix and gradient must have same number of cols.")); + OP_REQUIRES( + context, input_matrix_l_full.rows() == input_matrix_grad.rows(), + errors::InvalidArgument( + "Input matrix and gradient must have same number of rows.")); + + // Algorithm only depends on lower triangular half on input_matrix_l. const Matrix input_matrix_l = input_matrix_l_full.template triangularView(); // Algorithm only depends on lower triangular half on input_matrix_grad. - output_matrix = input_matrix_grad.template triangularView(); + *output_matrix = input_matrix_grad.template triangularView(); const int64 kMatrixSize = input_matrix_l.rows(); const int64 kMaxBlockSize = 32; @@ -104,20 +97,21 @@ class CholeskyGrad : public OpKernel { auto B = input_matrix_l.block(block_end, 0, trailing_size, block_begin); auto B_bar = - output_matrix.block(block_end, 0, trailing_size, block_begin); + output_matrix->block(block_end, 0, trailing_size, block_begin); auto C = input_matrix_l.block(block_end, block_begin, trailing_size, block_size); - auto C_bar = output_matrix.block(block_end, block_begin, trailing_size, - block_size); + auto C_bar = output_matrix->block(block_end, block_begin, trailing_size, + block_size); auto D = input_matrix_l.block(block_begin, block_begin, block_size, block_size); - auto D_bar = - output_matrix.block(block_begin, block_begin, block_size, block_size); + auto D_bar = output_matrix->block(block_begin, block_begin, block_size, + block_size); auto R = input_matrix_l.block(block_begin, 0, block_size, block_begin); - auto R_bar = output_matrix.block(block_begin, 0, block_size, block_begin); + auto R_bar = + output_matrix->block(block_begin, 0, block_size, block_begin); C_bar = D.adjoint().template triangularView() .solve(C_bar.adjoint()).adjoint(); @@ -127,9 +121,11 @@ class CholeskyGrad : public OpKernel { CholeskyGradUnblocked(D, D_bar); R_bar -= (D_bar + D_bar.adjoint()) * R; } - output_matrix = (0.5 * (output_matrix + output_matrix.transpose())).eval(); + *output_matrix = + (0.5 * (*output_matrix + output_matrix->transpose())).eval(); } - void CholeskyGradUnblocked(const ConstRef l_block, Ref grad_block) { + + void CholeskyGradUnblocked(const ConstRef& l_block, Ref grad_block) { const int64 kMatrixSize = l_block.rows(); for (int64 k = kMatrixSize - 1; k >= 0; k--) { /* This shows the block structure. @@ -166,6 +162,11 @@ class CholeskyGrad : public OpKernel { } }; -REGISTER_LINALG_OP("CholeskyGrad", (CholeskyGrad), float); -REGISTER_LINALG_OP("CholeskyGrad", (CholeskyGrad), double); +REGISTER_BINARY_LINALG_OP("CholeskyGrad", (CholeskyGrad), float); +REGISTER_BINARY_LINALG_OP("CholeskyGrad", (CholeskyGrad), + double); +REGISTER_BINARY_LINALG_OP("BatchCholeskyGrad", (CholeskyGrad), + float); +REGISTER_BINARY_LINALG_OP("BatchCholeskyGrad", (CholeskyGrad), + double); } // namespace tensorflow diff --git a/tensorflow/core/kernels/conv_ops_gpu.h b/tensorflow/core/kernels/conv_ops_gpu.h index 14e7d033eb9..37c3711e4aa 100644 --- a/tensorflow/core/kernels/conv_ops_gpu.h +++ b/tensorflow/core/kernels/conv_ops_gpu.h @@ -64,8 +64,7 @@ class CudnnScratchAllocator : public perftools::gputools::ScratchAllocator { AllocatorAttributes(), allocation_attr)); if (!allocation_status.ok()) { return perftools::gputools::port::StatusOr< - perftools::gputools::DeviceMemory>( - AsDeviceMemory(nullptr, 0)); + perftools::gputools::DeviceMemory>(); } // Hold the reference of the allocated tensors until the end of the // allocator. diff --git a/tensorflow/core/kernels/pooling_ops_common.h b/tensorflow/core/kernels/pooling_ops_common.h index 4bdeb6bf9c5..f38e8e751c5 100644 --- a/tensorflow/core/kernels/pooling_ops_common.h +++ b/tensorflow/core/kernels/pooling_ops_common.h @@ -305,7 +305,7 @@ void SpatialAvgPool(OpKernelContext* context, Tensor* output, const int out_offset = (b * params.out_height + ph) * params.out_width + pw; out_mat.col(out_offset) += in_mat.col(in_offset); - out_count(out_offset)++; + out_count(out_offset) += T(1); } } } diff --git a/tensorflow/core/ops/compat/ops_history.v0.pbtxt b/tensorflow/core/ops/compat/ops_history.v0.pbtxt index 43d1b1002be..f44b7ea05e7 100644 --- a/tensorflow/core/ops/compat/ops_history.v0.pbtxt +++ b/tensorflow/core/ops/compat/ops_history.v0.pbtxt @@ -3175,6 +3175,31 @@ op { } } } +op { + name: "BatchCholeskyGrad" + input_arg { + name: "l" + type_attr: "T" + } + input_arg { + name: "grad" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + } + } + } +} op { name: "BatchFFT" input_arg { diff --git a/tensorflow/core/ops/linalg_ops.cc b/tensorflow/core/ops/linalg_ops.cc index edd60df8ef9..be87022c0a8 100644 --- a/tensorflow/core/ops/linalg_ops.cc +++ b/tensorflow/core/ops/linalg_ops.cc @@ -129,11 +129,34 @@ REGISTER_OP("CholeskyGrad") .Doc(R"doc( Calculates the reverse mode backpropagated gradient of the Cholesky algorithm. -For an explanation see "Differentiation of the Cholesky algorithm" by Iain Murray http://arxiv.org/abs/1602.07527. +For an explanation see "Differentiation of the Cholesky algorithm" by +Iain Murray http://arxiv.org/abs/1602.07527. -l: Output of Cholesky algorithm l = chol(A). Shape is `[M, M]`. Algorithm depends only on lower triangular part of this matrix. -grad: df/dl where f is some scalar function. Shape is `[M, M]'. Algorithm depends only on lower triangular part of this matrix. -output: Symmetrized version of df/dA . Shape is `[M, M]' +l: Output of Cholesky algorithm l = chol(A). Shape is `[M, M]`. + Algorithm depends only on lower triangular part of this matrix. +grad: df/dl where f is some scalar function. Shape is `[M, M]'. + Algorithm depends only on lower triangular part of this matrix. +output: Symmetrized version of df/dA . Shape is `[M, M]'. +)doc"); + +REGISTER_OP("BatchCholeskyGrad") + .Input("l: T") + .Input("grad: T") + .Output("output: T") + .Attr("T: {float, double}") + .Doc(R"doc( +Calculates the reverse mode backpropagated gradient of the Cholesky algorithm. + +For an explanation see "Differentiation of the Cholesky algorithm" by +Iain Murray http://arxiv.org/abs/1602.07527. + +l: Output of batch Cholesky algorithm l = batch_cholesky(A). Shape is `[..., M, M]`. + Algorithm depends only on lower triangular part of the innermost matrices of + this tensor. +grad: df/dl where f is some scalar function. Shape is `[..., M, M]'. + Algorithm depends only on lower triangular part of the innermost matrices of + this tensor. +output: Symmetrized version of df/dA . Shape is `[..., M, M]' )doc"); REGISTER_OP("SelfAdjointEig") diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 8cb7d4870be..2c0c6f30d3c 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -1397,6 +1397,36 @@ op { summary: "Calculates the Cholesky decomposition of a batch of square matrices." description: "The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions\nform square matrices, with the same constraints as the single matrix Cholesky\ndecomposition above. The output is a tensor of the same shape as the input\ncontaining the Cholesky decompositions for all input submatrices `[..., :, :]`." } +op { + name: "BatchCholeskyGrad" + input_arg { + name: "l" + description: "Output of batch Cholesky algorithm l = batch_cholesky(A). Shape is `[..., M, M]`.\nAlgorithm depends only on lower triangular part of the innermost matrices of\nthis tensor." + type_attr: "T" + } + input_arg { + name: "grad" + description: "df/dl where f is some scalar function. Shape is `[..., M, M]\'.\nAlgorithm depends only on lower triangular part of the innermost matrices of\nthis tensor." + type_attr: "T" + } + output_arg { + name: "output" + description: "Symmetrized version of df/dA . Shape is `[..., M, M]\'" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + } + } + } + summary: "Calculates the reverse mode backpropagated gradient of the Cholesky algorithm." + description: "For an explanation see \"Differentiation of the Cholesky algorithm\" by\nIain Murray http://arxiv.org/abs/1602.07527." +} op { name: "BatchFFT" input_arg { @@ -2482,17 +2512,17 @@ op { name: "CholeskyGrad" input_arg { name: "l" - description: "Output of Cholesky algorithm l = chol(A). Shape is `[M, M]`. Algorithm depends only on lower triangular part of this matrix." + description: "Output of Cholesky algorithm l = chol(A). Shape is `[M, M]`.\nAlgorithm depends only on lower triangular part of this matrix." type_attr: "T" } input_arg { name: "grad" - description: "df/dl where f is some scalar function. Shape is `[M, M]\'. Algorithm depends only on lower triangular part of this matrix." + description: "df/dl where f is some scalar function. Shape is `[M, M]\'.\nAlgorithm depends only on lower triangular part of this matrix." type_attr: "T" } output_arg { name: "output" - description: "Symmetrized version of df/dA . Shape is `[M, M]\'" + description: "Symmetrized version of df/dA . Shape is `[M, M]\'." type_attr: "T" } attr { @@ -2506,7 +2536,7 @@ op { } } summary: "Calculates the reverse mode backpropagated gradient of the Cholesky algorithm." - description: "For an explanation see \"Differentiation of the Cholesky algorithm\" by Iain Murray http://arxiv.org/abs/1602.07527." + description: "For an explanation see \"Differentiation of the Cholesky algorithm\" by\nIain Murray http://arxiv.org/abs/1602.07527." } op { name: "Complex" @@ -11482,7 +11512,7 @@ op { } } summary: "Computes the sum of elements across dimensions of a SparseTensor." - description: "This Op takes a SparseTensor and is the sparse counterpart to\n`tf.reduce_sum()`. In particular, this Op also returns a dense `Tensor`\ninstead of a sparse one.\n\nReduces `sp_input` along the dimensions given in `reduction_axes`. Unless\n`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in\n`reduction_axes`. If `keep_dims` is true, the reduced dimensions are retained\nwith length 1.\n\nIf `reduction_axes` has no entries, all dimensions are reduced, and a tensor\nwith a single element is returned." + description: "This Op takes a SparseTensor and is the sparse counterpart to\n`tf.reduce_sum()`. In particular, this Op also returns a dense `Tensor`\ninstead of a sparse one.\n\nReduces `sp_input` along the dimensions given in `reduction_axes`. Unless\n`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in\n`reduction_axes`. If `keep_dims` is true, the reduced dimensions are retained\nwith length 1.\n\nIf `reduction_axes` has no entries, all dimensions are reduced, and a tensor\nwith a single element is returned. Additionally, the axes can be negative,\nwhich are interpreted according to the indexing rules in Python." } op { name: "SparseReorder" diff --git a/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py b/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py index f2811b4dca9..a966e7fa493 100644 --- a/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py +++ b/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py @@ -52,11 +52,11 @@ def train(): # Input placehoolders with tf.name_scope('input'): x = tf.placeholder(tf.float32, [None, 784], name='x-input') + y_ = tf.placeholder(tf.float32, [None, 10], name='y-input') + + with tf.name_scope('input_reshape'): image_shaped_input = tf.reshape(x, [-1, 28, 28, 1]) tf.image_summary('input', image_shaped_input, 10) - y_ = tf.placeholder(tf.float32, [None, 10], name='y-input') - keep_prob = tf.placeholder(tf.float32) - tf.scalar_summary('dropout_keep_probability', keep_prob) # We can't initialize these variables to 0 - the network will get stuck. def weight_variable(shape): @@ -105,7 +105,12 @@ def train(): return activations hidden1 = nn_layer(x, 784, 500, 'layer1') - dropped = tf.nn.dropout(hidden1, keep_prob) + + with tf.name_scope('dropout'): + keep_prob = tf.placeholder(tf.float32) + tf.scalar_summary('dropout_keep_probability', keep_prob) + dropped = tf.nn.dropout(hidden1, keep_prob) + y = nn_layer(dropped, 500, 10, 'layer2', act=tf.nn.softmax) with tf.name_scope('cross_entropy'): @@ -151,9 +156,20 @@ def train(): summary, acc = sess.run([merged, accuracy], feed_dict=feed_dict(False)) test_writer.add_summary(summary, i) print('Accuracy at step %s: %s' % (i, acc)) - else: # Record train set summarieis, and train - summary, _ = sess.run([merged, train_step], feed_dict=feed_dict(True)) - train_writer.add_summary(summary, i) + else: # Record train set summaries, and train + if i % 100 == 99: # Record execution stats + run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) + run_metadata = tf.RunMetadata() + summary, _ = sess.run([merged, train_step], + feed_dict=feed_dict(True), + options=run_options, + run_metadata=run_metadata) + train_writer.add_run_metadata(run_metadata, 'step%d' % i) + train_writer.add_summary(summary, i) + print('Adding run metadata for', i) + else: # Record a summary + summary, _ = sess.run([merged, train_step], feed_dict=feed_dict(True)) + train_writer.add_summary(summary, i) def main(_): diff --git a/tensorflow/g3doc/api_docs/python/contrib.distributions.md b/tensorflow/g3doc/api_docs/python/contrib.distributions.md index 760335eafce..d76098b7002 100644 --- a/tensorflow/g3doc/api_docs/python/contrib.distributions.md +++ b/tensorflow/g3doc/api_docs/python/contrib.distributions.md @@ -1338,9 +1338,9 @@ Variance of each batch member. - - - -### `class tf.contrib.distributions.Gaussian` {#Gaussian} +### `class tf.contrib.distributions.Normal` {#Normal} -The scalar Gaussian distribution with mean and stddev parameters mu, sigma. +The scalar Normal distribution with mean and stddev parameters mu, sigma. #### Mathematical details @@ -1353,15 +1353,15 @@ The PDF of this distribution is: Examples of initialization of one or a batch of distributions. ```python -# Define a single scalar Gaussian distribution. -dist = tf.contrib.distributions.Gaussian(mu=0, sigma=3) +# Define a single scalar Normal distribution. +dist = tf.contrib.distributions.Normal(mu=0, sigma=3) # Evaluate the cdf at 1, returning a scalar. dist.cdf(1) -# Define a batch of two scalar valued Gaussians. +# Define a batch of two scalar valued Normals. # The first has mean 1 and standard deviation 11, the second 2 and 22. -dist = tf.contrib.distributions.Gaussian(mu=[1, 2.], sigma=[11, 22.]) +dist = tf.contrib.distributions.Normal(mu=[1, 2.], sigma=[11, 22.]) # Evaluate the pdf of the first distribution on 0, and the second on 1.5, # returning a length two tensor. @@ -1374,9 +1374,9 @@ dist.sample(3) Arguments are broadcast when possible. ```python -# Define a batch of two scalar valued Gaussians. +# Define a batch of two scalar valued Normals. # Both have mean 1, but different standard deviations. -dist = tf.contrib.distributions.Gaussian(mu=1, sigma=[11, 22.]) +dist = tf.contrib.distributions.Normal(mu=1, sigma=[11, 22.]) # Evaluate the pdf of both distributions on the same point, 3.0, # returning a length 2 tensor. @@ -1384,9 +1384,9 @@ dist.pdf(3.0) ``` - - - -#### `tf.contrib.distributions.Gaussian.__init__(mu, sigma, name=None)` {#Gaussian.__init__} +#### `tf.contrib.distributions.Normal.__init__(mu, sigma, name=None)` {#Normal.__init__} -Construct Gaussian distributions with mean and stddev `mu` and `sigma`. +Construct Normal distributions with mean and stddev `mu` and `sigma`. The parameters `mu` and `sigma` must be shaped in a way that supports broadcasting (e.g. `mu + sigma` is a valid operation). @@ -1407,9 +1407,9 @@ broadcasting (e.g. `mu + sigma` is a valid operation). - - - -#### `tf.contrib.distributions.Gaussian.cdf(x, name=None)` {#Gaussian.cdf} +#### `tf.contrib.distributions.Normal.cdf(x, name=None)` {#Normal.cdf} -CDF of observations in `x` under these Gaussian distribution(s). +CDF of observations in `x` under these Normal distribution(s). ##### Args: @@ -1425,16 +1425,16 @@ CDF of observations in `x` under these Gaussian distribution(s). - - - -#### `tf.contrib.distributions.Gaussian.dtype` {#Gaussian.dtype} +#### `tf.contrib.distributions.Normal.dtype` {#Normal.dtype} - - - -#### `tf.contrib.distributions.Gaussian.entropy(name=None)` {#Gaussian.entropy} +#### `tf.contrib.distributions.Normal.entropy(name=None)` {#Normal.entropy} -The entropy of Gaussian distribution(s). +The entropy of Normal distribution(s). ##### Args: @@ -1449,16 +1449,16 @@ The entropy of Gaussian distribution(s). - - - -#### `tf.contrib.distributions.Gaussian.is_reparameterized` {#Gaussian.is_reparameterized} +#### `tf.contrib.distributions.Normal.is_reparameterized` {#Normal.is_reparameterized} - - - -#### `tf.contrib.distributions.Gaussian.log_cdf(x, name=None)` {#Gaussian.log_cdf} +#### `tf.contrib.distributions.Normal.log_cdf(x, name=None)` {#Normal.log_cdf} -Log CDF of observations `x` under these Gaussian distribution(s). +Log CDF of observations `x` under these Normal distribution(s). ##### Args: @@ -1474,9 +1474,9 @@ Log CDF of observations `x` under these Gaussian distribution(s). - - - -#### `tf.contrib.distributions.Gaussian.log_pdf(x, name=None)` {#Gaussian.log_pdf} +#### `tf.contrib.distributions.Normal.log_pdf(x, name=None)` {#Normal.log_pdf} -Log pdf of observations in `x` under these Gaussian distribution(s). +Log pdf of observations in `x` under these Normal distribution(s). ##### Args: @@ -1492,23 +1492,23 @@ Log pdf of observations in `x` under these Gaussian distribution(s). - - - -#### `tf.contrib.distributions.Gaussian.mean` {#Gaussian.mean} +#### `tf.contrib.distributions.Normal.mean` {#Normal.mean} - - - -#### `tf.contrib.distributions.Gaussian.mu` {#Gaussian.mu} +#### `tf.contrib.distributions.Normal.mu` {#Normal.mu} - - - -#### `tf.contrib.distributions.Gaussian.pdf(x, name=None)` {#Gaussian.pdf} +#### `tf.contrib.distributions.Normal.pdf(x, name=None)` {#Normal.pdf} -The PDF of observations in `x` under these Gaussian distribution(s). +The PDF of observations in `x` under these Normal distribution(s). ##### Args: @@ -1524,9 +1524,9 @@ The PDF of observations in `x` under these Gaussian distribution(s). - - - -#### `tf.contrib.distributions.Gaussian.sample(n, seed=None, name=None)` {#Gaussian.sample} +#### `tf.contrib.distributions.Normal.sample(n, seed=None, name=None)` {#Normal.sample} -Sample `n` observations from the Gaussian Distributions. +Sample `n` observations from the Normal Distributions. ##### Args: @@ -1544,7 +1544,7 @@ Sample `n` observations from the Gaussian Distributions. - - - -#### `tf.contrib.distributions.Gaussian.sigma` {#Gaussian.sigma} +#### `tf.contrib.distributions.Normal.sigma` {#Normal.sigma} @@ -2443,26 +2443,26 @@ probability includes a combinatorial coefficient. Functions that transform conjugate prior/likelihood pairs to distributions representing the posterior or posterior predictive. -### Gaussian likelihood with conjugate prior. +### Normal likelihood with conjugate prior. - - - -### `tf.contrib.distributions.gaussian_conjugates_known_sigma_posterior(prior, sigma, s, n)` {#gaussian_conjugates_known_sigma_posterior} +### `tf.contrib.distributions.normal_conjugates_known_sigma_posterior(prior, sigma, s, n)` {#normal_conjugates_known_sigma_posterior} -Posterior Gaussian distribution with conjugate prior on the mean. +Posterior Normal distribution with conjugate prior on the mean. This model assumes that `n` observations (with sum `s`) come from a -Gaussian with unknown mean `mu` (described by the Gaussian `prior`) +Normal with unknown mean `mu` (described by the Normal `prior`) and known variance `sigma^2`. The "known sigma posterior" is the distribution of the unknown `mu`. -Accepts a prior Gaussian distribution object, having parameters +Accepts a prior Normal distribution object, having parameters `mu0` and `sigma0`, as well as known `sigma` values of the predictive -distribution(s) (also assumed Gaussian), +distribution(s) (also assumed Normal), and statistical estimates `s` (the sum(s) of the observations) and `n` (the number(s) of observations). -Returns a posterior (also Gaussian) distribution object, with parameters +Returns a posterior (also Normal) distribution object, with parameters `(mu', sigma'^2)`, where: ``` @@ -2477,7 +2477,7 @@ will broadcast in the case of multidimensional sets of parameters. ##### Args: -* `prior`: `Gaussian` object of type `dtype`: +* `prior`: `Normal` object of type `dtype`: the prior distribution having parameters `(mu0, sigma0)`. * `sigma`: tensor of type `dtype`, taking values `sigma > 0`. The known stddev parameter(s). @@ -2486,35 +2486,35 @@ will broadcast in the case of multidimensional sets of parameters. ##### Returns: - A new Gaussian posterior distribution object for the unknown observation + A new Normal posterior distribution object for the unknown observation mean `mu`. ##### Raises: * `TypeError`: if dtype of `s` does not match `dtype`, or `prior` is not a - Gaussian object. + Normal object. - - - -### `tf.contrib.distributions.gaussian_congugates_known_sigma_predictive(prior, sigma, s, n)` {#gaussian_congugates_known_sigma_predictive} +### `tf.contrib.distributions.normal_congugates_known_sigma_predictive(prior, sigma, s, n)` {#normal_congugates_known_sigma_predictive} -Posterior predictive Gaussian distribution w. conjugate prior on the mean. +Posterior predictive Normal distribution w. conjugate prior on the mean. This model assumes that `n` observations (with sum `s`) come from a -Gaussian with unknown mean `mu` (described by the Gaussian `prior`) +Normal with unknown mean `mu` (described by the Normal `prior`) and known variance `sigma^2`. The "known sigma predictive" is the distribution of new observations, conditioned on the existing observations and our prior. -Accepts a prior Gaussian distribution object, having parameters +Accepts a prior Normal distribution object, having parameters `mu0` and `sigma0`, as well as known `sigma` values of the predictive -distribution(s) (also assumed Gaussian), +distribution(s) (also assumed Normal), and statistical estimates `s` (the sum(s) of the observations) and `n` (the number(s) of observations). -Calculates the Gaussian distribution(s) `p(x | sigma^2)`: +Calculates the Normal distribution(s) `p(x | sigma^2)`: ``` p(x | sigma^2) = int N(x | mu, sigma^2) N(mu | prior.mu, prior.sigma^2) dmu @@ -2536,7 +2536,7 @@ will broadcast in the case of multidimensional sets of parameters. ##### Args: -* `prior`: `Gaussian` object of type `dtype`: +* `prior`: `Normal` object of type `dtype`: the prior distribution having parameters `(mu0, sigma0)`. * `sigma`: tensor of type `dtype`, taking values `sigma > 0`. The known stddev parameter(s). @@ -2545,12 +2545,12 @@ will broadcast in the case of multidimensional sets of parameters. ##### Returns: - A new Gaussian predictive distribution object. + A new Normal predictive distribution object. ##### Raises: * `TypeError`: if dtype of `s` does not match `dtype`, or `prior` is not a - Gaussian object. + Normal object. diff --git a/tensorflow/g3doc/api_docs/python/contrib.layers.md b/tensorflow/g3doc/api_docs/python/contrib.layers.md index cb03110c9b0..8ba74c6faa1 100644 --- a/tensorflow/g3doc/api_docs/python/contrib.layers.md +++ b/tensorflow/g3doc/api_docs/python/contrib.layers.md @@ -339,7 +339,7 @@ Optimize weights given a loss. - - - -### `tf.contrib.layers.optimize_loss(loss, global_step, learning_rate, optimizer, gradient_noise_scale=None, gradient_multipliers=None, clip_gradients=None, moving_average_decay=0.9, learning_rate_decay_fn=None, variables=None, name=None)` {#optimize_loss} +### `tf.contrib.layers.optimize_loss(loss, global_step, learning_rate, optimizer, gradient_noise_scale=None, gradient_multipliers=None, clip_gradients=None, moving_average_decay=0.9, learning_rate_decay_fn=None, update_ops=None, variables=None, name=None)` {#optimize_loss} Given loss and parameters for optimizer, returns a training op. @@ -369,6 +369,8 @@ Given loss and parameters for optimizer, returns a training op. Can be used to implement any learning rate decay functions. For example: tf.train.exponential_decay. +* `update_ops`: list of update `Operation`s to execute at each step. If `None`, + uses elements of UPDATE_OPS collection. * `variables`: list of variables to optimize or `None` to use all trainable variables. * `name`: The name for this operation is used to scope operations and summaries. diff --git a/tensorflow/g3doc/api_docs/python/contrib.learn.md b/tensorflow/g3doc/api_docs/python/contrib.learn.md index 70aff96a846..af57392deb9 100644 --- a/tensorflow/g3doc/api_docs/python/contrib.learn.md +++ b/tensorflow/g3doc/api_docs/python/contrib.learn.md @@ -3396,7 +3396,7 @@ Extracts numpy matrix from pandas DataFrame. - - - -### `tf.contrib.learn.read_batch_examples(file_pattern, batch_size, reader, randomize_input=True, queue_capacity=10000, num_threads=1, name='dequeue_examples')` {#read_batch_examples} +### `tf.contrib.learn.read_batch_examples(file_pattern, batch_size, reader, randomize_input=True, num_epochs=None, queue_capacity=10000, num_threads=1, name=None)` {#read_batch_examples} Adds operations to read, queue, batch `Example` protos. @@ -3418,6 +3418,10 @@ All ops are added to the default graph. * `reader`: A function or class that returns an object with `read` method, (filename tensor) -> (example tensor). * `randomize_input`: Whether the input should be randomized. +* `num_epochs`: Integer specifying the number of times to read through the + dataset. If `None`, cycles through the dataset forever. + NOTE - If specified, creates a variable that must be initialized, so call + `tf.initialize_all_variables()` as shown in the tests. * `queue_capacity`: Capacity for input queue. * `num_threads`: The number of threads enqueuing examples. * `name`: Name of resulting op. @@ -3434,7 +3438,7 @@ All ops are added to the default graph. - - - -### `tf.contrib.learn.read_batch_features(file_pattern, batch_size, features, reader, randomize_input=True, queue_capacity=10000, num_threads=1, name='dequeue_examples')` {#read_batch_features} +### `tf.contrib.learn.read_batch_features(file_pattern, batch_size, features, reader, randomize_input=True, num_epochs=None, queue_capacity=10000, reader_num_threads=1, parser_num_threads=1, name=None)` {#read_batch_features} Adds operations to read, queue, batch and parse `Example` protos. @@ -3459,8 +3463,13 @@ All ops are added to the default graph. * `reader`: A function or class that returns an object with `read` method, (filename tensor) -> (example tensor). * `randomize_input`: Whether the input should be randomized. +* `num_epochs`: Integer specifying the number of times to read through the + dataset. If None, cycles through the dataset forever. NOTE - If specified, + creates a variable that must be initialized, so call + tf.initialize_all_variables() as shown in the tests. * `queue_capacity`: Capacity for input queue. -* `num_threads`: The number of threads enqueuing examples. +* `reader_num_threads`: The number of threads to read examples. +* `parser_num_threads`: The number of threads to parse examples. * `name`: Name of resulting op. ##### Returns: @@ -3475,7 +3484,7 @@ All ops are added to the default graph. - - - -### `tf.contrib.learn.read_batch_record_features(file_pattern, batch_size, features, randomize_input=True, queue_capacity=10000, num_threads=1, name='dequeue_record_examples')` {#read_batch_record_features} +### `tf.contrib.learn.read_batch_record_features(file_pattern, batch_size, features, randomize_input=True, num_epochs=None, queue_capacity=10000, reader_num_threads=1, parser_num_threads=1, name='dequeue_record_examples')` {#read_batch_record_features} Reads TFRecord, queues, batches and parses `Example` proto. @@ -3490,8 +3499,13 @@ See more detailed description in `read_examples`. * `features`: A `dict` mapping feature keys to `FixedLenFeature` or `VarLenFeature` values. * `randomize_input`: Whether the input should be randomized. +* `num_epochs`: Integer specifying the number of times to read through the + dataset. If None, cycles through the dataset forever. NOTE - If specified, + creates a variable that must be initialized, so call + tf.initialize_all_variables() as shown in the tests. * `queue_capacity`: Capacity for input queue. -* `num_threads`: The number of threads enqueuing examples. +* `reader_num_threads`: The number of threads to read examples. +* `parser_num_threads`: The number of threads to parse examples. * `name`: Name of resulting op. ##### Returns: diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.OpError.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.OpError.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.OpError.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.OpError.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.ReaderBase.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.ReaderBase.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.ReaderBase.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.ReaderBase.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.TFRecordReader.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.TFRecordReader.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.TFRecordReader.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.TFRecordReader.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.Variable.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.Variable.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.Variable.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.Variable.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.argmin.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.argmin.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.argmin.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.argmin.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.assert_less_equal.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.assert_less_equal.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.assert_less_equal.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.assert_less_equal.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.assert_rank.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.assert_rank.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.assert_rank.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.assert_rank.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.batch_fft.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.batch_fft.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.batch_fft.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.batch_fft.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.batch_matrix_band_part.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.batch_matrix_band_part.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.batch_matrix_band_part.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.batch_matrix_band_part.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.complex_abs.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.complex_abs.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.complex_abs.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.complex_abs.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.copy_graph.copy_op_to_graph.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.copy_graph.copy_op_to_graph.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.copy_graph.copy_op_to_graph.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.copy_graph.copy_op_to_graph.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.distributions.Exponential.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.distributions.Exponential.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.distributions.Exponential.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.distributions.Exponential.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.distributions.MultivariateNormal.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.distributions.MultivariateNormal.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.distributions.MultivariateNormal.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.distributions.MultivariateNormal.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.distributions.gaussian_conjugates_known_sigma_posterior.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.distributions.normal_conjugates_known_sigma_posterior.md similarity index 64% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.distributions.gaussian_conjugates_known_sigma_posterior.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.distributions.normal_conjugates_known_sigma_posterior.md index f2ff765acb6..ae8eb008903 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.distributions.gaussian_conjugates_known_sigma_posterior.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.distributions.normal_conjugates_known_sigma_posterior.md @@ -1,19 +1,19 @@ -### `tf.contrib.distributions.gaussian_conjugates_known_sigma_posterior(prior, sigma, s, n)` {#gaussian_conjugates_known_sigma_posterior} +### `tf.contrib.distributions.normal_conjugates_known_sigma_posterior(prior, sigma, s, n)` {#normal_conjugates_known_sigma_posterior} -Posterior Gaussian distribution with conjugate prior on the mean. +Posterior Normal distribution with conjugate prior on the mean. This model assumes that `n` observations (with sum `s`) come from a -Gaussian with unknown mean `mu` (described by the Gaussian `prior`) +Normal with unknown mean `mu` (described by the Normal `prior`) and known variance `sigma^2`. The "known sigma posterior" is the distribution of the unknown `mu`. -Accepts a prior Gaussian distribution object, having parameters +Accepts a prior Normal distribution object, having parameters `mu0` and `sigma0`, as well as known `sigma` values of the predictive -distribution(s) (also assumed Gaussian), +distribution(s) (also assumed Normal), and statistical estimates `s` (the sum(s) of the observations) and `n` (the number(s) of observations). -Returns a posterior (also Gaussian) distribution object, with parameters +Returns a posterior (also Normal) distribution object, with parameters `(mu', sigma'^2)`, where: ``` @@ -28,7 +28,7 @@ will broadcast in the case of multidimensional sets of parameters. ##### Args: -* `prior`: `Gaussian` object of type `dtype`: +* `prior`: `Normal` object of type `dtype`: the prior distribution having parameters `(mu0, sigma0)`. * `sigma`: tensor of type `dtype`, taking values `sigma > 0`. The known stddev parameter(s). @@ -37,12 +37,12 @@ will broadcast in the case of multidimensional sets of parameters. ##### Returns: - A new Gaussian posterior distribution object for the unknown observation + A new Normal posterior distribution object for the unknown observation mean `mu`. ##### Raises: * `TypeError`: if dtype of `s` does not match `dtype`, or `prior` is not a - Gaussian object. + Normal object. diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.layers.xavier_initializer.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.layers.xavier_initializer.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.layers.xavier_initializer.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.layers.xavier_initializer.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.TensorFlowLinearRegressor.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.learn.TensorFlowLinearRegressor.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.TensorFlowLinearRegressor.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.learn.TensorFlowLinearRegressor.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.extract_dask_data.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.learn.extract_dask_data.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.extract_dask_data.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.learn.extract_dask_data.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.read_batch_record_features.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.learn.read_batch_record_features.md similarity index 57% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.read_batch_record_features.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.learn.read_batch_record_features.md index f7aeb003daa..aa4e964be14 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.read_batch_record_features.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.learn.read_batch_record_features.md @@ -1,4 +1,4 @@ -### `tf.contrib.learn.read_batch_record_features(file_pattern, batch_size, features, randomize_input=True, queue_capacity=10000, num_threads=1, name='dequeue_record_examples')` {#read_batch_record_features} +### `tf.contrib.learn.read_batch_record_features(file_pattern, batch_size, features, randomize_input=True, num_epochs=None, queue_capacity=10000, reader_num_threads=1, parser_num_threads=1, name='dequeue_record_examples')` {#read_batch_record_features} Reads TFRecord, queues, batches and parses `Example` proto. @@ -13,8 +13,13 @@ See more detailed description in `read_examples`. * `features`: A `dict` mapping feature keys to `FixedLenFeature` or `VarLenFeature` values. * `randomize_input`: Whether the input should be randomized. +* `num_epochs`: Integer specifying the number of times to read through the + dataset. If None, cycles through the dataset forever. NOTE - If specified, + creates a variable that must be initialized, so call + tf.initialize_all_variables() as shown in the tests. * `queue_capacity`: Capacity for input queue. -* `num_threads`: The number of threads enqueuing examples. +* `reader_num_threads`: The number of threads to read examples. +* `parser_num_threads`: The number of threads to parse examples. * `name`: Name of resulting op. ##### Returns: diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.train.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.learn.train.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.train.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.learn.train.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.metrics.set_difference.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.metrics.set_difference.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.metrics.set_difference.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.metrics.set_difference.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.metrics.streaming_mean_absolute_error.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.metrics.streaming_mean_absolute_error.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.metrics.streaming_mean_absolute_error.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.metrics.streaming_mean_absolute_error.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.metrics.streaming_mean_cosine_distance.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.metrics.streaming_mean_cosine_distance.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.metrics.streaming_mean_cosine_distance.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.contrib.metrics.streaming_mean_cosine_distance.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.diag_part.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.diag_part.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.diag_part.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.diag_part.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.erf.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.erf.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.erf.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.erf.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.greater_equal.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.greater_equal.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.greater_equal.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.greater_equal.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.group.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.group.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.group.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.group.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.adjust_contrast.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.image.adjust_contrast.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.adjust_contrast.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.image.adjust_contrast.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.random_contrast.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.image.random_contrast.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.random_contrast.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.image.random_contrast.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.sample_distorted_bounding_box.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.image.sample_distorted_bounding_box.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.sample_distorted_bounding_box.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.image.sample_distorted_bounding_box.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.import_graph_def.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.import_graph_def.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.import_graph_def.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.import_graph_def.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.initialize_local_variables.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.initialize_local_variables.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.initialize_local_variables.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.initialize_local_variables.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.initialize_variables.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.initialize_variables.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.initialize_variables.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.initialize_variables.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.inv.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.inv.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.inv.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.inv.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.local_variables.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.local_variables.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.local_variables.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.local_variables.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.logical_xor.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.logical_xor.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.logical_xor.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.logical_xor.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.matrix_solve_ls.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.matrix_solve_ls.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.matrix_solve_ls.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.matrix_solve_ls.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.avg_pool.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.nn.avg_pool.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.avg_pool.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.nn.avg_pool.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.batch_normalization.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.nn.batch_normalization.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.batch_normalization.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.nn.batch_normalization.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.in_top_k.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.nn.in_top_k.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.in_top_k.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.nn.in_top_k.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.softsign.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.nn.softsign.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.softsign.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.nn.softsign.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.sufficient_statistics.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.nn.sufficient_statistics.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.sufficient_statistics.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.nn.sufficient_statistics.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.ones_initializer.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.ones_initializer.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.ones_initializer.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.ones_initializer.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.placeholder.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.placeholder.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.placeholder.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.placeholder.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.polygamma.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.polygamma.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.polygamma.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.polygamma.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.range.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.range.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.range.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.range.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.register_tensor_conversion_function.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.register_tensor_conversion_function.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.register_tensor_conversion_function.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.register_tensor_conversion_function.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.reverse_sequence.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.reverse_sequence.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.reverse_sequence.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.reverse_sequence.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.segment_sum.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.segment_sum.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.segment_sum.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.segment_sum.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.self_adjoint_eig.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.self_adjoint_eig.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.self_adjoint_eig.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.self_adjoint_eig.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sparse_add.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.sparse_add.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sparse_add.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.sparse_add.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sqrt.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.sqrt.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sqrt.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.sqrt.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.stop_gradient.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.stop_gradient.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.stop_gradient.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.stop_gradient.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.test.compute_gradient_error.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.test.compute_gradient_error.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.test.compute_gradient_error.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.test.compute_gradient_error.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.to_int32.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.to_int32.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.to_int32.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.to_int32.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.trace.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.trace.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.trace.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.trace.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.Optimizer.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.train.Optimizer.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.Optimizer.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.train.Optimizer.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.Saver.from_proto.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.train.Saver.from_proto.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.Saver.from_proto.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.train.Saver.from_proto.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.Saver.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.train.Saver.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.Saver.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.train.Saver.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.string_input_producer.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.train.string_input_producer.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.string_input_producer.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.train.string_input_producer.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.truediv.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.truediv.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.truediv.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.truediv.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.while_loop.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.while_loop.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.while_loop.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard0/tf.while_loop.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.DType.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.DType.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.DType.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.DType.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.Dimension.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.Dimension.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.Dimension.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.Dimension.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.RegisterGradient.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.RegisterGradient.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.RegisterGradient.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.RegisterGradient.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.Tensor.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.Tensor.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.Tensor.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.Tensor.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.abs.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.abs.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.abs.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.abs.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.add_n.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.add_n.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.add_n.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.add_n.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.assert_variables_initialized.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.assert_variables_initialized.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.assert_variables_initialized.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.assert_variables_initialized.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.batch_self_adjoint_eig.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.batch_self_adjoint_eig.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.batch_self_adjoint_eig.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.batch_self_adjoint_eig.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.distributions.DirichletMultinomial.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.distributions.DirichletMultinomial.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.distributions.DirichletMultinomial.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.distributions.DirichletMultinomial.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.layers.apply_regularization.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.layers.apply_regularization.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.layers.apply_regularization.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.layers.apply_regularization.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.layers.summarize_tensor.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.layers.summarize_tensor.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.layers.summarize_tensor.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.layers.summarize_tensor.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.RunConfig.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.learn.RunConfig.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.RunConfig.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.learn.RunConfig.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.metrics.streaming_mean_relative_error.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.metrics.streaming_mean_relative_error.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.metrics.streaming_mean_relative_error.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.contrib.metrics.streaming_mean_relative_error.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.control_dependencies.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.control_dependencies.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.control_dependencies.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.control_dependencies.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.diag.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.diag.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.diag.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.diag.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.dynamic_partition.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.dynamic_partition.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.dynamic_partition.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.dynamic_partition.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.errors.ResourceExhaustedError.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.errors.ResourceExhaustedError.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.errors.ResourceExhaustedError.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.errors.ResourceExhaustedError.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.histogram_fixed_width.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.histogram_fixed_width.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.histogram_fixed_width.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.histogram_fixed_width.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.ifft.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.ifft.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.ifft.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.ifft.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.adjust_brightness.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.image.adjust_brightness.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.adjust_brightness.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.image.adjust_brightness.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.central_crop.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.image.central_crop.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.central_crop.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.image.central_crop.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.draw_bounding_boxes.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.image.draw_bounding_boxes.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.draw_bounding_boxes.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.image.draw_bounding_boxes.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.encode_jpeg.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.image.encode_jpeg.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.encode_jpeg.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.image.encode_jpeg.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.encode_png.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.image.encode_png.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.encode_png.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.image.encode_png.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.transpose_image.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.image.transpose_image.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.transpose_image.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.image.transpose_image.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.invert_permutation.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.invert_permutation.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.invert_permutation.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.invert_permutation.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.log.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.log.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.log.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.log.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.logical_or.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.logical_or.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.logical_or.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.logical_or.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.matrix_determinant.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.matrix_determinant.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.matrix_determinant.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.matrix_determinant.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.minimum.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.minimum.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.minimum.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.minimum.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.atrous_conv2d.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.nn.atrous_conv2d.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.atrous_conv2d.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.nn.atrous_conv2d.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.conv2d.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.nn.conv2d.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.conv2d.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.nn.conv2d.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.embedding_lookup.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.nn.embedding_lookup.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.embedding_lookup.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.nn.embedding_lookup.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.log_softmax.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.nn.log_softmax.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.log_softmax.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.nn.log_softmax.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.separable_conv2d.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.nn.separable_conv2d.md similarity index 91% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.separable_conv2d.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.nn.separable_conv2d.md index f4be03303fc..a88c2112efe 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.separable_conv2d.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.nn.separable_conv2d.md @@ -38,3 +38,9 @@ horizontal and vertical strides, `strides = [1, stride, stride, 1]`. A 4-D `Tensor` of shape `[batch, out_height, out_width, out_channels]`. +##### Raises: + + +* `ValueError`: If channel_multiplier * in_channels > out_channels, + which means that the separable convolution is overparameterized. + diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.sigmoid_cross_entropy_with_logits.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.nn.sigmoid_cross_entropy_with_logits.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.sigmoid_cross_entropy_with_logits.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.nn.sigmoid_cross_entropy_with_logits.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.uniform_candidate_sampler.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.nn.uniform_candidate_sampler.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.uniform_candidate_sampler.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.nn.uniform_candidate_sampler.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.not_equal.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.not_equal.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.not_equal.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.not_equal.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.python_io.TFRecordWriter.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.python_io.TFRecordWriter.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.python_io.TFRecordWriter.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.python_io.TFRecordWriter.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.scalar_summary.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.scalar_summary.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.scalar_summary.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.scalar_summary.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.scatter_add.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.scatter_add.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.scatter_add.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.scatter_add.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.segment_mean.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.segment_mean.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.segment_mean.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.segment_mean.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.set_random_seed.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.set_random_seed.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.set_random_seed.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.set_random_seed.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.space_to_batch.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.space_to_batch.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.space_to_batch.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.space_to_batch.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sparse_retain.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.sparse_retain.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sparse_retain.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.sparse_retain.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sparse_segment_mean.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.sparse_segment_mean.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sparse_segment_mean.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.sparse_segment_mean.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sparse_softmax.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.sparse_softmax.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sparse_softmax.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.sparse_softmax.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sparse_split.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.sparse_split.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sparse_split.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.sparse_split.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.split.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.split.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.split.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.split.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.string_to_hash_bucket_strong.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.string_to_hash_bucket_strong.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.string_to_hash_bucket_strong.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.string_to_hash_bucket_strong.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.AdagradOptimizer.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.train.AdagradOptimizer.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.AdagradOptimizer.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.train.AdagradOptimizer.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.AdamOptimizer.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.train.AdamOptimizer.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.AdamOptimizer.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.train.AdamOptimizer.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.input_producer.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.train.input_producer.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.input_producer.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.train.input_producer.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.shuffle_batch.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.train.shuffle_batch.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.shuffle_batch.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.train.shuffle_batch.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.update_checkpoint_state.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.train.update_checkpoint_state.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.update_checkpoint_state.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.train.update_checkpoint_state.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.trainable_variables.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.trainable_variables.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.trainable_variables.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.trainable_variables.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.unique_with_counts.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.unique_with_counts.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.unique_with_counts.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard1/tf.unique_with_counts.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.Print.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.Print.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.Print.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.Print.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.Variable.from_proto.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.Variable.from_proto.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.Variable.from_proto.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.Variable.from_proto.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.add_check_numerics_ops.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.add_check_numerics_ops.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.add_check_numerics_ops.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.add_check_numerics_ops.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.audio_summary.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.audio_summary.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.audio_summary.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.audio_summary.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.batch_fft2d.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.batch_fft2d.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.batch_fft2d.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.batch_fft2d.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.case.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.case.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.case.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.case.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.concat.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.concat.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.concat.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.concat.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.distributions.Gamma.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.contrib.distributions.Gamma.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.distributions.Gamma.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.contrib.distributions.Gamma.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.ffmpeg.decode_audio.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.contrib.ffmpeg.decode_audio.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.ffmpeg.decode_audio.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.contrib.ffmpeg.decode_audio.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.layers.summarize_activations.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.contrib.layers.summarize_activations.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.layers.summarize_activations.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.contrib.layers.summarize_activations.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.metrics.auc_using_histogram.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.contrib.metrics.auc_using_histogram.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.metrics.auc_using_histogram.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.contrib.metrics.auc_using_histogram.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.metrics.streaming_recall.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.contrib.metrics.streaming_recall.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.metrics.streaming_recall.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.contrib.metrics.streaming_recall.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.util.make_ndarray.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.contrib.util.make_ndarray.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.util.make_ndarray.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.contrib.util.make_ndarray.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.util.make_tensor_proto.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.contrib.util.make_tensor_proto.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.util.make_tensor_proto.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.contrib.util.make_tensor_proto.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.convert_to_tensor.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.convert_to_tensor.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.convert_to_tensor.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.convert_to_tensor.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.errors.InternalError.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.errors.InternalError.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.errors.InternalError.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.errors.InternalError.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.errors.UnauthenticatedError.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.errors.UnauthenticatedError.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.errors.UnauthenticatedError.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.errors.UnauthenticatedError.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.errors.UnavailableError.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.errors.UnavailableError.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.errors.UnavailableError.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.errors.UnavailableError.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.foldl.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.foldl.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.foldl.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.foldl.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.get_default_graph.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.get_default_graph.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.get_default_graph.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.get_default_graph.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.get_seed.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.get_seed.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.get_seed.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.get_seed.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.grayscale_to_rgb.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.image.grayscale_to_rgb.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.grayscale_to_rgb.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.image.grayscale_to_rgb.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.random_flip_up_down.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.image.random_flip_up_down.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.random_flip_up_down.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.image.random_flip_up_down.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.resize_area.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.image.resize_area.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.resize_area.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.image.resize_area.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.resize_nearest_neighbor.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.image.resize_nearest_neighbor.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.resize_nearest_neighbor.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.image.resize_nearest_neighbor.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.is_nan.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.is_nan.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.is_nan.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.is_nan.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.multinomial.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.multinomial.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.multinomial.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.multinomial.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.conv3d.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.nn.conv3d.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.conv3d.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.nn.conv3d.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.l2_loss.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.nn.l2_loss.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.l2_loss.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.nn.l2_loss.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.sampled_softmax_loss.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.nn.sampled_softmax_loss.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.sampled_softmax_loss.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.nn.sampled_softmax_loss.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.softmax.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.nn.softmax.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.softmax.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.nn.softmax.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.pad.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.pad.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.pad.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.pad.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.scatter_sub.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.scatter_sub.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.scatter_sub.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.scatter_sub.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.select.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.select.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.select.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.select.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sparse_mask.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.sparse_mask.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sparse_mask.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.sparse_mask.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.to_double.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.to_double.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.to_double.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.to_double.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.FtrlOptimizer.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.train.FtrlOptimizer.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.FtrlOptimizer.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.train.FtrlOptimizer.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.LooperThread.loop.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.train.LooperThread.loop.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.LooperThread.loop.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.train.LooperThread.loop.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.Supervisor.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.train.Supervisor.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.Supervisor.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.train.Supervisor.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.add_queue_runner.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.train.add_queue_runner.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.add_queue_runner.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.train.add_queue_runner.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.limit_epochs.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.train.limit_epochs.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.limit_epochs.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.train.limit_epochs.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.match_filenames_once.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.train.match_filenames_once.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.match_filenames_once.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.train.match_filenames_once.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.transpose.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.transpose.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.transpose.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.transpose.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.unpack.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.unpack.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.unpack.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.unpack.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.DeviceSpec.from_string.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.DeviceSpec.from_string.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.DeviceSpec.from_string.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.DeviceSpec.from_string.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.SparseTensorValue.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.SparseTensorValue.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.SparseTensorValue.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.SparseTensorValue.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.accumulate_n.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.accumulate_n.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.accumulate_n.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.accumulate_n.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.add_to_collection.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.add_to_collection.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.add_to_collection.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.add_to_collection.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.assert_non_positive.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.assert_non_positive.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.assert_non_positive.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.assert_non_positive.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.batch_ifft3d.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.batch_ifft3d.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.batch_ifft3d.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.batch_ifft3d.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.batch_matmul.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.batch_matmul.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.batch_matmul.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.batch_matmul.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.batch_matrix_triangular_solve.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.batch_matrix_triangular_solve.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.batch_matrix_triangular_solve.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.batch_matrix_triangular_solve.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.boolean_mask.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.boolean_mask.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.boolean_mask.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.boolean_mask.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.constant_initializer.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.constant_initializer.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.constant_initializer.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.constant_initializer.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.distributions.Gaussian.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.contrib.distributions.Normal.md similarity index 60% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.distributions.Gaussian.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.contrib.distributions.Normal.md index 7fcc6249faa..d15dd93a652 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.distributions.Gaussian.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.contrib.distributions.Normal.md @@ -1,4 +1,4 @@ -The scalar Gaussian distribution with mean and stddev parameters mu, sigma. +The scalar Normal distribution with mean and stddev parameters mu, sigma. #### Mathematical details @@ -11,15 +11,15 @@ The PDF of this distribution is: Examples of initialization of one or a batch of distributions. ```python -# Define a single scalar Gaussian distribution. -dist = tf.contrib.distributions.Gaussian(mu=0, sigma=3) +# Define a single scalar Normal distribution. +dist = tf.contrib.distributions.Normal(mu=0, sigma=3) # Evaluate the cdf at 1, returning a scalar. dist.cdf(1) -# Define a batch of two scalar valued Gaussians. +# Define a batch of two scalar valued Normals. # The first has mean 1 and standard deviation 11, the second 2 and 22. -dist = tf.contrib.distributions.Gaussian(mu=[1, 2.], sigma=[11, 22.]) +dist = tf.contrib.distributions.Normal(mu=[1, 2.], sigma=[11, 22.]) # Evaluate the pdf of the first distribution on 0, and the second on 1.5, # returning a length two tensor. @@ -32,9 +32,9 @@ dist.sample(3) Arguments are broadcast when possible. ```python -# Define a batch of two scalar valued Gaussians. +# Define a batch of two scalar valued Normals. # Both have mean 1, but different standard deviations. -dist = tf.contrib.distributions.Gaussian(mu=1, sigma=[11, 22.]) +dist = tf.contrib.distributions.Normal(mu=1, sigma=[11, 22.]) # Evaluate the pdf of both distributions on the same point, 3.0, # returning a length 2 tensor. @@ -42,9 +42,9 @@ dist.pdf(3.0) ``` - - - -#### `tf.contrib.distributions.Gaussian.__init__(mu, sigma, name=None)` {#Gaussian.__init__} +#### `tf.contrib.distributions.Normal.__init__(mu, sigma, name=None)` {#Normal.__init__} -Construct Gaussian distributions with mean and stddev `mu` and `sigma`. +Construct Normal distributions with mean and stddev `mu` and `sigma`. The parameters `mu` and `sigma` must be shaped in a way that supports broadcasting (e.g. `mu + sigma` is a valid operation). @@ -65,9 +65,9 @@ broadcasting (e.g. `mu + sigma` is a valid operation). - - - -#### `tf.contrib.distributions.Gaussian.cdf(x, name=None)` {#Gaussian.cdf} +#### `tf.contrib.distributions.Normal.cdf(x, name=None)` {#Normal.cdf} -CDF of observations in `x` under these Gaussian distribution(s). +CDF of observations in `x` under these Normal distribution(s). ##### Args: @@ -83,16 +83,16 @@ CDF of observations in `x` under these Gaussian distribution(s). - - - -#### `tf.contrib.distributions.Gaussian.dtype` {#Gaussian.dtype} +#### `tf.contrib.distributions.Normal.dtype` {#Normal.dtype} - - - -#### `tf.contrib.distributions.Gaussian.entropy(name=None)` {#Gaussian.entropy} +#### `tf.contrib.distributions.Normal.entropy(name=None)` {#Normal.entropy} -The entropy of Gaussian distribution(s). +The entropy of Normal distribution(s). ##### Args: @@ -107,16 +107,16 @@ The entropy of Gaussian distribution(s). - - - -#### `tf.contrib.distributions.Gaussian.is_reparameterized` {#Gaussian.is_reparameterized} +#### `tf.contrib.distributions.Normal.is_reparameterized` {#Normal.is_reparameterized} - - - -#### `tf.contrib.distributions.Gaussian.log_cdf(x, name=None)` {#Gaussian.log_cdf} +#### `tf.contrib.distributions.Normal.log_cdf(x, name=None)` {#Normal.log_cdf} -Log CDF of observations `x` under these Gaussian distribution(s). +Log CDF of observations `x` under these Normal distribution(s). ##### Args: @@ -132,9 +132,9 @@ Log CDF of observations `x` under these Gaussian distribution(s). - - - -#### `tf.contrib.distributions.Gaussian.log_pdf(x, name=None)` {#Gaussian.log_pdf} +#### `tf.contrib.distributions.Normal.log_pdf(x, name=None)` {#Normal.log_pdf} -Log pdf of observations in `x` under these Gaussian distribution(s). +Log pdf of observations in `x` under these Normal distribution(s). ##### Args: @@ -150,23 +150,23 @@ Log pdf of observations in `x` under these Gaussian distribution(s). - - - -#### `tf.contrib.distributions.Gaussian.mean` {#Gaussian.mean} +#### `tf.contrib.distributions.Normal.mean` {#Normal.mean} - - - -#### `tf.contrib.distributions.Gaussian.mu` {#Gaussian.mu} +#### `tf.contrib.distributions.Normal.mu` {#Normal.mu} - - - -#### `tf.contrib.distributions.Gaussian.pdf(x, name=None)` {#Gaussian.pdf} +#### `tf.contrib.distributions.Normal.pdf(x, name=None)` {#Normal.pdf} -The PDF of observations in `x` under these Gaussian distribution(s). +The PDF of observations in `x` under these Normal distribution(s). ##### Args: @@ -182,9 +182,9 @@ The PDF of observations in `x` under these Gaussian distribution(s). - - - -#### `tf.contrib.distributions.Gaussian.sample(n, seed=None, name=None)` {#Gaussian.sample} +#### `tf.contrib.distributions.Normal.sample(n, seed=None, name=None)` {#Normal.sample} -Sample `n` observations from the Gaussian Distributions. +Sample `n` observations from the Normal Distributions. ##### Args: @@ -202,7 +202,7 @@ Sample `n` observations from the Gaussian Distributions. - - - -#### `tf.contrib.distributions.Gaussian.sigma` {#Gaussian.sigma} +#### `tf.contrib.distributions.Normal.sigma` {#Normal.sigma} diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.distributions.StudentT.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.contrib.distributions.StudentT.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.distributions.StudentT.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.contrib.distributions.StudentT.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.extract_pandas_labels.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.contrib.learn.extract_pandas_labels.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.extract_pandas_labels.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.contrib.learn.extract_pandas_labels.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.read_batch_examples.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.contrib.learn.read_batch_examples.md similarity index 74% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.read_batch_examples.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.contrib.learn.read_batch_examples.md index 7a3c3d7addb..c5cec0542aa 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.read_batch_examples.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.contrib.learn.read_batch_examples.md @@ -1,4 +1,4 @@ -### `tf.contrib.learn.read_batch_examples(file_pattern, batch_size, reader, randomize_input=True, queue_capacity=10000, num_threads=1, name='dequeue_examples')` {#read_batch_examples} +### `tf.contrib.learn.read_batch_examples(file_pattern, batch_size, reader, randomize_input=True, num_epochs=None, queue_capacity=10000, num_threads=1, name=None)` {#read_batch_examples} Adds operations to read, queue, batch `Example` protos. @@ -20,6 +20,10 @@ All ops are added to the default graph. * `reader`: A function or class that returns an object with `read` method, (filename tensor) -> (example tensor). * `randomize_input`: Whether the input should be randomized. +* `num_epochs`: Integer specifying the number of times to read through the + dataset. If `None`, cycles through the dataset forever. + NOTE - If specified, creates a variable that must be initialized, so call + `tf.initialize_all_variables()` as shown in the tests. * `queue_capacity`: Capacity for input queue. * `num_threads`: The number of threads enqueuing examples. * `name`: Name of resulting op. diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.metrics.set_union.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.contrib.metrics.set_union.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.metrics.set_union.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.contrib.metrics.set_union.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.metrics.streaming_mean_squared_error.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.contrib.metrics.streaming_mean_squared_error.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.metrics.streaming_mean_squared_error.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.contrib.metrics.streaming_mean_squared_error.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.metrics.streaming_sparse_precision_at_k.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.contrib.metrics.streaming_sparse_precision_at_k.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.metrics.streaming_sparse_precision_at_k.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.contrib.metrics.streaming_sparse_precision_at_k.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.metrics.streaming_sparse_recall_at_k.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.contrib.metrics.streaming_sparse_recall_at_k.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.metrics.streaming_sparse_recall_at_k.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.contrib.metrics.streaming_sparse_recall_at_k.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.decode_json_example.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.decode_json_example.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.decode_json_example.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.decode_json_example.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.errors.CancelledError.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.errors.CancelledError.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.errors.CancelledError.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.errors.CancelledError.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.errors.OutOfRangeError.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.errors.OutOfRangeError.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.errors.OutOfRangeError.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.errors.OutOfRangeError.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.expand_dims.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.expand_dims.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.expand_dims.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.expand_dims.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.floordiv.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.floordiv.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.floordiv.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.floordiv.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.get_collection.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.get_collection.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.get_collection.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.get_collection.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.global_norm.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.global_norm.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.global_norm.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.global_norm.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.resize_bicubic.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.image.resize_bicubic.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.resize_bicubic.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.image.resize_bicubic.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.resize_images.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.image.resize_images.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.resize_images.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.image.resize_images.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.is_finite.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.is_finite.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.is_finite.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.is_finite.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.is_non_decreasing.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.is_non_decreasing.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.is_non_decreasing.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.is_non_decreasing.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.listdiff.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.listdiff.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.listdiff.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.listdiff.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.matrix_solve.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.matrix_solve.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.matrix_solve.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.matrix_solve.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.fixed_unigram_candidate_sampler.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.nn.fixed_unigram_candidate_sampler.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.fixed_unigram_candidate_sampler.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.nn.fixed_unigram_candidate_sampler.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.l2_normalize.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.nn.l2_normalize.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.l2_normalize.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.nn.l2_normalize.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.relu.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.nn.relu.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.relu.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.nn.relu.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.top_k.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.nn.top_k.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.top_k.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.nn.top_k.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.weighted_cross_entropy_with_logits.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.nn.weighted_cross_entropy_with_logits.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.weighted_cross_entropy_with_logits.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.nn.weighted_cross_entropy_with_logits.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.parse_example.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.parse_example.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.parse_example.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.parse_example.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.random_shuffle.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.random_shuffle.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.random_shuffle.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.random_shuffle.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.reduce_mean.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.reduce_mean.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.reduce_mean.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.reduce_mean.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.rsqrt.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.rsqrt.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.rsqrt.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.rsqrt.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.segment_max.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.segment_max.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.segment_max.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.segment_max.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sin.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.sin.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sin.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.sin.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.size.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.size.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.size.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.size.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.squeeze.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.squeeze.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.squeeze.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.squeeze.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.to_int64.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.to_int64.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.to_int64.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.to_int64.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.GradientDescentOptimizer.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.train.GradientDescentOptimizer.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.GradientDescentOptimizer.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.train.GradientDescentOptimizer.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.Server.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.train.Server.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.Server.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.train.Server.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.SessionManager.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.train.SessionManager.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.SessionManager.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.train.SessionManager.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.global_step.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.train.global_step.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.global_step.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.train.global_step.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.import_meta_graph.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.train.import_meta_graph.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.import_meta_graph.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.train.import_meta_graph.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.range_input_producer.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.train.range_input_producer.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.range_input_producer.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.train.range_input_producer.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.write_graph.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.train.write_graph.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.write_graph.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.train.write_graph.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.variable_scope.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.variable_scope.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.variable_scope.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard3/tf.variable_scope.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.FIFOQueue.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.FIFOQueue.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.FIFOQueue.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.FIFOQueue.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.Operation.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.Operation.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.Operation.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.Operation.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.argmax.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.argmax.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.argmax.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.argmax.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.batch_matrix_determinant.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.batch_matrix_determinant.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.batch_matrix_determinant.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.batch_matrix_determinant.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.batch_matrix_solve.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.batch_matrix_solve.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.batch_matrix_solve.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.batch_matrix_solve.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.distributions.gaussian_congugates_known_sigma_predictive.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.contrib.distributions.normal_congugates_known_sigma_predictive.md similarity index 70% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.distributions.gaussian_congugates_known_sigma_predictive.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.contrib.distributions.normal_congugates_known_sigma_predictive.md index bfa36e208e0..89e4e5ca3c5 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.distributions.gaussian_congugates_known_sigma_predictive.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.contrib.distributions.normal_congugates_known_sigma_predictive.md @@ -1,20 +1,20 @@ -### `tf.contrib.distributions.gaussian_congugates_known_sigma_predictive(prior, sigma, s, n)` {#gaussian_congugates_known_sigma_predictive} +### `tf.contrib.distributions.normal_congugates_known_sigma_predictive(prior, sigma, s, n)` {#normal_congugates_known_sigma_predictive} -Posterior predictive Gaussian distribution w. conjugate prior on the mean. +Posterior predictive Normal distribution w. conjugate prior on the mean. This model assumes that `n` observations (with sum `s`) come from a -Gaussian with unknown mean `mu` (described by the Gaussian `prior`) +Normal with unknown mean `mu` (described by the Normal `prior`) and known variance `sigma^2`. The "known sigma predictive" is the distribution of new observations, conditioned on the existing observations and our prior. -Accepts a prior Gaussian distribution object, having parameters +Accepts a prior Normal distribution object, having parameters `mu0` and `sigma0`, as well as known `sigma` values of the predictive -distribution(s) (also assumed Gaussian), +distribution(s) (also assumed Normal), and statistical estimates `s` (the sum(s) of the observations) and `n` (the number(s) of observations). -Calculates the Gaussian distribution(s) `p(x | sigma^2)`: +Calculates the Normal distribution(s) `p(x | sigma^2)`: ``` p(x | sigma^2) = int N(x | mu, sigma^2) N(mu | prior.mu, prior.sigma^2) dmu @@ -36,7 +36,7 @@ will broadcast in the case of multidimensional sets of parameters. ##### Args: -* `prior`: `Gaussian` object of type `dtype`: +* `prior`: `Normal` object of type `dtype`: the prior distribution having parameters `(mu0, sigma0)`. * `sigma`: tensor of type `dtype`, taking values `sigma > 0`. The known stddev parameter(s). @@ -45,11 +45,11 @@ will broadcast in the case of multidimensional sets of parameters. ##### Returns: - A new Gaussian predictive distribution object. + A new Normal predictive distribution object. ##### Raises: * `TypeError`: if dtype of `s` does not match `dtype`, or `prior` is not a - Gaussian object. + Normal object. diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.layers.l1_regularizer.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.contrib.layers.l1_regularizer.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.layers.l1_regularizer.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.contrib.layers.l1_regularizer.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.layers.summarize_activation.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.contrib.layers.summarize_activation.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.layers.summarize_activation.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.contrib.layers.summarize_activation.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.layers.xavier_initializer_conv2d.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.contrib.layers.xavier_initializer_conv2d.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.layers.xavier_initializer_conv2d.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.contrib.layers.xavier_initializer_conv2d.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.evaluate.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.contrib.learn.evaluate.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.evaluate.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.contrib.learn.evaluate.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.infer.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.contrib.learn.infer.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.infer.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.contrib.learn.infer.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.metrics.confusion_matrix.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.contrib.metrics.confusion_matrix.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.metrics.confusion_matrix.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.contrib.metrics.confusion_matrix.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.metrics.set_intersection.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.contrib.metrics.set_intersection.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.metrics.set_intersection.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.contrib.metrics.set_intersection.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.metrics.streaming_accuracy.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.contrib.metrics.streaming_accuracy.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.metrics.streaming_accuracy.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.contrib.metrics.streaming_accuracy.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.metrics.streaming_percentage_less.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.contrib.metrics.streaming_percentage_less.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.metrics.streaming_percentage_less.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.contrib.metrics.streaming_percentage_less.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.convert_to_tensor_or_indexed_slices.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.convert_to_tensor_or_indexed_slices.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.convert_to_tensor_or_indexed_slices.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.convert_to_tensor_or_indexed_slices.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.cos.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.cos.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.cos.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.cos.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.erfc.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.erfc.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.erfc.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.erfc.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.errors.DataLossError.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.errors.DataLossError.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.errors.DataLossError.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.errors.DataLossError.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.foldr.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.foldr.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.foldr.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.foldr.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.histogram_summary.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.histogram_summary.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.histogram_summary.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.histogram_summary.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.ifft3d.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.ifft3d.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.ifft3d.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.ifft3d.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.convert_image_dtype.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.image.convert_image_dtype.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.convert_image_dtype.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.image.convert_image_dtype.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.pad_to_bounding_box.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.image.pad_to_bounding_box.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.pad_to_bounding_box.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.image.pad_to_bounding_box.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.random_saturation.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.image.random_saturation.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.random_saturation.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.image.random_saturation.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.resize_image_with_crop_or_pad.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.image.resize_image_with_crop_or_pad.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.resize_image_with_crop_or_pad.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.image.resize_image_with_crop_or_pad.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.rgb_to_grayscale.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.image.rgb_to_grayscale.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.rgb_to_grayscale.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.image.rgb_to_grayscale.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.is_inf.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.is_inf.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.is_inf.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.is_inf.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.logical_not.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.logical_not.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.logical_not.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.logical_not.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.make_template.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.make_template.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.make_template.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.make_template.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.depthwise_conv2d_native.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.nn.depthwise_conv2d_native.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.depthwise_conv2d_native.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.nn.depthwise_conv2d_native.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.local_response_normalization.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.nn.local_response_normalization.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.local_response_normalization.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.nn.local_response_normalization.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.max_pool.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.nn.max_pool.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.max_pool.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.nn.max_pool.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.max_pool_with_argmax.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.nn.max_pool_with_argmax.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.max_pool_with_argmax.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.nn.max_pool_with_argmax.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.no_op.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.no_op.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.no_op.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.no_op.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.placeholder_with_default.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.placeholder_with_default.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.placeholder_with_default.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.placeholder_with_default.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.python_io.tf_record_iterator.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.python_io.tf_record_iterator.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.python_io.tf_record_iterator.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.python_io.tf_record_iterator.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.random_normal.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.random_normal.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.random_normal.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.random_normal.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.reduce_all.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.reduce_all.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.reduce_all.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.reduce_all.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.saturate_cast.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.saturate_cast.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.saturate_cast.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.saturate_cast.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sparse_fill_empty_rows.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.sparse_fill_empty_rows.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sparse_fill_empty_rows.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.sparse_fill_empty_rows.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sparse_placeholder.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.sparse_placeholder.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sparse_placeholder.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.sparse_placeholder.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sparse_reorder.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.sparse_reorder.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sparse_reorder.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.sparse_reorder.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.to_bfloat16.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.to_bfloat16.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.to_bfloat16.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.to_bfloat16.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.ClusterSpec.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.train.ClusterSpec.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.ClusterSpec.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.train.ClusterSpec.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.truncated_normal.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.truncated_normal.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.truncated_normal.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.truncated_normal.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.variable_op_scope.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.variable_op_scope.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.variable_op_scope.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.variable_op_scope.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.zeta.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.zeta.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.zeta.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard4/tf.zeta.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.DeviceSpec.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.DeviceSpec.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.DeviceSpec.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.DeviceSpec.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.Graph.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.Graph.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.Graph.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.Graph.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.IdentityReader.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.IdentityReader.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.IdentityReader.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.IdentityReader.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.QueueBase.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.QueueBase.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.QueueBase.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.QueueBase.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.RandomShuffleQueue.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.RandomShuffleQueue.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.RandomShuffleQueue.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.RandomShuffleQueue.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.RegisterShape.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.RegisterShape.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.RegisterShape.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.RegisterShape.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.SparseTensor.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.SparseTensor.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.SparseTensor.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.SparseTensor.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.as_dtype.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.as_dtype.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.as_dtype.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.as_dtype.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.assert_non_negative.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.assert_non_negative.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.assert_non_negative.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.assert_non_negative.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.assert_rank_at_least.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.assert_rank_at_least.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.assert_rank_at_least.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.assert_rank_at_least.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.batch_cholesky_solve.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.batch_cholesky_solve.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.batch_cholesky_solve.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.batch_cholesky_solve.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.batch_matrix_diag.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.batch_matrix_diag.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.batch_matrix_diag.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.batch_matrix_diag.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.batch_matrix_inverse.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.batch_matrix_inverse.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.batch_matrix_inverse.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.batch_matrix_inverse.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.batch_matrix_solve_ls.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.batch_matrix_solve_ls.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.batch_matrix_solve_ls.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.batch_matrix_solve_ls.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.batch_to_space.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.batch_to_space.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.batch_to_space.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.batch_to_space.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.check_numerics.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.check_numerics.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.check_numerics.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.check_numerics.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.clip_by_norm.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.clip_by_norm.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.clip_by_norm.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.clip_by_norm.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.clip_by_value.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.clip_by_value.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.clip_by_value.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.clip_by_value.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.distributions.Uniform.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.contrib.distributions.Uniform.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.distributions.Uniform.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.contrib.distributions.Uniform.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.TensorFlowEstimator.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.contrib.learn.TensorFlowEstimator.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.TensorFlowEstimator.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.contrib.learn.TensorFlowEstimator.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.extract_pandas_matrix.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.contrib.learn.extract_pandas_matrix.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.extract_pandas_matrix.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.contrib.learn.extract_pandas_matrix.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.read_batch_features.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.contrib.learn.read_batch_features.md similarity index 69% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.read_batch_features.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.contrib.learn.read_batch_features.md index 1008bf64223..75b40f7e753 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.read_batch_features.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.contrib.learn.read_batch_features.md @@ -1,4 +1,4 @@ -### `tf.contrib.learn.read_batch_features(file_pattern, batch_size, features, reader, randomize_input=True, queue_capacity=10000, num_threads=1, name='dequeue_examples')` {#read_batch_features} +### `tf.contrib.learn.read_batch_features(file_pattern, batch_size, features, reader, randomize_input=True, num_epochs=None, queue_capacity=10000, reader_num_threads=1, parser_num_threads=1, name=None)` {#read_batch_features} Adds operations to read, queue, batch and parse `Example` protos. @@ -23,8 +23,13 @@ All ops are added to the default graph. * `reader`: A function or class that returns an object with `read` method, (filename tensor) -> (example tensor). * `randomize_input`: Whether the input should be randomized. +* `num_epochs`: Integer specifying the number of times to read through the + dataset. If None, cycles through the dataset forever. NOTE - If specified, + creates a variable that must be initialized, so call + tf.initialize_all_variables() as shown in the tests. * `queue_capacity`: Capacity for input queue. -* `num_threads`: The number of threads enqueuing examples. +* `reader_num_threads`: The number of threads to read examples. +* `parser_num_threads`: The number of threads to parse examples. * `name`: Name of resulting op. ##### Returns: diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.metrics.streaming_recall_at_k.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.contrib.metrics.streaming_recall_at_k.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.metrics.streaming_recall_at_k.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.contrib.metrics.streaming_recall_at_k.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.util.stripped_op_list_for_graph.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.contrib.util.stripped_op_list_for_graph.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.util.stripped_op_list_for_graph.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.contrib.util.stripped_op_list_for_graph.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.decode_csv.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.decode_csv.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.decode_csv.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.decode_csv.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.edit_distance.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.edit_distance.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.edit_distance.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.edit_distance.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.fft2d.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.fft2d.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.fft2d.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.fft2d.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.floor.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.floor.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.floor.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.floor.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.gather.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.gather.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.gather.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.gather.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.gather_nd.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.gather_nd.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.gather_nd.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.gather_nd.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.get_collection_ref.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.get_collection_ref.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.get_collection_ref.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.get_collection_ref.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.get_variable.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.get_variable.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.get_variable.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.get_variable.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.ifft2d.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.ifft2d.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.ifft2d.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.ifft2d.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.adjust_hue.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.image.adjust_hue.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.adjust_hue.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.image.adjust_hue.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.decode_jpeg.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.image.decode_jpeg.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.decode_jpeg.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.image.decode_jpeg.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.decode_png.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.image.decode_png.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.decode_png.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.image.decode_png.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.extract_glimpse.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.image.extract_glimpse.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.extract_glimpse.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.image.extract_glimpse.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.per_image_whitening.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.image.per_image_whitening.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.per_image_whitening.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.image.per_image_whitening.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.random_brightness.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.image.random_brightness.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.random_brightness.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.image.random_brightness.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.random_hue.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.image.random_hue.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.random_hue.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.image.random_hue.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image_summary.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.image_summary.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image_summary.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.image_summary.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.less.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.less.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.less.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.less.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.map_fn.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.map_fn.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.map_fn.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.map_fn.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.neg.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.neg.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.neg.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.neg.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.conv2d_transpose.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.nn.conv2d_transpose.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.conv2d_transpose.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.nn.conv2d_transpose.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.depthwise_conv2d.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.nn.depthwise_conv2d.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.depthwise_conv2d.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.nn.depthwise_conv2d.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.max_pool3d.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.nn.max_pool3d.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.max_pool3d.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.nn.max_pool3d.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.moments.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.nn.moments.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.moments.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.nn.moments.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.nce_loss.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.nn.nce_loss.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.nce_loss.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.nn.nce_loss.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.softplus.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.nn.softplus.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.softplus.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.nn.softplus.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.zero_fraction.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.nn.zero_fraction.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.zero_fraction.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.nn.zero_fraction.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.one_hot.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.one_hot.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.one_hot.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.one_hot.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.random_normal_initializer.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.random_normal_initializer.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.random_normal_initializer.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.random_normal_initializer.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.reduce_any.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.reduce_any.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.reduce_any.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.reduce_any.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.reset_default_graph.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.reset_default_graph.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.reset_default_graph.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.reset_default_graph.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.shape_n.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.shape_n.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.shape_n.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.shape_n.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sparse_merge.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.sparse_merge.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sparse_merge.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.sparse_merge.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sparse_segment_sqrt_n_grad.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.sparse_segment_sqrt_n_grad.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sparse_segment_sqrt_n_grad.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.sparse_segment_sqrt_n_grad.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sparse_to_dense.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.sparse_to_dense.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sparse_to_dense.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.sparse_to_dense.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.squared_difference.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.squared_difference.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.squared_difference.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.squared_difference.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sub.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.sub.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sub.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.sub.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.test.assert_equal_graph_def.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.test.assert_equal_graph_def.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.test.assert_equal_graph_def.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.test.assert_equal_graph_def.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.test.compute_gradient.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.test.compute_gradient.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.test.compute_gradient.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.test.compute_gradient.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.batch.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.train.batch.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.batch.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.train.batch.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.batch_join.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.train.batch_join.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.batch_join.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.train.batch_join.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.exponential_decay.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.train.exponential_decay.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.exponential_decay.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.train.exponential_decay.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.latest_checkpoint.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.train.latest_checkpoint.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.latest_checkpoint.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.train.latest_checkpoint.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.shuffle_batch_join.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.train.shuffle_batch_join.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.shuffle_batch_join.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.train.shuffle_batch_join.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.zeros_like.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.zeros_like.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.zeros_like.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.zeros_like.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.AggregationMethod.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.AggregationMethod.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.AggregationMethod.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.AggregationMethod.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.Assert.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.Assert.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.Assert.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.Assert.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.InteractiveSession.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.InteractiveSession.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.InteractiveSession.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.InteractiveSession.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.QueueBase.from_list.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.QueueBase.from_list.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.QueueBase.from_list.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.QueueBase.from_list.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.TextLineReader.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.TextLineReader.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.TextLineReader.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.TextLineReader.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.add.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.add.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.add.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.add.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.assert_proper_iterable.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.assert_proper_iterable.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.assert_proper_iterable.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.assert_proper_iterable.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.batch_ifft.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.batch_ifft.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.batch_ifft.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.batch_ifft.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.batch_ifft2d.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.batch_ifft2d.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.batch_ifft2d.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.batch_ifft2d.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.cholesky.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.cholesky.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.cholesky.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.cholesky.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.clip_by_global_norm.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.clip_by_global_norm.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.clip_by_global_norm.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.clip_by_global_norm.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.distributions.Chi2.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.contrib.distributions.Chi2.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.distributions.Chi2.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.contrib.distributions.Chi2.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.layers.optimize_loss.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.contrib.layers.optimize_loss.md similarity index 90% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.layers.optimize_loss.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.contrib.layers.optimize_loss.md index db0b01186a2..8e6072532be 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.layers.optimize_loss.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.contrib.layers.optimize_loss.md @@ -1,4 +1,4 @@ -### `tf.contrib.layers.optimize_loss(loss, global_step, learning_rate, optimizer, gradient_noise_scale=None, gradient_multipliers=None, clip_gradients=None, moving_average_decay=0.9, learning_rate_decay_fn=None, variables=None, name=None)` {#optimize_loss} +### `tf.contrib.layers.optimize_loss(loss, global_step, learning_rate, optimizer, gradient_noise_scale=None, gradient_multipliers=None, clip_gradients=None, moving_average_decay=0.9, learning_rate_decay_fn=None, update_ops=None, variables=None, name=None)` {#optimize_loss} Given loss and parameters for optimizer, returns a training op. @@ -28,6 +28,8 @@ Given loss and parameters for optimizer, returns a training op. Can be used to implement any learning rate decay functions. For example: tf.train.exponential_decay. +* `update_ops`: list of update `Operation`s to execute at each step. If `None`, + uses elements of UPDATE_OPS collection. * `variables`: list of variables to optimize or `None` to use all trainable variables. * `name`: The name for this operation is used to scope operations and summaries. diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.layers.summarize_collection.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.contrib.layers.summarize_collection.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.layers.summarize_collection.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.contrib.layers.summarize_collection.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.Estimator.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.contrib.learn.Estimator.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.Estimator.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.contrib.learn.Estimator.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.ModeKeys.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.contrib.learn.ModeKeys.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.ModeKeys.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.contrib.learn.ModeKeys.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.NanLossDuringTrainingError.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.contrib.learn.NanLossDuringTrainingError.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.NanLossDuringTrainingError.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.contrib.learn.NanLossDuringTrainingError.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.TensorFlowDNNRegressor.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.contrib.learn.TensorFlowDNNRegressor.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.TensorFlowDNNRegressor.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.contrib.learn.TensorFlowDNNRegressor.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.TensorFlowLinearClassifier.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.contrib.learn.TensorFlowLinearClassifier.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.TensorFlowLinearClassifier.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.contrib.learn.TensorFlowLinearClassifier.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.TensorFlowRNNClassifier.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.contrib.learn.TensorFlowRNNClassifier.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.TensorFlowRNNClassifier.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.contrib.learn.TensorFlowRNNClassifier.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.run_n.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.contrib.learn.run_n.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.run_n.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.contrib.learn.run_n.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.metrics.streaming_mean.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.contrib.metrics.streaming_mean.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.metrics.streaming_mean.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.contrib.metrics.streaming_mean.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.metrics.streaming_precision.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.contrib.metrics.streaming_precision.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.metrics.streaming_precision.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.contrib.metrics.streaming_precision.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.errors.AlreadyExistsError.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.errors.AlreadyExistsError.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.errors.AlreadyExistsError.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.errors.AlreadyExistsError.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.errors.InvalidArgumentError.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.errors.InvalidArgumentError.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.errors.InvalidArgumentError.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.errors.InvalidArgumentError.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.errors.UnknownError.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.errors.UnknownError.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.errors.UnknownError.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.errors.UnknownError.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.fft3d.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.fft3d.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.fft3d.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.fft3d.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.igamma.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.igamma.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.igamma.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.igamma.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.igammac.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.igammac.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.igammac.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.igammac.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.adjust_saturation.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.image.adjust_saturation.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.adjust_saturation.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.image.adjust_saturation.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.hsv_to_rgb.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.image.hsv_to_rgb.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.hsv_to_rgb.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.image.hsv_to_rgb.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.initialize_all_variables.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.initialize_all_variables.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.initialize_all_variables.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.initialize_all_variables.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.less_equal.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.less_equal.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.less_equal.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.less_equal.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.matmul.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.matmul.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.matmul.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.matmul.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.matrix_inverse.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.matrix_inverse.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.matrix_inverse.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.matrix_inverse.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.matrix_triangular_solve.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.matrix_triangular_solve.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.matrix_triangular_solve.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.matrix_triangular_solve.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.merge_summary.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.merge_summary.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.merge_summary.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.merge_summary.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.moving_average_variables.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.moving_average_variables.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.moving_average_variables.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.moving_average_variables.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.bias_add.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.nn.bias_add.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.bias_add.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.nn.bias_add.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.elu.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.nn.elu.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.elu.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.nn.elu.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.learned_unigram_candidate_sampler.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.nn.learned_unigram_candidate_sampler.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.learned_unigram_candidate_sampler.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.nn.learned_unigram_candidate_sampler.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.log_uniform_candidate_sampler.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.nn.log_uniform_candidate_sampler.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.log_uniform_candidate_sampler.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.nn.log_uniform_candidate_sampler.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.no_regularizer.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.no_regularizer.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.no_regularizer.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.no_regularizer.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.op_scope.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.op_scope.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.op_scope.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.op_scope.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.rank.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.rank.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.rank.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.rank.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.reduce_join.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.reduce_join.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.reduce_join.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.reduce_join.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.reduce_sum.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.reduce_sum.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.reduce_sum.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.reduce_sum.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.shape.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.shape.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.shape.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.shape.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.space_to_depth.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.space_to_depth.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.space_to_depth.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.space_to_depth.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sparse_concat.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.sparse_concat.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sparse_concat.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.sparse_concat.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sparse_tensor_to_dense.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.sparse_tensor_to_dense.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sparse_tensor_to_dense.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.sparse_tensor_to_dense.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.string_to_hash_bucket.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.string_to_hash_bucket.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.string_to_hash_bucket.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.string_to_hash_bucket.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.string_to_hash_bucket_fast.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.string_to_hash_bucket_fast.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.string_to_hash_bucket_fast.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.string_to_hash_bucket_fast.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.test.get_temp_dir.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.test.get_temp_dir.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.test.get_temp_dir.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.test.get_temp_dir.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.test.main.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.test.main.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.test.main.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.test.main.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.AdadeltaOptimizer.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.train.AdadeltaOptimizer.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.AdadeltaOptimizer.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.train.AdadeltaOptimizer.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.MomentumOptimizer.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.train.MomentumOptimizer.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.MomentumOptimizer.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.train.MomentumOptimizer.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.QueueRunner.from_proto.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.train.QueueRunner.from_proto.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.QueueRunner.from_proto.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.train.QueueRunner.from_proto.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.QueueRunner.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.train.QueueRunner.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.QueueRunner.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.train.QueueRunner.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.RMSPropOptimizer.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.train.RMSPropOptimizer.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.RMSPropOptimizer.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.train.RMSPropOptimizer.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.export_meta_graph.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.train.export_meta_graph.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.export_meta_graph.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.train.export_meta_graph.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.replica_device_setter.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.train.replica_device_setter.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.replica_device_setter.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.train.replica_device_setter.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.slice_input_producer.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.train.slice_input_producer.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.slice_input_producer.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.train.slice_input_producer.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.uniform_unit_scaling_initializer.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.uniform_unit_scaling_initializer.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.uniform_unit_scaling_initializer.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.uniform_unit_scaling_initializer.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.unsorted_segment_sum.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.unsorted_segment_sum.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.unsorted_segment_sum.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.unsorted_segment_sum.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.zeros.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.zeros.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.zeros.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard6/tf.zeros.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.FixedLengthRecordReader.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.FixedLengthRecordReader.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.FixedLengthRecordReader.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.FixedLengthRecordReader.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.IndexedSlices.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.IndexedSlices.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.IndexedSlices.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.IndexedSlices.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.VariableScope.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.VariableScope.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.VariableScope.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.VariableScope.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.assert_equal.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.assert_equal.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.assert_equal.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.assert_equal.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.assert_integer.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.assert_integer.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.assert_integer.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.assert_integer.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.assert_less.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.assert_less.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.assert_less.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.assert_less.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.assert_negative.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.assert_negative.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.assert_negative.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.assert_negative.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.batch_fft3d.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.batch_fft3d.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.batch_fft3d.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.batch_fft3d.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.bitcast.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.bitcast.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.bitcast.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.bitcast.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.cast.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.cast.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.cast.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.cast.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.ceil.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.ceil.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.ceil.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.ceil.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.clip_by_average_norm.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.clip_by_average_norm.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.clip_by_average_norm.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.clip_by_average_norm.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.distributions.BaseDistribution.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.contrib.distributions.BaseDistribution.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.distributions.BaseDistribution.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.contrib.distributions.BaseDistribution.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.distributions.ContinuousDistribution.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.contrib.distributions.ContinuousDistribution.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.distributions.ContinuousDistribution.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.contrib.distributions.ContinuousDistribution.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.layers.convolution2d.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.contrib.layers.convolution2d.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.layers.convolution2d.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.contrib.layers.convolution2d.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.layers.l2_regularizer.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.contrib.layers.l2_regularizer.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.layers.l2_regularizer.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.contrib.layers.l2_regularizer.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.layers.sum_regularizer.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.contrib.layers.sum_regularizer.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.layers.sum_regularizer.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.contrib.layers.sum_regularizer.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.layers.summarize_tensors.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.contrib.layers.summarize_tensors.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.layers.summarize_tensors.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.contrib.layers.summarize_tensors.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.BaseEstimator.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.contrib.learn.BaseEstimator.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.BaseEstimator.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.contrib.learn.BaseEstimator.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.TensorFlowRegressor.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.contrib.learn.TensorFlowRegressor.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.TensorFlowRegressor.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.contrib.learn.TensorFlowRegressor.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.cross.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.cross.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.cross.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.cross.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.delete_session_tensor.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.delete_session_tensor.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.delete_session_tensor.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.delete_session_tensor.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.dynamic_stitch.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.dynamic_stitch.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.dynamic_stitch.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.dynamic_stitch.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.errors.AbortedError.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.errors.AbortedError.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.errors.AbortedError.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.errors.AbortedError.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.errors.PermissionDeniedError.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.errors.PermissionDeniedError.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.errors.PermissionDeniedError.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.errors.PermissionDeniedError.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.fft.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.fft.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.fft.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.fft.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.greater.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.greater.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.greater.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.greater.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.resize_bilinear.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.image.resize_bilinear.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.resize_bilinear.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.image.resize_bilinear.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.is_strictly_increasing.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.is_strictly_increasing.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.is_strictly_increasing.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.is_strictly_increasing.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.is_variable_initialized.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.is_variable_initialized.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.is_variable_initialized.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.is_variable_initialized.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.lgamma.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.lgamma.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.lgamma.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.lgamma.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.linspace.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.linspace.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.linspace.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.linspace.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.matching_files.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.matching_files.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.matching_files.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.matching_files.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.maximum.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.maximum.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.maximum.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.maximum.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.mul.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.mul.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.mul.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.mul.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.compute_accidental_hits.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.nn.compute_accidental_hits.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.compute_accidental_hits.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.nn.compute_accidental_hits.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.dropout.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.nn.dropout.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.dropout.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.nn.dropout.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.embedding_lookup_sparse.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.nn.embedding_lookup_sparse.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.embedding_lookup_sparse.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.nn.embedding_lookup_sparse.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.normalize_moments.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.nn.normalize_moments.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.normalize_moments.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.nn.normalize_moments.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.softmax_cross_entropy_with_logits.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.nn.softmax_cross_entropy_with_logits.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.softmax_cross_entropy_with_logits.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.nn.softmax_cross_entropy_with_logits.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.parse_single_example.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.parse_single_example.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.parse_single_example.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.parse_single_example.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.pow.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.pow.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.pow.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.pow.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.random_crop.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.random_crop.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.random_crop.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.random_crop.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.reduce_max.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.reduce_max.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.reduce_max.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.reduce_max.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.reduce_prod.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.reduce_prod.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.reduce_prod.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.reduce_prod.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.reshape.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.reshape.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.reshape.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.reshape.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.round.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.round.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.round.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.round.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.scalar_mul.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.scalar_mul.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.scalar_mul.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.scalar_mul.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.scan.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.scan.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.scan.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.scan.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.segment_min.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.segment_min.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.segment_min.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.segment_min.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sigmoid.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.sigmoid.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sigmoid.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.sigmoid.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sign.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.sign.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sign.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.sign.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.slice.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.slice.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.slice.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.slice.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sparse_segment_sqrt_n.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.sparse_segment_sqrt_n.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sparse_segment_sqrt_n.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.sparse_segment_sqrt_n.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sparse_to_indicator.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.sparse_to_indicator.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sparse_to_indicator.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.sparse_to_indicator.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.string_to_number.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.string_to_number.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.string_to_number.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.string_to_number.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.LooperThread.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.train.LooperThread.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.LooperThread.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.train.LooperThread.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.get_checkpoint_state.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.train.get_checkpoint_state.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.get_checkpoint_state.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.train.get_checkpoint_state.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.tuple.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.tuple.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.tuple.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.tuple.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.unique.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.unique.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.unique.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard7/tf.unique.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.FixedLenFeature.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.FixedLenFeature.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.FixedLenFeature.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.FixedLenFeature.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.all_variables.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.all_variables.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.all_variables.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.all_variables.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.assert_type.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.assert_type.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.assert_type.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.assert_type.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.batch_matrix_diag_part.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.batch_matrix_diag_part.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.batch_matrix_diag_part.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.batch_matrix_diag_part.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.bytes.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.bytes.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.bytes.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.bytes.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.cholesky_solve.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.cholesky_solve.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.cholesky_solve.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.cholesky_solve.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.cond.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.cond.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.cond.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.cond.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.conj.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.conj.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.conj.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.conj.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.constant.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.constant.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.constant.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.constant.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.copy_graph.get_copied_op.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.contrib.copy_graph.get_copied_op.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.copy_graph.get_copied_op.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.contrib.copy_graph.get_copied_op.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.ffmpeg.encode_audio.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.contrib.ffmpeg.encode_audio.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.ffmpeg.encode_audio.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.contrib.ffmpeg.encode_audio.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.layers.variance_scaling_initializer.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.contrib.layers.variance_scaling_initializer.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.layers.variance_scaling_initializer.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.contrib.layers.variance_scaling_initializer.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.TensorFlowDNNClassifier.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.contrib.learn.TensorFlowDNNClassifier.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.TensorFlowDNNClassifier.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.contrib.learn.TensorFlowDNNClassifier.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.TensorFlowRNNRegressor.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.contrib.learn.TensorFlowRNNRegressor.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.TensorFlowRNNRegressor.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.contrib.learn.TensorFlowRNNRegressor.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.extract_pandas_data.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.contrib.learn.extract_pandas_data.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.extract_pandas_data.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.contrib.learn.extract_pandas_data.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.metrics.set_size.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.contrib.metrics.set_size.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.metrics.set_size.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.contrib.metrics.set_size.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.util.constant_value.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.contrib.util.constant_value.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.util.constant_value.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.contrib.util.constant_value.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.util.ops_used_by_graph_def.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.contrib.util.ops_used_by_graph_def.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.util.ops_used_by_graph_def.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.contrib.util.ops_used_by_graph_def.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.count_up_to.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.count_up_to.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.count_up_to.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.count_up_to.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.depth_to_space.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.depth_to_space.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.depth_to_space.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.depth_to_space.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.device.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.device.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.device.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.device.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.div.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.div.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.div.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.div.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.errors.DeadlineExceededError.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.errors.DeadlineExceededError.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.errors.DeadlineExceededError.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.errors.DeadlineExceededError.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.errors.FailedPreconditionError.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.errors.FailedPreconditionError.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.errors.FailedPreconditionError.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.errors.FailedPreconditionError.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.errors.NotFoundError.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.errors.NotFoundError.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.errors.NotFoundError.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.errors.NotFoundError.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.exp.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.exp.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.exp.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.exp.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.get_default_session.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.get_default_session.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.get_default_session.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.get_default_session.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.get_session_handle.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.get_session_handle.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.get_session_handle.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.get_session_handle.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.get_variable_scope.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.get_variable_scope.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.get_variable_scope.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.get_variable_scope.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.gradients.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.gradients.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.gradients.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.gradients.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.identity.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.identity.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.identity.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.identity.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.imag.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.imag.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.imag.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.imag.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.crop_to_bounding_box.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.image.crop_to_bounding_box.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.crop_to_bounding_box.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.image.crop_to_bounding_box.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.flip_left_right.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.image.flip_left_right.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.flip_left_right.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.image.flip_left_right.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.rgb_to_hsv.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.image.rgb_to_hsv.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.rgb_to_hsv.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.image.rgb_to_hsv.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.logical_and.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.logical_and.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.logical_and.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.logical_and.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.merge_all_summaries.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.merge_all_summaries.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.merge_all_summaries.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.merge_all_summaries.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.name_scope.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.name_scope.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.name_scope.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.name_scope.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.relu6.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.nn.relu6.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.relu6.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.nn.relu6.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.ones_like.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.ones_like.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.ones_like.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.ones_like.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.random_uniform.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.random_uniform.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.random_uniform.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.random_uniform.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.read_file.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.read_file.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.read_file.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.read_file.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.real.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.real.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.real.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.real.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.report_uninitialized_variables.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.report_uninitialized_variables.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.report_uninitialized_variables.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.report_uninitialized_variables.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.scatter_update.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.scatter_update.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.scatter_update.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.scatter_update.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sparse_reset_shape.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.sparse_reset_shape.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sparse_reset_shape.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.sparse_reset_shape.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sparse_segment_sum.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.sparse_segment_sum.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sparse_segment_sum.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.sparse_segment_sum.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.tile.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.tile.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.tile.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.tile.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.to_float.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.to_float.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.to_float.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.to_float.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.ExponentialMovingAverage.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.train.ExponentialMovingAverage.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.ExponentialMovingAverage.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.train.ExponentialMovingAverage.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.SummaryWriter.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.train.SummaryWriter.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.SummaryWriter.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.train.SummaryWriter.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.generate_checkpoint_state_proto.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.train.generate_checkpoint_state_proto.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.generate_checkpoint_state_proto.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.train.generate_checkpoint_state_proto.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.variable_axis_size_partitioner.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.variable_axis_size_partitioner.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.variable_axis_size_partitioner.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.variable_axis_size_partitioner.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.where.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.where.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.where.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard8/tf.where.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.FixedLenSequenceFeature.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.FixedLenSequenceFeature.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.FixedLenSequenceFeature.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.FixedLenSequenceFeature.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.GraphKeys.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.GraphKeys.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.GraphKeys.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.GraphKeys.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.NoGradient.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.NoGradient.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.NoGradient.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.NoGradient.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.Session.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.Session.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.Session.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.Session.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.TensorShape.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.TensorShape.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.TensorShape.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.TensorShape.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.VarLenFeature.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.VarLenFeature.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.VarLenFeature.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.VarLenFeature.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.WholeFileReader.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.WholeFileReader.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.WholeFileReader.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.WholeFileReader.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.assert_positive.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.assert_positive.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.assert_positive.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.assert_positive.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.batch_cholesky.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.batch_cholesky.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.batch_cholesky.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.batch_cholesky.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.complex.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.complex.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.complex.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.complex.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.copy_graph.copy_variable_to_graph.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.contrib.copy_graph.copy_variable_to_graph.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.copy_graph.copy_variable_to_graph.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.contrib.copy_graph.copy_variable_to_graph.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.distributions.DiscreteDistribution.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.contrib.distributions.DiscreteDistribution.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.distributions.DiscreteDistribution.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.contrib.distributions.DiscreteDistribution.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.layers.fully_connected.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.contrib.layers.fully_connected.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.layers.fully_connected.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.contrib.layers.fully_connected.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.TensorFlowClassifier.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.contrib.learn.TensorFlowClassifier.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.TensorFlowClassifier.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.contrib.learn.TensorFlowClassifier.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.extract_dask_labels.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.contrib.learn.extract_dask_labels.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.extract_dask_labels.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.contrib.learn.extract_dask_labels.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.run_feeds.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.contrib.learn.run_feeds.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.learn.run_feeds.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.contrib.learn.run_feeds.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.metrics.accuracy.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.contrib.metrics.accuracy.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.metrics.accuracy.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.contrib.metrics.accuracy.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.metrics.streaming_auc.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.contrib.metrics.streaming_auc.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.metrics.streaming_auc.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.contrib.metrics.streaming_auc.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.metrics.streaming_root_mean_squared_error.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.contrib.metrics.streaming_root_mean_squared_error.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.contrib.metrics.streaming_root_mean_squared_error.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.contrib.metrics.streaming_root_mean_squared_error.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.decode_raw.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.decode_raw.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.decode_raw.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.decode_raw.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.digamma.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.digamma.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.digamma.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.digamma.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.equal.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.equal.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.equal.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.equal.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.errors.UnimplementedError.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.errors.UnimplementedError.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.errors.UnimplementedError.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.errors.UnimplementedError.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.fill.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.fill.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.fill.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.fill.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.get_session_tensor.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.get_session_tensor.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.get_session_tensor.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.get_session_tensor.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.flip_up_down.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.image.flip_up_down.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.flip_up_down.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.image.flip_up_down.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.random_flip_left_right.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.image.random_flip_left_right.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.image.random_flip_left_right.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.image.random_flip_left_right.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.is_numeric_tensor.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.is_numeric_tensor.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.is_numeric_tensor.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.is_numeric_tensor.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.lbeta.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.lbeta.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.lbeta.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.lbeta.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.load_file_system_library.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.load_file_system_library.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.load_file_system_library.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.load_file_system_library.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.load_op_library.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.load_op_library.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.load_op_library.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.load_op_library.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.mod.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.mod.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.mod.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.mod.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.avg_pool3d.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.nn.avg_pool3d.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.avg_pool3d.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.nn.avg_pool3d.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.sparse_softmax_cross_entropy_with_logits.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.nn.sparse_softmax_cross_entropy_with_logits.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.nn.sparse_softmax_cross_entropy_with_logits.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.nn.sparse_softmax_cross_entropy_with_logits.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.ones.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.ones.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.ones.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.ones.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.pack.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.pack.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.pack.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.pack.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.py_func.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.py_func.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.py_func.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.py_func.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.random_uniform_initializer.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.random_uniform_initializer.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.random_uniform_initializer.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.random_uniform_initializer.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.reduce_min.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.reduce_min.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.reduce_min.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.reduce_min.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.reverse.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.reverse.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.reverse.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.reverse.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.segment_prod.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.segment_prod.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.segment_prod.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.segment_prod.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sparse_tensor_dense_matmul.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.sparse_tensor_dense_matmul.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.sparse_tensor_dense_matmul.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.sparse_tensor_dense_matmul.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.square.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.square.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.square.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.square.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.tanh.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.tanh.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.tanh.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.tanh.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.test.is_built_with_cuda.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.test.is_built_with_cuda.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.test.is_built_with_cuda.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.test.is_built_with_cuda.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.Coordinator.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.train.Coordinator.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.Coordinator.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.train.Coordinator.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.Server.create_local_server.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.train.Server.create_local_server.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.Server.create_local_server.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.train.Server.create_local_server.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.start_queue_runners.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.train.start_queue_runners.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.start_queue_runners.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.train.start_queue_runners.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.summary_iterator.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.train.summary_iterator.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.train.summary_iterator.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.train.summary_iterator.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.truncated_normal_initializer.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.truncated_normal_initializer.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.truncated_normal_initializer.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.truncated_normal_initializer.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.verify_tensor_all_finite.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.verify_tensor_all_finite.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.verify_tensor_all_finite.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.verify_tensor_all_finite.md diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/tf.zeros_initializer.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.zeros_initializer.md similarity index 100% rename from tensorflow/g3doc/api_docs/python/functions_and_classes/tf.zeros_initializer.md rename to tensorflow/g3doc/api_docs/python/functions_and_classes/shard9/tf.zeros_initializer.md diff --git a/tensorflow/g3doc/api_docs/python/index.md b/tensorflow/g3doc/api_docs/python/index.md index 1617ebffa9e..1a199d25195 100644 --- a/tensorflow/g3doc/api_docs/python/index.md +++ b/tensorflow/g3doc/api_docs/python/index.md @@ -530,10 +530,10 @@ * [`DiscreteDistribution`](../../api_docs/python/contrib.distributions.md#DiscreteDistribution) * [`Exponential`](../../api_docs/python/contrib.distributions.md#Exponential) * [`Gamma`](../../api_docs/python/contrib.distributions.md#Gamma) - * [`Gaussian`](../../api_docs/python/contrib.distributions.md#Gaussian) - * [`gaussian_congugates_known_sigma_predictive`](../../api_docs/python/contrib.distributions.md#gaussian_congugates_known_sigma_predictive) - * [`gaussian_conjugates_known_sigma_posterior`](../../api_docs/python/contrib.distributions.md#gaussian_conjugates_known_sigma_posterior) * [`MultivariateNormal`](../../api_docs/python/contrib.distributions.md#MultivariateNormal) + * [`Normal`](../../api_docs/python/contrib.distributions.md#Normal) + * [`normal_congugates_known_sigma_predictive`](../../api_docs/python/contrib.distributions.md#normal_congugates_known_sigma_predictive) + * [`normal_conjugates_known_sigma_posterior`](../../api_docs/python/contrib.distributions.md#normal_conjugates_known_sigma_posterior) * [`StudentT`](../../api_docs/python/contrib.distributions.md#StudentT) * [`Uniform`](../../api_docs/python/contrib.distributions.md#Uniform) diff --git a/tensorflow/g3doc/api_docs/python/nn.md b/tensorflow/g3doc/api_docs/python/nn.md index 35a94808c1e..9405ed74d1e 100644 --- a/tensorflow/g3doc/api_docs/python/nn.md +++ b/tensorflow/g3doc/api_docs/python/nn.md @@ -418,6 +418,12 @@ horizontal and vertical strides, `strides = [1, stride, stride, 1]`. A 4-D `Tensor` of shape `[batch, out_height, out_width, out_channels]`. +##### Raises: + + +* `ValueError`: If channel_multiplier * in_channels > out_channels, + which means that the separable convolution is overparameterized. + - - - diff --git a/tensorflow/g3doc/get_started/basic_usage.md b/tensorflow/g3doc/get_started/basic_usage.md index ad95472574f..b4289a986d1 100644 --- a/tensorflow/g3doc/get_started/basic_usage.md +++ b/tensorflow/g3doc/get_started/basic_usage.md @@ -274,9 +274,9 @@ example we fetched the single node `state`, but you can also fetch multiple tensors: ```python -input1 = tf.constant(3.0) -input2 = tf.constant(2.0) -input3 = tf.constant(5.0) +input1 = tf.constant([3.0]) +input2 = tf.constant([2.0]) +input3 = tf.constant([5.0]) intermed = tf.add(input2, input3) mul = tf.mul(input1, intermed) diff --git a/tensorflow/python/framework/docs.py b/tensorflow/python/framework/docs.py index 12dfc71ffcd..92c6cfb3d92 100644 --- a/tensorflow/python/framework/docs.py +++ b/tensorflow/python/framework/docs.py @@ -33,6 +33,8 @@ _always_drop_symbol_re = re.compile("_[_a-zA-Z0-9]") _anchor_re = re.compile(r"^[\w.]+$") _member_mark = "@@" _indiv_dir = "functions_and_classes" +_num_subdirs = 10 +_subdir_prefix = "shard" class Document(object): @@ -250,6 +252,24 @@ class Library(Document): or self._should_include_member(name)): yield name, ("%s.%s" % (cls_name, name), member) + def shard_dir(self, name): + """Returns the path of the doc subdirectory for member `name`. + + When generating individual files for each function and class, we shard + the files across several directories to avoid hitting the limit for + files per directory. This function determines the subdirectory for + a member based on a hash of its name. + + Args: + name: string. The name of a function or class. + + Returns: + The path to a subdirectory of the api docs directory. + """ + index = hash(name) % _num_subdirs + return os.path.join(self.functions_and_classes_dir, + _subdir_prefix + str(index)) + def set_functions_and_classes_dir(self, dirname): """Sets the name of the directory for function and class markdown files. @@ -400,7 +420,7 @@ class Library(Document): # Write an individual file for each function. if inspect.isfunction(member): indivf = open( - os.path.join(self.functions_and_classes_dir, name + ".md"), "w+") + os.path.join(self.shard_dir(name), name + ".md"), "w+") self._print_function(indivf, prefix, name, member) elif inspect.isclass(member): print("- - -", file=f) @@ -414,7 +434,7 @@ class Library(Document): # Write an individual file for each class. indivf = open( - os.path.join(self.functions_and_classes_dir, name + ".md"), "w+") + os.path.join(self.shard_dir(name), name + ".md"), "w+") self._write_class_markdown_to_file(indivf, name, member) else: raise RuntimeError("Member %s has unknown type %s" % (name, type(member))) @@ -547,11 +567,17 @@ def write_libraries(output_dir, libraries): files = [open(os.path.join(output_dir, k), "w") for k, _ in libraries] # Set the directory in which to save individual class and function md files, - # creating it if it doesn't exist. + # creating it if it doesn't exist. Create subdirectories to avoid hitting + # the limit for number of files in a directory. indiv_dir = os.path.join(output_dir, _indiv_dir) if not os.path.exists(indiv_dir): os.makedirs(indiv_dir) + for i in range(0, _num_subdirs): + subdir = os.path.join(indiv_dir, _subdir_prefix + str(i)) + if not os.path.exists(subdir): + os.makedirs(subdir) + # Document mentioned symbols for all libraries for f, (_, v) in zip(files, libraries): v.set_functions_and_classes_dir(indiv_dir) diff --git a/tensorflow/python/kernel_tests/batchtospace_op_test.py b/tensorflow/python/kernel_tests/batchtospace_op_test.py index df46c446fc6..2bc1fd5e10c 100644 --- a/tensorflow/python/kernel_tests/batchtospace_op_test.py +++ b/tensorflow/python/kernel_tests/batchtospace_op_test.py @@ -88,6 +88,11 @@ class BatchToSpaceErrorHandlingTest(tf.test.TestCase): with self.assertRaises(IndexError): _ = tf.batch_to_space(x_np, crops, block_size) + def testUnknownShape(self): + t = tf.batch_to_space(tf.placeholder(tf.float32), tf.placeholder(tf.int32), + block_size=4) + self.assertEqual(4, t.get_shape().ndims) + class BatchToSpaceGradientTest(tf.test.TestCase): diff --git a/tensorflow/python/kernel_tests/cholesky_op_test.py b/tensorflow/python/kernel_tests/cholesky_op_test.py index 199b54512e0..0189280e248 100644 --- a/tensorflow/python/kernel_tests/cholesky_op_test.py +++ b/tensorflow/python/kernel_tests/cholesky_op_test.py @@ -71,18 +71,23 @@ class CholeskyOpTest(tf.test.TestCase): def testNonSquareMatrix(self): with self.assertRaises(ValueError): tf.cholesky(np.array([[1., 2., 3.], [3., 4., 5.]])) + with self.assertRaises(ValueError): + tf.batch_cholesky(np.array([[[1., 2., 3.], [3., 4., 5.]], + [[1., 2., 3.], [3., 4., 5.]]])) def testWrongDimensions(self): tensor3 = tf.constant([1., 2.]) with self.assertRaises(ValueError): tf.cholesky(tensor3) + with self.assertRaises(ValueError): + tf.batch_cholesky(tensor3) def testNotInvertible(self): - # The input should be invertible. + # The input should be invertible. with self.test_session(): with self.assertRaisesOpError("LLT decomposition was not successful. The" " input might not be valid."): - # All rows of the matrix below add to zero + # All rows of the matrix below add to zero self._verifyCholesky(np.array([[1., -1., 0.], [-1., 1., -1.], [0., -1., 1.]])) @@ -122,24 +127,36 @@ class CholeskyGradTest(tf.test.TestCase): scalarTest=False): with self.test_session(use_gpu=False): for shape in shapes: - for dtype in dtypes: - if not(scalarTest): - x = tf.constant(np.random.randn(shape[0], shape[1]), dtype) - K = tf.matmul(x, tf.transpose(x)) / shape[0] # K is posdef - y = tf.cholesky(K) - else: # This is designed to be a faster test for larger matrices. - x = tf.constant(np.random.randn(), dtype) - R = tf.constant(np.random.randn(shape[0], shape[1]), dtype) - e = tf.mul(R, x) - K = tf.matmul(e, tf.transpose(e)) / shape[0] # K is posdef - y = tf.reduce_mean(tf.cholesky(K)) - error = tf.test.compute_gradient_error(x, x._shape_as_list(), - y, y._shape_as_list()) - tf.logging.info("error = %f", error) - if dtype == tf.float64: - self.assertLess(error, 1e-5) - else: - self.assertLess(error, 2e-3) + for batch in False, True: + for dtype in dtypes: + if not scalarTest: + x = tf.constant(np.random.randn(shape[0], shape[1]), dtype) + tensor = tf.matmul(x, tf.transpose(x)) / shape[0] + else: + # This is designed to be a faster test for larger matrices. + x = tf.constant(np.random.randn(), dtype) + R = tf.constant(np.random.randn(shape[0], shape[1]), dtype) + e = tf.mul(R, x) + tensor = tf.matmul(e, tf.transpose(e)) / shape[0] + + # Inner-most matrices in tensor are positive definite. + if batch: + tensor = tf.tile(tf.expand_dims(tensor, 0), [4, 1, 1]) + op = tf.batch_cholesky + else: + op = tf.cholesky + + if not (scalarTest): + y = op(tensor) + else: + y = tf.reduce_mean(op(tensor)) + error = tf.test.compute_gradient_error(x, x._shape_as_list(), y, + y._shape_as_list()) + tf.logging.info("error = %f", error) + if dtype == tf.float64: + self.assertLess(error, 1e-5) + else: + self.assertLess(error, 3e-3) if __name__ == "__main__": tf.test.main() diff --git a/tensorflow/python/kernel_tests/conv_ops_test.py b/tensorflow/python/kernel_tests/conv_ops_test.py index 74e91b826a7..50bb643402d 100644 --- a/tensorflow/python/kernel_tests/conv_ops_test.py +++ b/tensorflow/python/kernel_tests/conv_ops_test.py @@ -1059,9 +1059,9 @@ class SeparableConv2DTest(tf.test.TestCase): def testSeparableConv2D(self): # The output is the result of two convolutions: - # First with tensor_in[1, 4, 4, 3] * filter1[2, 2, 3, 3]. - # Second with intermediate_out[4, 4, 3, 3] * filter2[1, 1, 3, 6]. - # Complexity is O(3*3*2*2 + 3*6*1*1] as opposed to O(3*6*2*2). + # First with tensor_in[1, 4, 4, 2] * filter1[2, 2, 2, 3]. + # Second with intermediate_out[1, 4, 4, 6] * filter2[1, 1, 6, 7]. + # Complexity is O(2*3*2*2 + 6*7*1*1) as opposed to O(2*7*2*2). expected_output = [ 6644.5, 6971.5, 7298.5, 7625.5, 7952.5, 8279.5, 8606.5, 8154.5, 8556.5, 8958.5, 9360.5, 9762.5, 10164.5, 10566.5, 9664.5, 10141.5, 10618.5, @@ -1084,6 +1084,46 @@ class SeparableConv2DTest(tf.test.TestCase): stride=1, padding="SAME", expected=expected_output) + def testSeparableConv2DEqualInputOutputDepth(self): + # The output is the result of two convolutions: + # First with tensor_in[1, 4, 4, 2] * filter1[2, 2, 3, 3]. + # Second with intermediate_out[1, 4, 4, 6] * filter2[1, 1, 6, 6]. + # Complexity is O(2*3*2*2 + 6*6*1*1) as opposed to O(2*6*2*2). + expected_output = [ + 5742.0, 6069.0, 6396.0, 6723.0, 7050.0, 7377.0, + 7047.0, 7449.0, 7851.0, 8253.0, 8655.0, 9057.0, + 8352.0, 8829.0, 9306.0, 9783.0, 10260.0, 10737.0, + 3582.0, 3783.0, 3984.0, 4185.0, 4386.0, 4587.0, + 10962.0, 11589.0, 12216.0, 12843.0, 13470.0, 14097.0, + 12267.0, 12969.0, 13671.0, 14373.0, 15075.0, 15777.0, + 13572.0, 14349.0, 15126.0, 15903.0, 16680.0, 17457.0, + 5616.0, 5931.0, 6246.0, 6561.0, 6876.0, 7191.0, + 16182.0, 17109.0, 18036.0, 18963.0, 19890.0, 20817.0, + 17487.0, 18489.0, 19491.0, 20493.0, 21495.0, 22497.0, + 18792.0, 19869.0, 20946.0, 22023.0, 23100.0, 24177.0, + 7650.0, 8079.0, 8508.0, 8937.0, 9366.0, 9795.0, + 4963.5, 5227.5, 5491.5, 5755.5, 6019.5, 6283.5, + 5328.0, 5611.5, 5895.0, 6178.5, 6462.0, 6745.5, + 5692.5, 5995.5, 6298.5, 6601.5, 6904.5, 7207.5, + 1757.25, 1840.5, 1923.75, 2007.0, 2090.25, 2173.5] + + self._VerifyValues(tensor_in_sizes=[1, 4, 4, 2], + depthwise_filter_in_sizes=[2, 2, 2, 3], + pointwise_filter_in_sizes=[1, 1, 6, 6], + stride=1, padding="SAME", + expected=expected_output) + + def testSeparableConv2DIllegalCases(self): + # Output depth less then input depth. + with self.assertRaisesRegexp( + ValueError, + "Refusing to perform an overparameterized separable convolution"): + self._VerifyValues(tensor_in_sizes=[1, 4, 4, 2], + depthwise_filter_in_sizes=[2, 2, 2, 3], + pointwise_filter_in_sizes=[1, 1, 6, 5], + stride=1, padding="SAME", + expected=None) + def GetInceptionFwdTest(input_size, filter_size, stride, padding): def Test(self): diff --git a/tensorflow/python/kernel_tests/depthtospace_op_test.py b/tensorflow/python/kernel_tests/depthtospace_op_test.py index e9e66e0617d..3bd6fa2c97b 100644 --- a/tensorflow/python/kernel_tests/depthtospace_op_test.py +++ b/tensorflow/python/kernel_tests/depthtospace_op_test.py @@ -180,6 +180,10 @@ class DepthToSpaceTest(tf.test.TestCase): with self.assertRaises(IndexError): _ = tf.space_to_depth(x_np, block_size) + def testUnknownShape(self): + t = tf.depth_to_space(tf.placeholder(tf.float32), block_size=4) + self.assertEqual(4, t.get_shape().ndims) + class DepthToSpaceGradientTest(tf.test.TestCase): diff --git a/tensorflow/python/kernel_tests/spacetobatch_op_test.py b/tensorflow/python/kernel_tests/spacetobatch_op_test.py index 2cb6f5e1048..690c9371cc0 100644 --- a/tensorflow/python/kernel_tests/spacetobatch_op_test.py +++ b/tensorflow/python/kernel_tests/spacetobatch_op_test.py @@ -195,6 +195,11 @@ class SpaceToBatchErrorHandlingTest(tf.test.TestCase): with self.assertRaises(IndexError): _ = tf.space_to_batch(x_np, paddings, block_size) + def testUnknownShape(self): + t = tf.space_to_batch(tf.placeholder(tf.float32), tf.placeholder(tf.int32), + block_size=4) + self.assertEqual(4, t.get_shape().ndims) + class SpaceToBatchGradientTest(tf.test.TestCase): diff --git a/tensorflow/python/kernel_tests/spacetodepth_op_test.py b/tensorflow/python/kernel_tests/spacetodepth_op_test.py index 88e650840e5..82612f6fbaa 100644 --- a/tensorflow/python/kernel_tests/spacetodepth_op_test.py +++ b/tensorflow/python/kernel_tests/spacetodepth_op_test.py @@ -198,6 +198,10 @@ class SpaceToDepthTest(tf.test.TestCase): with self.assertRaises(IndexError): _ = tf.space_to_depth(x_np, block_size) + def testUnknownShape(self): + t = tf.space_to_depth(tf.placeholder(tf.float32), block_size=4) + self.assertEqual(4, t.get_shape().ndims) + class SpaceToDepthGradientTest(tf.test.TestCase): diff --git a/tensorflow/python/kernel_tests/sparse_ops_test.py b/tensorflow/python/kernel_tests/sparse_ops_test.py index 037d1f2c3eb..f6474f58458 100644 --- a/tensorflow/python/kernel_tests/sparse_ops_test.py +++ b/tensorflow/python/kernel_tests/sparse_ops_test.py @@ -637,6 +637,17 @@ class SparseSoftmaxTest(test_util.TensorFlowTestCase): self.assertAllEqual(sp_t.indices.eval(), result.indices) self.assertAllEqual(shape, result.shape) + def testGradient(self): + x_shape = [2, 5, 10] + with self.test_session(use_gpu=False): + for dtype in [np.float32, np.float64]: + x_np = np.random.randn(*x_shape).astype(dtype) + x_tf, nnz = _sparsify(x_np) + y_tf = tf.sparse_softmax(x_tf) + err = tf.test.compute_gradient_error(x_tf.values, (nnz,), y_tf.values, + (nnz,)) + self.assertLess(err, 1e-4) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index c19d93829a7..b8aa311951a 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -1723,15 +1723,19 @@ def _SpaceToBatchShape(op): raise ValueError("Attribute block_size has to be > 1.") paddings = tensor_util.constant_value(op.inputs[1]) - if (paddings[0, 0] < 0 or paddings[0, 1] < 0 or - paddings[1, 0] < 0 or paddings[1, 1] < 0): - raise ValueError("paddings cannot be negative.") + if paddings is not None: + if (paddings[0, 0] < 0 or paddings[0, 1] < 0 or + paddings[1, 0] < 0 or paddings[1, 1] < 0): + raise ValueError("paddings cannot be negative.") - input_height = input_shape[1] + paddings[0, 0] + paddings[0, 1] - input_width = input_shape[2] + paddings[1, 0] + paddings[1, 1] + input_height = input_shape[1] + paddings[0, 0] + paddings[0, 1] + input_width = input_shape[2] + paddings[1, 0] + paddings[1, 1] - if input_height % block_size > 0 or input_width % block_size > 0: - raise IndexError("block_size needs to divide both width and height.") + if input_height % block_size > 0 or input_width % block_size > 0: + raise IndexError("block_size needs to divide both width and height.") + else: + input_height = tensor_shape.Dimension(None) + input_width = tensor_shape.Dimension(None) batch = input_shape[0] * block_size * block_size height = input_height // block_size @@ -1792,8 +1796,9 @@ def _BatchToSpaceShape(op): "tf.space_to_batch() requires input crops with shape [2, 2].") crops = tensor_util.constant_value(op.inputs[1]) - if (crops[0, 0] < 0 or crops[0, 1] < 0 or - crops[1, 0] < 0 or crops[1, 1] < 0): + if (crops is not None and + (crops[0, 0] < 0 or crops[0, 1] < 0 or + crops[1, 0] < 0 or crops[1, 1] < 0)): raise ValueError("crops cannot be negative.") block_size = op.get_attr("block_size") @@ -1805,10 +1810,14 @@ def _BatchToSpaceShape(op): raise IndexError("input batch must be divisible by block_size*block_size.") batch = input_batch // (block_size * block_size) - height = input_shape[1] * block_size - crops[0, 0] - crops[0, 1] - width = input_shape[2] * block_size - crops[1, 0] - crops[1, 1] - if height <= 0 or width <= 0: - raise ValueError("Output height or width is not positive.") + if crops is not None: + height = input_shape[1] * block_size - crops[0, 0] - crops[0, 1] + width = input_shape[2] * block_size - crops[1, 0] - crops[1, 1] + if height <= 0 or width <= 0: + raise ValueError("Output height or width is not positive.") + else: + height = tensor_shape.Dimension(None) + width = tensor_shape.Dimension(None) depth = input_shape[3] return [tensor_shape.TensorShape([batch, height, width, depth])] diff --git a/tensorflow/python/ops/linalg_grad.py b/tensorflow/python/ops/linalg_grad.py index 92911009eb7..36c82278584 100644 --- a/tensorflow/python/ops/linalg_grad.py +++ b/tensorflow/python/ops/linalg_grad.py @@ -32,6 +32,9 @@ from tensorflow.python.ops import constant_op from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops +ops.NoGradient("CholeskyGrad") +ops.NoGradient("BatchCholeskyGrad") + @ops.RegisterGradient("MatrixInverse") def _MatrixInverseGrad(op, grad): @@ -76,11 +79,17 @@ def _BatchMatrixDeterminantGrad(op, grad): @ops.RegisterGradient("Cholesky") -def _cholesky_grad(op, grad): +def _CholeskyGrad(op, grad): """Gradient for Cholesky.""" return linalg_ops.cholesky_grad(op.outputs[0], grad) +@ops.RegisterGradient("BatchCholesky") +def _BatchCholeskyGrad(op, grad): + """Gradient for BatchCholesky.""" + return linalg_ops.batch_cholesky_grad(op.outputs[0], grad) + + @ops.RegisterGradient("MatrixSolve") def _MatrixSolveGrad(op, grad): """Gradients for MatrixSolve.""" diff --git a/tensorflow/python/ops/linalg_ops.py b/tensorflow/python/ops/linalg_ops.py index 983851e09e4..66a64a998ba 100644 --- a/tensorflow/python/ops/linalg_ops.py +++ b/tensorflow/python/ops/linalg_ops.py @@ -28,6 +28,7 @@ from tensorflow.python.ops.gen_linalg_ops import * @ops.RegisterShape("Cholesky") +@ops.RegisterShape("CholeskyGrad") @ops.RegisterShape("MatrixInverse") def _UnchangedSquare(op): input_shape = op.inputs[0].get_shape().with_rank(2) @@ -37,6 +38,7 @@ def _UnchangedSquare(op): @ops.RegisterShape("BatchCholesky") +@ops.RegisterShape("BatchCholeskyGrad") @ops.RegisterShape("BatchMatrixInverse") def _BatchUnchangedSquare(op): input_shape = op.inputs[0].get_shape().with_rank_at_least(2) @@ -44,10 +46,6 @@ def _BatchUnchangedSquare(op): input_shape[-1].assert_is_compatible_with(input_shape[-2]) return [input_shape] -@ops.RegisterShape("CholeskyGrad") -def _cholesky_grad_shape(op): - return [op.inputs[0].get_shape()] - @ops.RegisterShape("MatrixDeterminant") def _MatrixDeterminantShape(op): input_shape = op.inputs[0].get_shape().with_rank(2) diff --git a/tensorflow/python/ops/nn.py b/tensorflow/python/ops/nn.py index d47b03db5b6..e6b8bb664fd 100644 --- a/tensorflow/python/ops/nn.py +++ b/tensorflow/python/ops/nn.py @@ -559,6 +559,10 @@ def separable_conv2d(input, depthwise_filter, pointwise_filter, strides, Returns: A 4-D `Tensor` of shape `[batch, out_height, out_width, out_channels]`. + + Raises: + ValueError: If channel_multiplier * in_channels > out_channels, + which means that the separable convolution is overparameterized. """ with ops.op_scope([input, depthwise_filter, pointwise_filter], name, "separable_conv2d") as name: @@ -576,8 +580,13 @@ def separable_conv2d(input, depthwise_filter, pointwise_filter, strides, channel_multiplier = depthwise_filter.get_shape()[3] in_channels = input.get_shape()[3] out_channels = pointwise_filter.get_shape()[3] - # This would mean the separable convolutions is over-parametrized. - assert channel_multiplier * in_channels < out_channels + if channel_multiplier * in_channels > out_channels: + raise ValueError( + ("Refusing to perform an overparameterized separable " + "convolution: channel_multiplier * in_channels = " + "%d * %d = %d > %d = out_channels" % + (channel_multiplier, in_channels, + channel_multiplier * in_channels, out_channels))) # The layout of the ops in the graph are expected to be as follows: # depthwise_conv2d // Conv2D op corresponding to native deptwise conv. # separable_conv2d // Conv2D op corresponding to the pointwise conv. diff --git a/tensorflow/python/ops/session_ops.py b/tensorflow/python/ops/session_ops.py index 4e43073e9e6..f97e5f5b564 100644 --- a/tensorflow/python/ops/session_ops.py +++ b/tensorflow/python/ops/session_ops.py @@ -207,7 +207,7 @@ def _get_handle_reader(graph, handle, dtype): if result is None: # Create reader if we haven't done it. handle_device = TensorHandle._get_device_name(handle) - with ops.device(handle_device): + with graph.as_default(), graph.device(handle_device): holder = array_ops.placeholder(dtypes.string) _register_handle_feeder(holder.graph, holder, dtype) reader = gen_data_flow_ops._get_session_tensor(holder, dtype) @@ -234,7 +234,7 @@ def _get_handle_mover(graph, feeder, handle): if result is None: # Create mover if we haven't done it. holder, reader = _get_handle_reader(graph, handle, dtype) - with ops.device(feeder.op.device): + with graph.as_default(), graph.device(feeder.op.device): mover = gen_data_flow_ops._get_session_handle(reader) result = (holder, mover) graph._handle_movers[graph_key] = result @@ -248,7 +248,7 @@ def _get_handle_deleter(graph, handle): if result is None: # Create deleter if we haven't done it. handle_device = TensorHandle._get_device_name(handle) - with ops.device(handle_device): + with graph.as_default(), graph.device(handle_device): holder = array_ops.placeholder(dtypes.string) deleter = gen_data_flow_ops._delete_session_tensor(holder) result = (holder, deleter) diff --git a/tensorflow/python/ops/sparse_grad.py b/tensorflow/python/ops/sparse_grad.py index d2e1ae20967..16c45aba544 100644 --- a/tensorflow/python/ops/sparse_grad.py +++ b/tensorflow/python/ops/sparse_grad.py @@ -227,6 +227,33 @@ def _SparseDenseCwiseDivGrad(op, grad): @ops.RegisterGradient("SparseSoftmax") -def _SparseSoftmaxGrad(unused_op, unused_grad): - raise NotImplementedError("SparseSoftmax op doesn't have its gradient" - "implemented yet") +def _SparseSoftmaxGrad(op, grad): + """Gradients for SparseSoftmax. + + The calculation is the same as SoftmaxGrad: + + grad_x = grad_softmax * softmax - sum(grad_softmax * softmax) * softmax + + where we now only operate on the non-zero values present in the SparseTensors. + + Args: + op: the SparseSoftmax op. + grad: the upstream gradient w.r.t. the non-zero SparseSoftmax output values. + + Returns: + Gradients w.r.t. the input (sp_indices, sp_values, sp_shape). + """ + indices, shape = op.inputs[0], op.inputs[2] + out_vals = op.outputs[0] + sp_output = ops.SparseTensor(indices, out_vals, shape) + sp_grad = ops.SparseTensor(indices, grad, shape) + sp_product = ops.SparseTensor( + indices, sp_output.values * sp_grad.values, shape) + + # [..., B, 1], dense. + sum_reduced = -sparse_ops.sparse_reduce_sum(sp_product, [-1], keep_dims=True) + # sparse [..., B, C] + dense [..., B, 1] with broadcast; outputs sparse. + sp_sum = sparse_ops.sparse_dense_cwise_add(sp_grad, sum_reduced) + + grad_x = sp_sum.values * sp_output.values + return [None, grad_x, None]