diff --git a/AUTHORS b/AUTHORS index e3289a50bcc..a46ae7e616a 100644 --- a/AUTHORS +++ b/AUTHORS @@ -7,3 +7,4 @@ # The email address is not required for organizations. Google Inc. +Yuan Tang terrytangyuan@gmail.com diff --git a/README.md b/README.md index e640f54774b..578b985d643 100644 --- a/README.md +++ b/README.md @@ -59,6 +59,7 @@ Hello, TensorFlow! * [TensorFlow website](http://tensorflow.org) * [TensorFlow whitepaper](http://download.tensorflow.org/paper/whitepaper2015.pdf) +* [TensorFlow Model Zoo](https://github.com/tensorflow/models) * [TensorFlow MOOC on Udacity] (https://www.udacity.com/course/deep-learning--ud730) The TensorFlow community has created amazing things with TensorFlow, please see the [resources section of tensorflow.org](https://www.tensorflow.org/versions/master/resources#community) for an incomplete list. diff --git a/RELEASE.md b/RELEASE.md index 3843d543e93..60d77764c06 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -51,7 +51,7 @@ This release contains contributions from many people at Google, as well as: -Aaron Schumacher, Aidan Dang, Akihiko ITOH, Aki Sukegawa, Arbit Chen, Aziz Alto, Danijar Hafner, Erik Erwitt, Fabrizio Milo, Felix Maximilian Möller, Henry Saputra, Sung Kim, Igor Babuschkin, Jan Zikes, Jesper Steen Møller, Johannes Mayer, Justin Harris, Kashif Rasul, Kevin Robinson, Loo Rong Jie, Lucas Moura, Łukasz Bieniasz-Krzywiec, Mario Cho, Maxim Grechkin, Michael Heilman, Mostafa Rahmani, Mourad Mourafiq, @ninotoshi, Orion Reblitz-Richardson, Yuncheng Li, @raoqiyu, Robert DiPietro, Sam Abrahams, Sebastian Raschka, Siddharth Agrawal, @snakecharmer1024, Stephen Roller, Sung Kim, SunYeop Lee, Thijs Vogels, Till Hoffmann, Victor Melo, Ville Kallioniemi, Waleed Abdulla, Wenjian Huang, Yaroslav Bulatov, Yeison Rodriguez, Yuan (Terry) Tang, Yuxin Wu, @zhongzyd, Ziming Dong, Zohar Jackson +Aaron Schumacher, Aidan Dang, Akihiko ITOH, Aki Sukegawa, Arbit Chen, Aziz Alto, Danijar Hafner, Erik Erwitt, Fabrizio Milo, Felix Maximilian Möller, Henry Saputra, Sung Kim, Igor Babuschkin, Jan Zikes, Jeremy Barnes, Jesper Steen Møller, Johannes Mayer, Justin Harris, Kashif Rasul, Kevin Robinson, Loo Rong Jie, Lucas Moura, Łukasz Bieniasz-Krzywiec, Mario Cho, Maxim Grechkin, Michael Heilman, Mostafa Rahmani, Mourad Mourafiq, @ninotoshi, Orion Reblitz-Richardson, Yuncheng Li, @raoqiyu, Robert DiPietro, Sam Abrahams, Sebastian Raschka, Siddharth Agrawal, @snakecharmer1024, Stephen Roller, Sung Kim, SunYeop Lee, Thijs Vogels, Till Hoffmann, Victor Melo, Ville Kallioniemi, Waleed Abdulla, Wenjian Huang, Yaroslav Bulatov, Yeison Rodriguez, Yuan (Terry) Tang, Yuxin Wu, @zhongzyd, Ziming Dong, Zohar Jackson We are also grateful to all who filed issues or helped resolve them, asked and answered questions, and were part of inspiring discussions. diff --git a/tensorflow/contrib/learn/python/learn/README.md b/tensorflow/contrib/learn/python/learn/README.md index f474eb4e541..2016f53a8a2 100644 --- a/tensorflow/contrib/learn/python/learn/README.md +++ b/tensorflow/contrib/learn/python/learn/README.md @@ -59,8 +59,8 @@ Simple linear classification: from sklearn import datasets, metrics iris = datasets.load_iris() -classifier = learn.TensorFlowLinearClassifier(n_classes=3) -classifier.fit(iris.data, iris.target) +classifier = learn.LinearClassifier(n_classes=3) +classifier.fit(iris.data, iris.target, steps=200, batch_size=32) score = metrics.accuracy_score(iris.target, classifier.predict(iris.data)) print("Accuracy: %f" % score) ``` @@ -74,8 +74,8 @@ from sklearn import datasets, metrics, preprocessing boston = datasets.load_boston() x = preprocessing.StandardScaler().fit_transform(boston.data) -regressor = learn.TensorFlowLinearRegressor() -regressor.fit(x, boston.target) +regressor = learn.LinearRegressor() +regressor.fit(x, boston.target, steps=200, batch_size=32) score = metrics.mean_squared_error(regressor.predict(x), boston.target) print ("MSE: %f" % score) ``` @@ -88,15 +88,15 @@ Example of 3 layer network with 10, 20 and 10 hidden units respectively: from sklearn import datasets, metrics iris = datasets.load_iris() -classifier = learn.TensorFlowDNNClassifier(hidden_units=[10, 20, 10], n_classes=3) -classifier.fit(iris.data, iris.target) +classifier = learn.DNNClassifier(hidden_units=[10, 20, 10], n_classes=3) +classifier.fit(iris.data, iris.target, steps=200, batch_size=32) score = metrics.accuracy_score(iris.target, classifier.predict(iris.data)) print("Accuracy: %f" % score) ``` ## Custom model -Example of how to pass a custom model to the TensorFlowEstimator: +Example of how to pass a custom model to the Estimator: ```python from sklearn import datasets, metrics @@ -108,7 +108,7 @@ def my_model(x, y): layers = learn.ops.dnn(x, [10, 20, 10], dropout=0.5) return learn.models.logistic_regression(layers, y) -classifier = learn.TensorFlowEstimator(model_fn=my_model, n_classes=3) +classifier = learn.Estimator(model_fn=my_model, n_classes=3) classifier.fit(iris.data, iris.target) score = metrics.accuracy_score(iris.target, classifier.predict(iris.data)) print("Accuracy: %f" % score) @@ -116,16 +116,16 @@ print("Accuracy: %f" % score) ## Saving / Restoring models -Each estimator has a ``save`` method which takes folder path where all model information will be saved. For restoring you can just call ``learn.TensorFlowEstimator.restore(path)`` and it will return object of your class. +Each estimator has a ``save`` method which takes folder path where all model information will be saved. For restoring you can just call ``learn.Estimator.restore(path)`` and it will return object of your class. Some example code: ```python -classifier = learn.TensorFlowLinearRegression() +classifier = learn.LinearRegressor() classifier.fit(...) classifier.save('/tmp/tf_examples/my_model_1/') -new_classifier = TensorFlowEstimator.restore('/tmp/tf_examples/my_model_2') +new_classifier = Estimator.restore('/tmp/tf_examples/my_model_2') new_classifier.predict(...) ``` @@ -134,7 +134,7 @@ new_classifier.predict(...) To get nice visualizations and summaries you can use ``logdir`` parameter on ``fit``. It will start writing summaries for ``loss`` and histograms for variables in your model. You can also add custom summaries in your custom model function by calling ``tf.summary`` and passing Tensors to report. ```python -classifier = learn.TensorFlowLinearRegression() +classifier = learn.LinearRegressor() classifier.fit(x, y, logdir='/tmp/tf_examples/my_model_1/') ``` diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn.py b/tensorflow/contrib/learn/python/learn/estimators/dnn.py index fdb598efc5a..63d103ed35c 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn.py @@ -152,6 +152,11 @@ class DNNClassifier(dnn_linear_combined.DNNLinearCombinedClassifier): gradient_clip_norm=gradient_clip_norm, enable_centered_bias=enable_centered_bias, config=config) + self.feature_columns = feature_columns + self.optimizer = optimizer + self.activation_fn = activation_fn + self.dropout = dropout + self.hidden_units = hidden_units self._feature_columns_inferred = False # TODO(b/29580537): Remove feature_columns inference. @@ -299,6 +304,11 @@ class DNNRegressor(dnn_linear_combined.DNNLinearCombinedRegressor): gradient_clip_norm=gradient_clip_norm, enable_centered_bias=enable_centered_bias, config=config) + self.feature_columns = feature_columns + self.optimizer = optimizer + self.activation_fn = activation_fn + self.dropout = dropout + self.hidden_units = hidden_units self._feature_columns_inferred = False # TODO(b/29580537): Remove feature_columns inference. diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py index ea09f717857..6304d06f555 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py @@ -21,6 +21,13 @@ from __future__ import print_function import tensorflow as tf +# pylint: disable=g-import-not-at-top +try: + from sklearn.cross_validation import cross_val_score + HAS_SKLEARN = True +except ImportError: + HAS_SKLEARN = False + def _iris_input_fn(): iris = tf.contrib.learn.datasets.load_iris() @@ -59,6 +66,28 @@ class DNNClassifierTest(tf.test.TestCase): classifier.fit(input_fn=_iris_input_fn, steps=1000) self.assertFalse('centered_bias_weight' in classifier.get_variable_names()) + def testSklearnCompatibility(self): + """Tests compatibility with sklearn""" + if not HAS_SKLEARN: + return + iris = tf.contrib.learn.datasets.load_iris() + kwargs = { + "n_classes": 3, + "optimizer" : "Adam", + "hidden_units" : [3, 4] + } + + classifier = tf.contrib.learn.DNNClassifier(**kwargs) + + scores = cross_val_score( + classifier, + iris.data[1:5], + iris.target[1:5], + scoring="accuracy", + fit_params={"steps": 2} + ) + self.assertAllClose(scores, [1, 1, 1]) + class DNNRegressorTest(tf.test.TestCase): diff --git a/tensorflow/contrib/learn/python/learn/tests/dataframe/arithmetic_transform_test.py b/tensorflow/contrib/learn/python/learn/tests/dataframe/arithmetic_transform_test.py index be8305e3da3..2266caeb2f0 100644 --- a/tensorflow/contrib/learn/python/learn/tests/dataframe/arithmetic_transform_test.py +++ b/tensorflow/contrib/learn/python/learn/tests/dataframe/arithmetic_transform_test.py @@ -20,16 +20,24 @@ from __future__ import division from __future__ import print_function import numpy as np -import pandas as pd import tensorflow as tf from tensorflow.contrib.learn.python.learn.dataframe import tensorflow_dataframe as df +# pylint: disable=g-import-not-at-top +try: + import pandas as pd + HAS_PANDAS = True +except ImportError: + HAS_PANDAS = False + class SumTestCase(tf.test.TestCase): """Test class for `Sum` transform.""" def testSum(self): + if not HAS_PANDAS: + return num_rows = 100 pandas_df = pd.DataFrame({"a": np.arange(num_rows), diff --git a/tensorflow/contrib/makefile/README.md b/tensorflow/contrib/makefile/README.md index ebaacdfcd98..dff9373c103 100644 --- a/tensorflow/contrib/makefile/README.md +++ b/tensorflow/contrib/makefile/README.md @@ -61,7 +61,7 @@ On Ubuntu, you can do this: ```bash sudo apt-get install autoconf automake libtool curl make g++ unzip pushd . -cd tensforflow/contrib/makefile/downloads/protobuf +cd tensorflow/contrib/makefile/downloads/protobuf ./autogen.sh ./configure make @@ -104,7 +104,7 @@ tensorflow/contrib/makefile/gen/bin/benchmark \ ## Android First, you will need to download and unzip the -[Native Development Kit (NDK)](http://developers.google.com/ndk). You will not +[Native Development Kit (NDK)](https://developer.android.com/ndk/). You will not need to install the standalone toolchain, however. Assign your NDK location to $NDK_ROOT: @@ -153,7 +153,7 @@ For more details, see the [benchmark documentation](../../tools/benchmark). ## iOS _Note: To use this library in an iOS application, see related instructions in -the [iOS examples](../ios_examples/] directory._ +the [iOS examples](../ios_examples/) directory._ Install XCode 7.3 or more recent. If you have not already, you will need to install the command-line tools using `xcode-select`: @@ -189,7 +189,7 @@ benchmark program. Although successfully compiling the benchmark program is a sign of success, the program is not a complete iOS app. To see TensorFlow running on iOS, the example Xcode project in -[tensorflow/contrib/ios_example](../ios_example) shows how to use the static +[tensorflow/contrib/ios_examples](../ios_examples) shows how to use the static library in a simple app. ### Building by hand @@ -227,7 +227,7 @@ benchmark program. Although successfully compiling the benchmark program is a sign of success, the program is not a complete iOS app. To see TensorFlow running on iOS, the example Xcode project in -[tensorflow/contrib/ios_example](../ios_example) shows how to use the static +[tensorflow/contrib/ios_examples](../ios_examples) shows how to use the static library in a simple app. #### Universal binaries diff --git a/tensorflow/core/client/tensor_c_api.cc b/tensorflow/core/client/tensor_c_api.cc index e8aee7d3b1b..dccf66f2aec 100644 --- a/tensorflow/core/client/tensor_c_api.cc +++ b/tensorflow/core/client/tensor_c_api.cc @@ -115,7 +115,7 @@ struct TF_Tensor { TensorBuffer* buffer; }; -TF_Tensor* TF_NewTensor(TF_DataType dtype, tensorflow::int64* dims, +TF_Tensor* TF_NewTensor(TF_DataType dtype, const tensorflow::int64* dims, int num_dims, void* data, size_t len, void (*deallocator)(void* data, size_t len, void* arg), void* deallocator_arg) { diff --git a/tensorflow/core/public/tensor_c_api.h b/tensorflow/core/public/tensor_c_api.h index 1de5a86503a..d01f3a14cc6 100644 --- a/tensorflow/core/public/tensor_c_api.h +++ b/tensorflow/core/public/tensor_c_api.h @@ -187,7 +187,7 @@ typedef struct TF_Tensor TF_Tensor; // (*deallocator)(data, len, deallocator_arg) // Clients must provide a custom deallocator function so they can pass in // memory managed by something like numpy. -extern TF_Tensor* TF_NewTensor(TF_DataType, long long* dims, int num_dims, +extern TF_Tensor* TF_NewTensor(TF_DataType, const long long* dims, int num_dims, void* data, size_t len, void (*deallocator)(void* data, size_t len, void* arg), diff --git a/tensorflow/examples/skflow/digits.py b/tensorflow/examples/skflow/digits.py index b3c684b7dfc..6c9aec52da1 100644 --- a/tensorflow/examples/skflow/digits.py +++ b/tensorflow/examples/skflow/digits.py @@ -1,4 +1,4 @@ -# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved. +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -54,6 +54,6 @@ val_monitor = monitors.ValidationMonitor(X_val, y_val, every_n_steps=50) classifier = learn.TensorFlowEstimator(model_fn=conv_model, n_classes=10, steps=1000, learning_rate=0.05, batch_size=128) -classifier.fit(X_train, y_train, val_monitor) +classifier.fit(X_train, y_train, monitors=[val_monitor]) score = metrics.accuracy_score(y_test, classifier.predict(X_test)) print('Test Accuracy: {0:f}'.format(score)) diff --git a/tensorflow/examples/skflow/dnn_autoencoder_iris.py b/tensorflow/examples/skflow/dnn_autoencoder_iris.py index c4383ae6083..284bd9e58a5 100644 --- a/tensorflow/examples/skflow/dnn_autoencoder_iris.py +++ b/tensorflow/examples/skflow/dnn_autoencoder_iris.py @@ -1,4 +1,4 @@ -# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved. +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensorflow/examples/skflow/hdf5_classification.py b/tensorflow/examples/skflow/hdf5_classification.py index edcce6fe6f8..50e7d73b954 100644 --- a/tensorflow/examples/skflow/hdf5_classification.py +++ b/tensorflow/examples/skflow/hdf5_classification.py @@ -1,4 +1,4 @@ -# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved. +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensorflow/examples/skflow/iris_custom_model.py b/tensorflow/examples/skflow/iris_custom_model.py index afce504b744..8e2ab2ec882 100644 --- a/tensorflow/examples/skflow/iris_custom_model.py +++ b/tensorflow/examples/skflow/iris_custom_model.py @@ -1,4 +1,4 @@ -# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved. +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensorflow/examples/skflow/iris_run_config.py b/tensorflow/examples/skflow/iris_run_config.py index de9b44d460e..6ca563e9a32 100644 --- a/tensorflow/examples/skflow/iris_run_config.py +++ b/tensorflow/examples/skflow/iris_run_config.py @@ -1,4 +1,4 @@ -# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved. +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensorflow/examples/skflow/iris_val_based_early_stopping.py b/tensorflow/examples/skflow/iris_val_based_early_stopping.py index 70dd8053aa5..c80a0ccca1b 100644 --- a/tensorflow/examples/skflow/iris_val_based_early_stopping.py +++ b/tensorflow/examples/skflow/iris_val_based_early_stopping.py @@ -1,4 +1,4 @@ -# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved. +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensorflow/examples/skflow/iris_with_pipeline.py b/tensorflow/examples/skflow/iris_with_pipeline.py index ee5f9aed81b..c548387f388 100644 --- a/tensorflow/examples/skflow/iris_with_pipeline.py +++ b/tensorflow/examples/skflow/iris_with_pipeline.py @@ -1,4 +1,4 @@ -# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved. +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensorflow/examples/skflow/language_model.py b/tensorflow/examples/skflow/language_model.py index dcd65bf9f6e..7ee709fd912 100644 --- a/tensorflow/examples/skflow/language_model.py +++ b/tensorflow/examples/skflow/language_model.py @@ -1,6 +1,6 @@ # encoding: utf-8 -# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved. +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensorflow/examples/skflow/mnist_rnn.py b/tensorflow/examples/skflow/mnist_rnn.py index a6a594fad56..ddd6d7910f4 100644 --- a/tensorflow/examples/skflow/mnist_rnn.py +++ b/tensorflow/examples/skflow/mnist_rnn.py @@ -1,4 +1,4 @@ -# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved. +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensorflow/examples/skflow/mnist_weights.py b/tensorflow/examples/skflow/mnist_weights.py index 9ad019f9a4f..37d527c42cc 100644 --- a/tensorflow/examples/skflow/mnist_weights.py +++ b/tensorflow/examples/skflow/mnist_weights.py @@ -1,4 +1,4 @@ -# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved. +#t Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensorflow/examples/skflow/multioutput_regression.py b/tensorflow/examples/skflow/multioutput_regression.py index c0ddf1cf307..ef76a6ce270 100644 --- a/tensorflow/examples/skflow/multioutput_regression.py +++ b/tensorflow/examples/skflow/multioutput_regression.py @@ -1,4 +1,4 @@ -# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved. +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensorflow/examples/skflow/multiple_gpu.py b/tensorflow/examples/skflow/multiple_gpu.py index 1168184a38d..50e4b8252e9 100644 --- a/tensorflow/examples/skflow/multiple_gpu.py +++ b/tensorflow/examples/skflow/multiple_gpu.py @@ -1,4 +1,4 @@ -# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved. +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensorflow/examples/skflow/neural_translation.py b/tensorflow/examples/skflow/neural_translation.py index 7832767145c..ded54608ba2 100644 --- a/tensorflow/examples/skflow/neural_translation.py +++ b/tensorflow/examples/skflow/neural_translation.py @@ -1,6 +1,6 @@ # encoding: utf-8 -# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved. +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensorflow/examples/skflow/neural_translation_word.py b/tensorflow/examples/skflow/neural_translation_word.py index 90c73f0ba5c..185835c139a 100644 --- a/tensorflow/examples/skflow/neural_translation_word.py +++ b/tensorflow/examples/skflow/neural_translation_word.py @@ -1,6 +1,6 @@ # encoding: utf-8 -# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved. +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensorflow/examples/skflow/out_of_core_data_classification.py b/tensorflow/examples/skflow/out_of_core_data_classification.py index 5f612db3d79..5ed6033cc09 100644 --- a/tensorflow/examples/skflow/out_of_core_data_classification.py +++ b/tensorflow/examples/skflow/out_of_core_data_classification.py @@ -1,4 +1,4 @@ -# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved. +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensorflow/examples/skflow/text_classification.py b/tensorflow/examples/skflow/text_classification.py index fe19e273d35..3d34617016c 100644 --- a/tensorflow/examples/skflow/text_classification.py +++ b/tensorflow/examples/skflow/text_classification.py @@ -1,4 +1,4 @@ -# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved. +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensorflow/examples/skflow/text_classification_builtin_rnn_model.py b/tensorflow/examples/skflow/text_classification_builtin_rnn_model.py index fef5a2d9b3e..afaa0bfff76 100644 --- a/tensorflow/examples/skflow/text_classification_builtin_rnn_model.py +++ b/tensorflow/examples/skflow/text_classification_builtin_rnn_model.py @@ -1,4 +1,4 @@ -# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved. +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensorflow/examples/skflow/text_classification_character_cnn.py b/tensorflow/examples/skflow/text_classification_character_cnn.py index 998ed308078..be627f316e5 100644 --- a/tensorflow/examples/skflow/text_classification_character_cnn.py +++ b/tensorflow/examples/skflow/text_classification_character_cnn.py @@ -1,4 +1,4 @@ -# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved. +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensorflow/examples/skflow/text_classification_character_rnn.py b/tensorflow/examples/skflow/text_classification_character_rnn.py index a3de8aa42b6..864f678d4e4 100644 --- a/tensorflow/examples/skflow/text_classification_character_rnn.py +++ b/tensorflow/examples/skflow/text_classification_character_rnn.py @@ -1,4 +1,4 @@ -# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved. +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensorflow/examples/skflow/text_classification_cnn.py b/tensorflow/examples/skflow/text_classification_cnn.py index 0cbed33ef13..46238d2f037 100644 --- a/tensorflow/examples/skflow/text_classification_cnn.py +++ b/tensorflow/examples/skflow/text_classification_cnn.py @@ -1,4 +1,4 @@ -# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved. +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensorflow/examples/skflow/text_classification_save_restore.py b/tensorflow/examples/skflow/text_classification_save_restore.py index 9cabc322059..2b2831eb527 100644 --- a/tensorflow/examples/skflow/text_classification_save_restore.py +++ b/tensorflow/examples/skflow/text_classification_save_restore.py @@ -1,4 +1,4 @@ -# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved. +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tensorflow/examples/udacity/README.md b/tensorflow/examples/udacity/README.md index 98edc71e594..4743ab557bb 100644 --- a/tensorflow/examples/udacity/README.md +++ b/tensorflow/examples/udacity/README.md @@ -6,7 +6,7 @@ Course information can be found at https://www.udacity.com/course/deep-learning- Running the Docker container from the Google Cloud repository ------------------------------------------------------------- - docker run -p 8888:8888 -it --rm b.gcr.io/tensorflow-udacity/assignments:0.5.0 + docker run -p 8888:8888 -it b.gcr.io/tensorflow-udacity/assignments:0.5.0 Accessing the Notebooks ----------------------- @@ -19,6 +19,21 @@ On mac, find the virtual machine's IP using: Then go to: http://IP:8888 (likely http://192.168.99.100:8888) +Saving Your Progress +-------------------- + +Because of the `--rm` flag above, stopping the docker container removes it, so any changes you've made will disappear. One way around this is to remove the `--rm` flag, and name the container for easy restarting: +```sh +# you only need to "run" the container the first time: +docker run -p 8888:8888 -it --name tensorflow-udacity b.gcr.io/tensorflow-udacity/assignments:0.5.0 +# …do various things… +# when you're done, control-C to kill jupyter and stop the container +# when you're ready to do more things, you can now just "start" the container: +docker start -ai tensorflow-udacity +# …do more things… +# …repeat… +``` + FAQ --- diff --git a/tensorflow/g3doc/get_started/os_setup.md b/tensorflow/g3doc/get_started/os_setup.md index 923535144b8..e1cece4faa3 100644 --- a/tensorflow/g3doc/get_started/os_setup.md +++ b/tensorflow/g3doc/get_started/os_setup.md @@ -63,7 +63,7 @@ Then, select the correct binary to install: # Ubuntu/Linux 64-bit, CPU only, Python 2.7 $ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.9.0-cp27-none-linux_x86_64.whl -# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7 +# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7 # Requires CUDA toolkit 7.5 and CuDNN v4. For other versions, see "Install from sources" below. $ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.9.0-cp27-none-linux_x86_64.whl @@ -73,14 +73,14 @@ $ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/tensorflow- # Ubuntu/Linux 64-bit, CPU only, Python 3.4 $ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.9.0-cp34-cp34m-linux_x86_64.whl -# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4 +# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4 # Requires CUDA toolkit 7.5 and CuDNN v4. For other versions, see "Install from sources" below. $ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.9.0-cp34-cp34m-linux_x86_64.whl # Ubuntu/Linux 64-bit, CPU only, Python 3.5 $ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.9.0-cp35-cp35m-linux_x86_64.whl -# Ubuntu/Linux 64-bit, GPU enabled, Python 3.5 +# Ubuntu/Linux 64-bit, GPU enabled, Python 3.5 # Requires CUDA toolkit 7.5 and CuDNN v4. For other versions, see "Install from sources" below. $ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.9.0-cp35-cp35m-linux_x86_64.whl @@ -153,7 +153,7 @@ Now, install TensorFlow just as you would for a regular Pip installation. First # Ubuntu/Linux 64-bit, CPU only, Python 2.7 (tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.9.0-cp27-none-linux_x86_64.whl -# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7 +# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7 # Requires CUDA toolkit 7.5 and CuDNN v4. For other versions, see "Install from sources" below. (tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.9.0-cp27-none-linux_x86_64.whl @@ -163,14 +163,14 @@ Now, install TensorFlow just as you would for a regular Pip installation. First # Ubuntu/Linux 64-bit, CPU only, Python 3.4 (tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.9.0-cp34-cp34m-linux_x86_64.whl -# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4 +# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4 # Requires CUDA toolkit 7.5 and CuDNN v4. For other versions, see "Install from sources" below. (tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.9.0-cp34-cp34m-linux_x86_64.whl # Ubuntu/Linux 64-bit, CPU only, Python 3.5 (tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.9.0-cp35-cp35m-linux_x86_64.whl -# Ubuntu/Linux 64-bit, GPU enabled, Python 3.5 +# Ubuntu/Linux 64-bit, GPU enabled, Python 3.5 # Requires CUDA toolkit 7.5 and CuDNN v4. For other versions, see "Install from sources" below. (tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.9.0-cp35-cp35m-linux_x86_64.whl @@ -277,7 +277,7 @@ Now, install TensorFlow just as you would for a regular Pip installation. First # Ubuntu/Linux 64-bit, CPU only, Python 2.7 (tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.9.0-cp27-none-linux_x86_64.whl -# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7 +# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7 # Requires CUDA toolkit 7.5 and CuDNN v4. For other versions, see "Install from sources" below. (tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.9.0-cp27-none-linux_x86_64.whl @@ -287,14 +287,14 @@ Now, install TensorFlow just as you would for a regular Pip installation. First # Ubuntu/Linux 64-bit, CPU only, Python 3.4 (tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.9.0-cp34-cp34m-linux_x86_64.whl -# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4 +# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4 # Requires CUDA toolkit 7.5 and CuDNN v4. For other versions, see "Install from sources" below. (tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.9.0-cp34-cp34m-linux_x86_64.whl # Ubuntu/Linux 64-bit, CPU only, Python 3.5 (tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.9.0-cp35-cp35m-linux_x86_64.whl -# Ubuntu/Linux 64-bit, GPU enabled, Python 3.5 +# Ubuntu/Linux 64-bit, GPU enabled, Python 3.5 # Requires CUDA toolkit 7.5 and CuDNN v4. For other versions, see "Install from sources" below. (tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.9.0-cp35-cp35m-linux_x86_64.whl diff --git a/tensorflow/g3doc/how_tos/quantization/index.md b/tensorflow/g3doc/how_tos/quantization/index.md index 0431a7ad615..61461822a0a 100644 --- a/tensorflow/g3doc/how_tos/quantization/index.md +++ b/tensorflow/g3doc/how_tos/quantization/index.md @@ -6,7 +6,7 @@ were the top priorities. Using floating point arithmetic was the easiest way to preserve accuracy, and GPUs were well-equipped to accelerate those calculations, so it's natural that not much attention was paid to other numerical formats. -These days, we actually have a lot of models being being deployed in commercial +These days, we actually have a lot of models being deployed in commercial applications. The computation demands of training grow with the number of researchers, but the cycles needed for inference expand in proportion to users. That means pure inference efficiency has become a burning issue for a lot of diff --git a/tensorflow/g3doc/resources/index.md b/tensorflow/g3doc/resources/index.md index 249ec503271..2c5d06946ca 100644 --- a/tensorflow/g3doc/resources/index.md +++ b/tensorflow/g3doc/resources/index.md @@ -33,8 +33,9 @@ something amazing with TensorFlow, we'd like to hear about it! The TensorFlow community has created many great projects around TensorFlow, including: +* [@jtoy's awesome "Awesome TensorFlow" list of awesome things](https://github.com/jtoy/awesome-tensorflow) * [TensorFlow tutorials](https://github.com/pkmital/tensorflow_tutorials) -* [Scikit Flow - Simplified Interface for TensorFlow](https://github.com/tensorflow/skflow) +* [Scikit Flow - Simplified Interface for TensorFlow](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/learn/python/learn) * [Caffe to TensorFlow model converter](https://github.com/ethereon/caffe-tensorflow) ### Development diff --git a/tensorflow/g3doc/tutorials/mnist/pros/index.md b/tensorflow/g3doc/tutorials/mnist/pros/index.md index 12de1df66cb..324a29c02eb 100644 --- a/tensorflow/g3doc/tutorials/mnist/pros/index.md +++ b/tensorflow/g3doc/tutorials/mnist/pros/index.md @@ -232,7 +232,7 @@ print(accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels})) ## Build a Multilayer Convolutional Network -Getting 91% accuracy on MNIST is bad. It's almost embarrassingly bad. In this +Getting 92% accuracy on MNIST is bad. It's almost embarrassingly bad. In this section, we'll fix that, jumping from a very simple model to something moderately sophisticated: a small convolutional neural network. This will get us to around 99.2% accuracy -- not state of the art, but respectable. @@ -243,7 +243,7 @@ To create this model, we're going to need to create a lot of weights and biases. One should generally initialize weights with a small amount of noise for symmetry breaking, and to prevent 0 gradients. Since we're using ReLU neurons, it is also good practice to initialize them with a slightly positive initial -bias to avoid "dead neurons." Instead of doing this repeatedly while we build +bias to avoid "dead neurons". Instead of doing this repeatedly while we build the model, let's create two handy functions to do it for us. ```python diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py index 5f0549463f6..20658fa6327 100644 --- a/tensorflow/python/client/session.py +++ b/tensorflow/python/client/session.py @@ -205,7 +205,7 @@ class BaseSession(SessionInterface): Use with the `with` keyword to specify that calls to [`Operation.run()`](../../api_docs/python/framework.md#Operation.run) or - [`Tensor.run()`](../../api_docs/python/framework.md#Tensor.run) should be + [`Tensor.eval()`](../../api_docs/python/framework.md#Tensor.eval) should be executed in this session. ```python diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py index 7f1be574bbf..093da97469a 100644 --- a/tensorflow/python/kernel_tests/cwise_ops_test.py +++ b/tensorflow/python/kernel_tests/cwise_ops_test.py @@ -215,6 +215,7 @@ class UnaryOpTest(tf.test.TestCase): self._compareBothSparse(z, np.sqrt, tf.sqrt, tol=1e-3) self._compareBothSparse(x, np.tanh, tf.tanh) self._compareBothSparse(y, np.sign, tf.sign) + self._compareBothSparse(x, np.vectorize(math.erf), tf.erf) def testFloatTanhEdge(self): x = np.arange(40, 40 + 6).reshape(6).astype(np.float32) @@ -254,6 +255,7 @@ class UnaryOpTest(tf.test.TestCase): self._compareBothSparse(x, np.sqrt, tf.sqrt, tol=1e-3) self._compareBothSparse(x, np.tanh, tf.tanh) self._compareBothSparse(x, np.sign, tf.sign) + self._compareBothSparse(x, np.sign, tf.erf) def testDoubleBasic(self): x = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float64) @@ -292,6 +294,7 @@ class UnaryOpTest(tf.test.TestCase): self._compareBothSparse(z, np.sqrt, tf.sqrt, tol=1e-3) self._compareBothSparse(x, np.tanh, tf.tanh) self._compareBothSparse(y, np.sign, tf.sign) + self._compareBothSparse(x, np.vectorize(math.erf), tf.erf) def testHalfBasic(self): x = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float16) @@ -325,6 +328,7 @@ class UnaryOpTest(tf.test.TestCase): self._compareBothSparse(z, np.sqrt, tf.sqrt, tol=1e-3) self._compareBothSparse(x, np.tanh, tf.tanh) self._compareBothSparse(y, np.sign, tf.sign) + self._compareBothSparse(x, np.vectorize(math.erf), tf.erf, tol=1e-3) def testInt32Basic(self): x = np.arange(-6, 6, 2).reshape(1, 3, 2).astype(np.int32) diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 0bcf45db76f..07d93160ad8 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -348,6 +348,25 @@ def sqrt(x, name=None): return gen_math_ops.sqrt(x, name=name) +def erf(x, name=None): + """Computes the Gauss error function of `x` element-wise. + + Args: + x: A `Tensor` of `SparseTensor`. Must be one of the following types: `half`, + `float32`, `float64`. + name: A name for the operation (optional). + + Returns: + A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. + """ + with ops.op_scope([x], name, "Erf") as name: + if isinstance(x, ops.SparseTensor): + x_erf = gen_math_ops.erf(x.values, name=name) + return ops.SparseTensor(indices=x.indices, values=x_erf, shape=x.shape) + else: + return gen_math_ops.erf(x, name=name) + + def complex_abs(x, name=None): r"""Computes the complex absolute value of a tensor.