Merge changes from github.
Change: 127101926
This commit is contained in:
parent
cb8cdf73c3
commit
6cd8b28da1
1
AUTHORS
1
AUTHORS
@ -7,3 +7,4 @@
|
||||
# The email address is not required for organizations.
|
||||
|
||||
Google Inc.
|
||||
Yuan Tang terrytangyuan@gmail.com
|
||||
|
@ -59,6 +59,7 @@ Hello, TensorFlow!
|
||||
|
||||
* [TensorFlow website](http://tensorflow.org)
|
||||
* [TensorFlow whitepaper](http://download.tensorflow.org/paper/whitepaper2015.pdf)
|
||||
* [TensorFlow Model Zoo](https://github.com/tensorflow/models)
|
||||
* [TensorFlow MOOC on Udacity] (https://www.udacity.com/course/deep-learning--ud730)
|
||||
|
||||
The TensorFlow community has created amazing things with TensorFlow, please see the [resources section of tensorflow.org](https://www.tensorflow.org/versions/master/resources#community) for an incomplete list.
|
||||
|
@ -51,7 +51,7 @@
|
||||
|
||||
This release contains contributions from many people at Google, as well as:
|
||||
|
||||
Aaron Schumacher, Aidan Dang, Akihiko ITOH, Aki Sukegawa, Arbit Chen, Aziz Alto, Danijar Hafner, Erik Erwitt, Fabrizio Milo, Felix Maximilian Möller, Henry Saputra, Sung Kim, Igor Babuschkin, Jan Zikes, Jesper Steen Møller, Johannes Mayer, Justin Harris, Kashif Rasul, Kevin Robinson, Loo Rong Jie, Lucas Moura, Łukasz Bieniasz-Krzywiec, Mario Cho, Maxim Grechkin, Michael Heilman, Mostafa Rahmani, Mourad Mourafiq, @ninotoshi, Orion Reblitz-Richardson, Yuncheng Li, @raoqiyu, Robert DiPietro, Sam Abrahams, Sebastian Raschka, Siddharth Agrawal, @snakecharmer1024, Stephen Roller, Sung Kim, SunYeop Lee, Thijs Vogels, Till Hoffmann, Victor Melo, Ville Kallioniemi, Waleed Abdulla, Wenjian Huang, Yaroslav Bulatov, Yeison Rodriguez, Yuan (Terry) Tang, Yuxin Wu, @zhongzyd, Ziming Dong, Zohar Jackson
|
||||
Aaron Schumacher, Aidan Dang, Akihiko ITOH, Aki Sukegawa, Arbit Chen, Aziz Alto, Danijar Hafner, Erik Erwitt, Fabrizio Milo, Felix Maximilian Möller, Henry Saputra, Sung Kim, Igor Babuschkin, Jan Zikes, Jeremy Barnes, Jesper Steen Møller, Johannes Mayer, Justin Harris, Kashif Rasul, Kevin Robinson, Loo Rong Jie, Lucas Moura, Łukasz Bieniasz-Krzywiec, Mario Cho, Maxim Grechkin, Michael Heilman, Mostafa Rahmani, Mourad Mourafiq, @ninotoshi, Orion Reblitz-Richardson, Yuncheng Li, @raoqiyu, Robert DiPietro, Sam Abrahams, Sebastian Raschka, Siddharth Agrawal, @snakecharmer1024, Stephen Roller, Sung Kim, SunYeop Lee, Thijs Vogels, Till Hoffmann, Victor Melo, Ville Kallioniemi, Waleed Abdulla, Wenjian Huang, Yaroslav Bulatov, Yeison Rodriguez, Yuan (Terry) Tang, Yuxin Wu, @zhongzyd, Ziming Dong, Zohar Jackson
|
||||
|
||||
We are also grateful to all who filed issues or helped resolve them, asked and
|
||||
answered questions, and were part of inspiring discussions.
|
||||
|
@ -59,8 +59,8 @@ Simple linear classification:
|
||||
from sklearn import datasets, metrics
|
||||
|
||||
iris = datasets.load_iris()
|
||||
classifier = learn.TensorFlowLinearClassifier(n_classes=3)
|
||||
classifier.fit(iris.data, iris.target)
|
||||
classifier = learn.LinearClassifier(n_classes=3)
|
||||
classifier.fit(iris.data, iris.target, steps=200, batch_size=32)
|
||||
score = metrics.accuracy_score(iris.target, classifier.predict(iris.data))
|
||||
print("Accuracy: %f" % score)
|
||||
```
|
||||
@ -74,8 +74,8 @@ from sklearn import datasets, metrics, preprocessing
|
||||
|
||||
boston = datasets.load_boston()
|
||||
x = preprocessing.StandardScaler().fit_transform(boston.data)
|
||||
regressor = learn.TensorFlowLinearRegressor()
|
||||
regressor.fit(x, boston.target)
|
||||
regressor = learn.LinearRegressor()
|
||||
regressor.fit(x, boston.target, steps=200, batch_size=32)
|
||||
score = metrics.mean_squared_error(regressor.predict(x), boston.target)
|
||||
print ("MSE: %f" % score)
|
||||
```
|
||||
@ -88,15 +88,15 @@ Example of 3 layer network with 10, 20 and 10 hidden units respectively:
|
||||
from sklearn import datasets, metrics
|
||||
|
||||
iris = datasets.load_iris()
|
||||
classifier = learn.TensorFlowDNNClassifier(hidden_units=[10, 20, 10], n_classes=3)
|
||||
classifier.fit(iris.data, iris.target)
|
||||
classifier = learn.DNNClassifier(hidden_units=[10, 20, 10], n_classes=3)
|
||||
classifier.fit(iris.data, iris.target, steps=200, batch_size=32)
|
||||
score = metrics.accuracy_score(iris.target, classifier.predict(iris.data))
|
||||
print("Accuracy: %f" % score)
|
||||
```
|
||||
|
||||
## Custom model
|
||||
|
||||
Example of how to pass a custom model to the TensorFlowEstimator:
|
||||
Example of how to pass a custom model to the Estimator:
|
||||
|
||||
```python
|
||||
from sklearn import datasets, metrics
|
||||
@ -108,7 +108,7 @@ def my_model(x, y):
|
||||
layers = learn.ops.dnn(x, [10, 20, 10], dropout=0.5)
|
||||
return learn.models.logistic_regression(layers, y)
|
||||
|
||||
classifier = learn.TensorFlowEstimator(model_fn=my_model, n_classes=3)
|
||||
classifier = learn.Estimator(model_fn=my_model, n_classes=3)
|
||||
classifier.fit(iris.data, iris.target)
|
||||
score = metrics.accuracy_score(iris.target, classifier.predict(iris.data))
|
||||
print("Accuracy: %f" % score)
|
||||
@ -116,16 +116,16 @@ print("Accuracy: %f" % score)
|
||||
|
||||
## Saving / Restoring models
|
||||
|
||||
Each estimator has a ``save`` method which takes folder path where all model information will be saved. For restoring you can just call ``learn.TensorFlowEstimator.restore(path)`` and it will return object of your class.
|
||||
Each estimator has a ``save`` method which takes folder path where all model information will be saved. For restoring you can just call ``learn.Estimator.restore(path)`` and it will return object of your class.
|
||||
|
||||
Some example code:
|
||||
|
||||
```python
|
||||
classifier = learn.TensorFlowLinearRegression()
|
||||
classifier = learn.LinearRegressor()
|
||||
classifier.fit(...)
|
||||
classifier.save('/tmp/tf_examples/my_model_1/')
|
||||
|
||||
new_classifier = TensorFlowEstimator.restore('/tmp/tf_examples/my_model_2')
|
||||
new_classifier = Estimator.restore('/tmp/tf_examples/my_model_2')
|
||||
new_classifier.predict(...)
|
||||
```
|
||||
|
||||
@ -134,7 +134,7 @@ new_classifier.predict(...)
|
||||
To get nice visualizations and summaries you can use ``logdir`` parameter on ``fit``. It will start writing summaries for ``loss`` and histograms for variables in your model. You can also add custom summaries in your custom model function by calling ``tf.summary`` and passing Tensors to report.
|
||||
|
||||
```python
|
||||
classifier = learn.TensorFlowLinearRegression()
|
||||
classifier = learn.LinearRegressor()
|
||||
classifier.fit(x, y, logdir='/tmp/tf_examples/my_model_1/')
|
||||
```
|
||||
|
||||
|
@ -152,6 +152,11 @@ class DNNClassifier(dnn_linear_combined.DNNLinearCombinedClassifier):
|
||||
gradient_clip_norm=gradient_clip_norm,
|
||||
enable_centered_bias=enable_centered_bias,
|
||||
config=config)
|
||||
self.feature_columns = feature_columns
|
||||
self.optimizer = optimizer
|
||||
self.activation_fn = activation_fn
|
||||
self.dropout = dropout
|
||||
self.hidden_units = hidden_units
|
||||
self._feature_columns_inferred = False
|
||||
|
||||
# TODO(b/29580537): Remove feature_columns inference.
|
||||
@ -299,6 +304,11 @@ class DNNRegressor(dnn_linear_combined.DNNLinearCombinedRegressor):
|
||||
gradient_clip_norm=gradient_clip_norm,
|
||||
enable_centered_bias=enable_centered_bias,
|
||||
config=config)
|
||||
self.feature_columns = feature_columns
|
||||
self.optimizer = optimizer
|
||||
self.activation_fn = activation_fn
|
||||
self.dropout = dropout
|
||||
self.hidden_units = hidden_units
|
||||
self._feature_columns_inferred = False
|
||||
|
||||
# TODO(b/29580537): Remove feature_columns inference.
|
||||
|
@ -21,6 +21,13 @@ from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
# pylint: disable=g-import-not-at-top
|
||||
try:
|
||||
from sklearn.cross_validation import cross_val_score
|
||||
HAS_SKLEARN = True
|
||||
except ImportError:
|
||||
HAS_SKLEARN = False
|
||||
|
||||
|
||||
def _iris_input_fn():
|
||||
iris = tf.contrib.learn.datasets.load_iris()
|
||||
@ -59,6 +66,28 @@ class DNNClassifierTest(tf.test.TestCase):
|
||||
classifier.fit(input_fn=_iris_input_fn, steps=1000)
|
||||
self.assertFalse('centered_bias_weight' in classifier.get_variable_names())
|
||||
|
||||
def testSklearnCompatibility(self):
|
||||
"""Tests compatibility with sklearn"""
|
||||
if not HAS_SKLEARN:
|
||||
return
|
||||
iris = tf.contrib.learn.datasets.load_iris()
|
||||
kwargs = {
|
||||
"n_classes": 3,
|
||||
"optimizer" : "Adam",
|
||||
"hidden_units" : [3, 4]
|
||||
}
|
||||
|
||||
classifier = tf.contrib.learn.DNNClassifier(**kwargs)
|
||||
|
||||
scores = cross_val_score(
|
||||
classifier,
|
||||
iris.data[1:5],
|
||||
iris.target[1:5],
|
||||
scoring="accuracy",
|
||||
fit_params={"steps": 2}
|
||||
)
|
||||
self.assertAllClose(scores, [1, 1, 1])
|
||||
|
||||
|
||||
class DNNRegressorTest(tf.test.TestCase):
|
||||
|
||||
|
@ -20,16 +20,24 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.contrib.learn.python.learn.dataframe import tensorflow_dataframe as df
|
||||
|
||||
# pylint: disable=g-import-not-at-top
|
||||
try:
|
||||
import pandas as pd
|
||||
HAS_PANDAS = True
|
||||
except ImportError:
|
||||
HAS_PANDAS = False
|
||||
|
||||
|
||||
class SumTestCase(tf.test.TestCase):
|
||||
"""Test class for `Sum` transform."""
|
||||
|
||||
def testSum(self):
|
||||
if not HAS_PANDAS:
|
||||
return
|
||||
num_rows = 100
|
||||
|
||||
pandas_df = pd.DataFrame({"a": np.arange(num_rows),
|
||||
|
@ -61,7 +61,7 @@ On Ubuntu, you can do this:
|
||||
```bash
|
||||
sudo apt-get install autoconf automake libtool curl make g++ unzip
|
||||
pushd .
|
||||
cd tensforflow/contrib/makefile/downloads/protobuf
|
||||
cd tensorflow/contrib/makefile/downloads/protobuf
|
||||
./autogen.sh
|
||||
./configure
|
||||
make
|
||||
@ -104,7 +104,7 @@ tensorflow/contrib/makefile/gen/bin/benchmark \
|
||||
## Android
|
||||
|
||||
First, you will need to download and unzip the
|
||||
[Native Development Kit (NDK)](http://developers.google.com/ndk). You will not
|
||||
[Native Development Kit (NDK)](https://developer.android.com/ndk/). You will not
|
||||
need to install the standalone toolchain, however.
|
||||
|
||||
Assign your NDK location to $NDK_ROOT:
|
||||
@ -153,7 +153,7 @@ For more details, see the [benchmark documentation](../../tools/benchmark).
|
||||
## iOS
|
||||
|
||||
_Note: To use this library in an iOS application, see related instructions in
|
||||
the [iOS examples](../ios_examples/] directory._
|
||||
the [iOS examples](../ios_examples/) directory._
|
||||
|
||||
Install XCode 7.3 or more recent. If you have not already, you will need to
|
||||
install the command-line tools using `xcode-select`:
|
||||
@ -189,7 +189,7 @@ benchmark program. Although successfully compiling the benchmark program is a
|
||||
sign of success, the program is not a complete iOS app.
|
||||
|
||||
To see TensorFlow running on iOS, the example Xcode project in
|
||||
[tensorflow/contrib/ios_example](../ios_example) shows how to use the static
|
||||
[tensorflow/contrib/ios_examples](../ios_examples) shows how to use the static
|
||||
library in a simple app.
|
||||
|
||||
### Building by hand
|
||||
@ -227,7 +227,7 @@ benchmark program. Although successfully compiling the benchmark program is a
|
||||
sign of success, the program is not a complete iOS app.
|
||||
|
||||
To see TensorFlow running on iOS, the example Xcode project in
|
||||
[tensorflow/contrib/ios_example](../ios_example) shows how to use the static
|
||||
[tensorflow/contrib/ios_examples](../ios_examples) shows how to use the static
|
||||
library in a simple app.
|
||||
|
||||
#### Universal binaries
|
||||
|
@ -115,7 +115,7 @@ struct TF_Tensor {
|
||||
TensorBuffer* buffer;
|
||||
};
|
||||
|
||||
TF_Tensor* TF_NewTensor(TF_DataType dtype, tensorflow::int64* dims,
|
||||
TF_Tensor* TF_NewTensor(TF_DataType dtype, const tensorflow::int64* dims,
|
||||
int num_dims, void* data, size_t len,
|
||||
void (*deallocator)(void* data, size_t len, void* arg),
|
||||
void* deallocator_arg) {
|
||||
|
@ -187,7 +187,7 @@ typedef struct TF_Tensor TF_Tensor;
|
||||
// (*deallocator)(data, len, deallocator_arg)
|
||||
// Clients must provide a custom deallocator function so they can pass in
|
||||
// memory managed by something like numpy.
|
||||
extern TF_Tensor* TF_NewTensor(TF_DataType, long long* dims, int num_dims,
|
||||
extern TF_Tensor* TF_NewTensor(TF_DataType, const long long* dims, int num_dims,
|
||||
void* data, size_t len,
|
||||
void (*deallocator)(void* data, size_t len,
|
||||
void* arg),
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -54,6 +54,6 @@ val_monitor = monitors.ValidationMonitor(X_val, y_val, every_n_steps=50)
|
||||
classifier = learn.TensorFlowEstimator(model_fn=conv_model, n_classes=10,
|
||||
steps=1000, learning_rate=0.05,
|
||||
batch_size=128)
|
||||
classifier.fit(X_train, y_train, val_monitor)
|
||||
classifier.fit(X_train, y_train, monitors=[val_monitor])
|
||||
score = metrics.accuracy_score(y_test, classifier.predict(X_test))
|
||||
print('Test Accuracy: {0:f}'.format(score))
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1,6 +1,6 @@
|
||||
# encoding: utf-8
|
||||
|
||||
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
|
||||
#t Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1,6 +1,6 @@
|
||||
# encoding: utf-8
|
||||
|
||||
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1,6 +1,6 @@
|
||||
# encoding: utf-8
|
||||
|
||||
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -6,7 +6,7 @@ Course information can be found at https://www.udacity.com/course/deep-learning-
|
||||
Running the Docker container from the Google Cloud repository
|
||||
-------------------------------------------------------------
|
||||
|
||||
docker run -p 8888:8888 -it --rm b.gcr.io/tensorflow-udacity/assignments:0.5.0
|
||||
docker run -p 8888:8888 -it b.gcr.io/tensorflow-udacity/assignments:0.5.0
|
||||
|
||||
Accessing the Notebooks
|
||||
-----------------------
|
||||
@ -19,6 +19,21 @@ On mac, find the virtual machine's IP using:
|
||||
|
||||
Then go to: http://IP:8888 (likely http://192.168.99.100:8888)
|
||||
|
||||
Saving Your Progress
|
||||
--------------------
|
||||
|
||||
Because of the `--rm` flag above, stopping the docker container removes it, so any changes you've made will disappear. One way around this is to remove the `--rm` flag, and name the container for easy restarting:
|
||||
```sh
|
||||
# you only need to "run" the container the first time:
|
||||
docker run -p 8888:8888 -it --name tensorflow-udacity b.gcr.io/tensorflow-udacity/assignments:0.5.0
|
||||
# …do various things…
|
||||
# when you're done, control-C to kill jupyter and stop the container
|
||||
# when you're ready to do more things, you can now just "start" the container:
|
||||
docker start -ai tensorflow-udacity
|
||||
# …do more things…
|
||||
# …repeat…
|
||||
```
|
||||
|
||||
FAQ
|
||||
---
|
||||
|
||||
|
@ -63,7 +63,7 @@ Then, select the correct binary to install:
|
||||
# Ubuntu/Linux 64-bit, CPU only, Python 2.7
|
||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.9.0-cp27-none-linux_x86_64.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7
|
||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7
|
||||
# Requires CUDA toolkit 7.5 and CuDNN v4. For other versions, see "Install from sources" below.
|
||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.9.0-cp27-none-linux_x86_64.whl
|
||||
|
||||
@ -73,14 +73,14 @@ $ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/mac/tensorflow-
|
||||
# Ubuntu/Linux 64-bit, CPU only, Python 3.4
|
||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.9.0-cp34-cp34m-linux_x86_64.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4
|
||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4
|
||||
# Requires CUDA toolkit 7.5 and CuDNN v4. For other versions, see "Install from sources" below.
|
||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.9.0-cp34-cp34m-linux_x86_64.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, CPU only, Python 3.5
|
||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.9.0-cp35-cp35m-linux_x86_64.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.5
|
||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.5
|
||||
# Requires CUDA toolkit 7.5 and CuDNN v4. For other versions, see "Install from sources" below.
|
||||
$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.9.0-cp35-cp35m-linux_x86_64.whl
|
||||
|
||||
@ -153,7 +153,7 @@ Now, install TensorFlow just as you would for a regular Pip installation. First
|
||||
# Ubuntu/Linux 64-bit, CPU only, Python 2.7
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.9.0-cp27-none-linux_x86_64.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7
|
||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7
|
||||
# Requires CUDA toolkit 7.5 and CuDNN v4. For other versions, see "Install from sources" below.
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.9.0-cp27-none-linux_x86_64.whl
|
||||
|
||||
@ -163,14 +163,14 @@ Now, install TensorFlow just as you would for a regular Pip installation. First
|
||||
# Ubuntu/Linux 64-bit, CPU only, Python 3.4
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.9.0-cp34-cp34m-linux_x86_64.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4
|
||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4
|
||||
# Requires CUDA toolkit 7.5 and CuDNN v4. For other versions, see "Install from sources" below.
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.9.0-cp34-cp34m-linux_x86_64.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, CPU only, Python 3.5
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.9.0-cp35-cp35m-linux_x86_64.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.5
|
||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.5
|
||||
# Requires CUDA toolkit 7.5 and CuDNN v4. For other versions, see "Install from sources" below.
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.9.0-cp35-cp35m-linux_x86_64.whl
|
||||
|
||||
@ -277,7 +277,7 @@ Now, install TensorFlow just as you would for a regular Pip installation. First
|
||||
# Ubuntu/Linux 64-bit, CPU only, Python 2.7
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.9.0-cp27-none-linux_x86_64.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7
|
||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 2.7
|
||||
# Requires CUDA toolkit 7.5 and CuDNN v4. For other versions, see "Install from sources" below.
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.9.0-cp27-none-linux_x86_64.whl
|
||||
|
||||
@ -287,14 +287,14 @@ Now, install TensorFlow just as you would for a regular Pip installation. First
|
||||
# Ubuntu/Linux 64-bit, CPU only, Python 3.4
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.9.0-cp34-cp34m-linux_x86_64.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4
|
||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.4
|
||||
# Requires CUDA toolkit 7.5 and CuDNN v4. For other versions, see "Install from sources" below.
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.9.0-cp34-cp34m-linux_x86_64.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, CPU only, Python 3.5
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.9.0-cp35-cp35m-linux_x86_64.whl
|
||||
|
||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.5
|
||||
# Ubuntu/Linux 64-bit, GPU enabled, Python 3.5
|
||||
# Requires CUDA toolkit 7.5 and CuDNN v4. For other versions, see "Install from sources" below.
|
||||
(tensorflow)$ export TF_BINARY_URL=https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.9.0-cp35-cp35m-linux_x86_64.whl
|
||||
|
||||
|
@ -6,7 +6,7 @@ were the top priorities. Using floating point arithmetic was the easiest way to
|
||||
preserve accuracy, and GPUs were well-equipped to accelerate those calculations,
|
||||
so it's natural that not much attention was paid to other numerical formats.
|
||||
|
||||
These days, we actually have a lot of models being being deployed in commercial
|
||||
These days, we actually have a lot of models being deployed in commercial
|
||||
applications. The computation demands of training grow with the number of
|
||||
researchers, but the cycles needed for inference expand in proportion to users.
|
||||
That means pure inference efficiency has become a burning issue for a lot of
|
||||
|
@ -33,8 +33,9 @@ something amazing with TensorFlow, we'd like to hear about it!
|
||||
|
||||
The TensorFlow community has created many great projects around TensorFlow, including:
|
||||
|
||||
* [@jtoy's awesome "Awesome TensorFlow" list of awesome things](https://github.com/jtoy/awesome-tensorflow)
|
||||
* [TensorFlow tutorials](https://github.com/pkmital/tensorflow_tutorials)
|
||||
* [Scikit Flow - Simplified Interface for TensorFlow](https://github.com/tensorflow/skflow)
|
||||
* [Scikit Flow - Simplified Interface for TensorFlow](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/learn/python/learn)
|
||||
* [Caffe to TensorFlow model converter](https://github.com/ethereon/caffe-tensorflow)
|
||||
|
||||
### Development
|
||||
|
@ -232,7 +232,7 @@ print(accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels}))
|
||||
|
||||
## Build a Multilayer Convolutional Network
|
||||
|
||||
Getting 91% accuracy on MNIST is bad. It's almost embarrassingly bad. In this
|
||||
Getting 92% accuracy on MNIST is bad. It's almost embarrassingly bad. In this
|
||||
section, we'll fix that, jumping from a very simple model to something
|
||||
moderately sophisticated: a small convolutional neural network. This will get us
|
||||
to around 99.2% accuracy -- not state of the art, but respectable.
|
||||
@ -243,7 +243,7 @@ To create this model, we're going to need to create a lot of weights and biases.
|
||||
One should generally initialize weights with a small amount of noise for
|
||||
symmetry breaking, and to prevent 0 gradients. Since we're using ReLU neurons,
|
||||
it is also good practice to initialize them with a slightly positive initial
|
||||
bias to avoid "dead neurons." Instead of doing this repeatedly while we build
|
||||
bias to avoid "dead neurons". Instead of doing this repeatedly while we build
|
||||
the model, let's create two handy functions to do it for us.
|
||||
|
||||
```python
|
||||
|
@ -205,7 +205,7 @@ class BaseSession(SessionInterface):
|
||||
|
||||
Use with the `with` keyword to specify that calls to
|
||||
[`Operation.run()`](../../api_docs/python/framework.md#Operation.run) or
|
||||
[`Tensor.run()`](../../api_docs/python/framework.md#Tensor.run) should be
|
||||
[`Tensor.eval()`](../../api_docs/python/framework.md#Tensor.eval) should be
|
||||
executed in this session.
|
||||
|
||||
```python
|
||||
|
@ -215,6 +215,7 @@ class UnaryOpTest(tf.test.TestCase):
|
||||
self._compareBothSparse(z, np.sqrt, tf.sqrt, tol=1e-3)
|
||||
self._compareBothSparse(x, np.tanh, tf.tanh)
|
||||
self._compareBothSparse(y, np.sign, tf.sign)
|
||||
self._compareBothSparse(x, np.vectorize(math.erf), tf.erf)
|
||||
|
||||
def testFloatTanhEdge(self):
|
||||
x = np.arange(40, 40 + 6).reshape(6).astype(np.float32)
|
||||
@ -254,6 +255,7 @@ class UnaryOpTest(tf.test.TestCase):
|
||||
self._compareBothSparse(x, np.sqrt, tf.sqrt, tol=1e-3)
|
||||
self._compareBothSparse(x, np.tanh, tf.tanh)
|
||||
self._compareBothSparse(x, np.sign, tf.sign)
|
||||
self._compareBothSparse(x, np.sign, tf.erf)
|
||||
|
||||
def testDoubleBasic(self):
|
||||
x = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float64)
|
||||
@ -292,6 +294,7 @@ class UnaryOpTest(tf.test.TestCase):
|
||||
self._compareBothSparse(z, np.sqrt, tf.sqrt, tol=1e-3)
|
||||
self._compareBothSparse(x, np.tanh, tf.tanh)
|
||||
self._compareBothSparse(y, np.sign, tf.sign)
|
||||
self._compareBothSparse(x, np.vectorize(math.erf), tf.erf)
|
||||
|
||||
def testHalfBasic(self):
|
||||
x = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float16)
|
||||
@ -325,6 +328,7 @@ class UnaryOpTest(tf.test.TestCase):
|
||||
self._compareBothSparse(z, np.sqrt, tf.sqrt, tol=1e-3)
|
||||
self._compareBothSparse(x, np.tanh, tf.tanh)
|
||||
self._compareBothSparse(y, np.sign, tf.sign)
|
||||
self._compareBothSparse(x, np.vectorize(math.erf), tf.erf, tol=1e-3)
|
||||
|
||||
def testInt32Basic(self):
|
||||
x = np.arange(-6, 6, 2).reshape(1, 3, 2).astype(np.int32)
|
||||
|
@ -348,6 +348,25 @@ def sqrt(x, name=None):
|
||||
return gen_math_ops.sqrt(x, name=name)
|
||||
|
||||
|
||||
def erf(x, name=None):
|
||||
"""Computes the Gauss error function of `x` element-wise.
|
||||
|
||||
Args:
|
||||
x: A `Tensor` of `SparseTensor`. Must be one of the following types: `half`,
|
||||
`float32`, `float64`.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
|
||||
"""
|
||||
with ops.op_scope([x], name, "Erf") as name:
|
||||
if isinstance(x, ops.SparseTensor):
|
||||
x_erf = gen_math_ops.erf(x.values, name=name)
|
||||
return ops.SparseTensor(indices=x.indices, values=x_erf, shape=x.shape)
|
||||
else:
|
||||
return gen_math_ops.erf(x, name=name)
|
||||
|
||||
|
||||
def complex_abs(x, name=None):
|
||||
r"""Computes the complex absolute value of a tensor.
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user