Merge pull request #2503 from caisq/123201123

Push changes from internal: 123201123
This commit is contained in:
Vijay Vasudevan 2016-05-25 10:34:29 -07:00
commit 0050a205bc
14 changed files with 308 additions and 97 deletions

View File

@ -162,7 +162,7 @@ class _DNNLinearCombinedBaseEstimator(estimator.BaseEstimator):
def _run_metrics(self, predictions, targets, metrics, weights):
result = {}
targets = math_ops.cast(targets, predictions.dtype)
for name, metric in six.iteritems(metrics):
for name, metric in six.iteritems(metrics or {}):
if "weights" in inspect.getargspec(metric)[0]:
result[name] = metric(predictions, targets, weights=weights)
else:
@ -211,7 +211,8 @@ class _DNNLinearCombinedBaseEstimator(estimator.BaseEstimator):
def _get_feature_ops_from_example(self, examples_batch):
column_types = layers.create_dict_for_parse_example(
self._get_linear_feature_columns() + self._get_dnn_feature_columns())
(self._get_linear_feature_columns() or []) +
(self._get_dnn_feature_columns() or []))
features = parsing_ops.parse_example(examples_batch, column_types)
return features

View File

@ -1,66 +0,0 @@
"""Generic trainer for TensorFlow models.
This module is deprecated, please use graph_actions.
"""
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python.platform import tf_logging as logging
def train(session,
train_op,
loss,
global_step,
feed_dict_fn,
steps,
monitor,
summary_writer=None,
summaries=None,
feed_params_fn=None):
"""Trains a model for given number of steps, given feed_dict function.
Args:
session: Session object.
train_op: Tensor, trains model.
loss: Tensor, loss value.
global_step: Tensor, global step of the model.
feed_dict_fn: Function that will return a feed dictionary.
steps: Number of steps to run.
monitor: Monitor object to track training progress and induce early
stopping
summary_writer: SummaryWriter object to use for writing summaries.
summaries: Joined object of all summaries that should be ran.
feed_params_fn: Feed params function.
"""
logging.warning("learn.trainer.train is deprecated. "
"Please use learn.graph_actions.train instead.")
for step in xrange(steps):
feed_dict = feed_dict_fn()
if summaries is not None:
global_step_value, loss_value, summ, _ = session.run(
[global_step, loss, summaries, train_op],
feed_dict=feed_dict)
else:
global_step_value, loss_value, _ = session.run(
[global_step, loss, train_op],
feed_dict=feed_dict)
if summaries is not None and summary_writer and summ is not None:
summary_writer.add_summary(summ, global_step_value)

View File

@ -131,13 +131,13 @@ class ResizeBilinearOpGrad : public OpKernel {
const float inverse_x_lerp = (1.0f - x_lerp);
for (int64 c = 0; c < st.channels; ++c) {
output_grad(b, top_y_index, left_x_index, c) +=
input_grad(b, y, x, c) * inverse_y_lerp * inverse_x_lerp;
T(input_grad(b, y, x, c) * inverse_y_lerp * inverse_x_lerp);
output_grad(b, top_y_index, right_x_index, c) +=
input_grad(b, y, x, c) * inverse_y_lerp * x_lerp;
T(input_grad(b, y, x, c) * inverse_y_lerp * x_lerp);
output_grad(b, bottom_y_index, left_x_index, c) +=
input_grad(b, y, x, c) * y_lerp * inverse_x_lerp;
T(input_grad(b, y, x, c) * y_lerp * inverse_x_lerp);
output_grad(b, bottom_y_index, right_x_index, c) +=
input_grad(b, y, x, c) * y_lerp * x_lerp;
T(input_grad(b, y, x, c) * y_lerp * x_lerp);
}
}
}
@ -165,6 +165,9 @@ REGISTER_KERNEL_BUILDER(Name("ResizeBilinearGrad")
ResizeBilinearOpGrad<CPUDevice, float>);
REGISTER_KERNEL_BUILDER(Name("ResizeBilinearGrad")
.Device(DEVICE_CPU)
.TypeConstraint<double>("T"),
ResizeBilinearOpGrad<CPUDevice, double>);
.TypeConstraint<Eigen::half>("T"),
ResizeBilinearOpGrad<CPUDevice, Eigen::half>);
REGISTER_KERNEL_BUILDER(
Name("ResizeBilinearGrad").Device(DEVICE_CPU).TypeConstraint<double>("T"),
ResizeBilinearOpGrad<CPUDevice, double>);
} // namespace tensorflow

View File

@ -15273,6 +15273,44 @@ op {
}
}
}
op {
name: "ResizeArea"
input_arg {
name: "images"
type_attr: "T"
}
input_arg {
name: "size"
type: DT_INT32
}
output_arg {
name: "resized_images"
type: DT_FLOAT
}
attr {
name: "T"
type: "type"
allowed_values {
list {
type: DT_UINT8
type: DT_INT8
type: DT_INT16
type: DT_INT32
type: DT_INT64
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
}
}
}
attr {
name: "align_corners"
type: "bool"
default_value {
b: false
}
}
}
op {
name: "ResizeBicubic"
input_arg {
@ -15368,6 +15406,44 @@ op {
}
}
}
op {
name: "ResizeBicubic"
input_arg {
name: "images"
type_attr: "T"
}
input_arg {
name: "size"
type: DT_INT32
}
output_arg {
name: "resized_images"
type: DT_FLOAT
}
attr {
name: "T"
type: "type"
allowed_values {
list {
type: DT_UINT8
type: DT_INT8
type: DT_INT16
type: DT_INT32
type: DT_INT64
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
}
}
}
attr {
name: "align_corners"
type: "bool"
default_value {
b: false
}
}
}
op {
name: "ResizeBilinear"
input_arg {
@ -15463,6 +15539,44 @@ op {
}
}
}
op {
name: "ResizeBilinear"
input_arg {
name: "images"
type_attr: "T"
}
input_arg {
name: "size"
type: DT_INT32
}
output_arg {
name: "resized_images"
type: DT_FLOAT
}
attr {
name: "T"
type: "type"
allowed_values {
list {
type: DT_UINT8
type: DT_INT8
type: DT_INT16
type: DT_INT32
type: DT_INT64
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
}
}
}
attr {
name: "align_corners"
type: "bool"
default_value {
b: false
}
}
}
op {
name: "ResizeBilinearGrad"
input_arg {
@ -15520,6 +15634,39 @@ op {
}
}
}
op {
name: "ResizeBilinearGrad"
input_arg {
name: "grads"
type: DT_FLOAT
}
input_arg {
name: "original_image"
type_attr: "T"
}
output_arg {
name: "output"
type_attr: "T"
}
attr {
name: "T"
type: "type"
allowed_values {
list {
type: DT_FLOAT
type: DT_HALF
type: DT_DOUBLE
}
}
}
attr {
name: "align_corners"
type: "bool"
default_value {
b: false
}
}
}
op {
name: "ResizeNearestNeighbor"
input_arg {
@ -15615,6 +15762,44 @@ op {
}
}
}
op {
name: "ResizeNearestNeighbor"
input_arg {
name: "images"
type_attr: "T"
}
input_arg {
name: "size"
type: DT_INT32
}
output_arg {
name: "resized_images"
type_attr: "T"
}
attr {
name: "T"
type: "type"
allowed_values {
list {
type: DT_UINT8
type: DT_INT8
type: DT_INT16
type: DT_INT32
type: DT_INT64
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
}
}
}
attr {
name: "align_corners"
type: "bool"
default_value {
b: false
}
}
}
op {
name: "ResizeNearestNeighborGrad"
input_arg {
@ -15678,6 +15863,42 @@ op {
}
}
}
op {
name: "ResizeNearestNeighborGrad"
input_arg {
name: "grads"
type_attr: "T"
}
input_arg {
name: "size"
type: DT_INT32
}
output_arg {
name: "output"
type_attr: "T"
}
attr {
name: "T"
type: "type"
allowed_values {
list {
type: DT_UINT8
type: DT_INT8
type: DT_INT32
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
}
}
}
attr {
name: "align_corners"
type: "bool"
default_value {
b: false
}
}
}
op {
name: "Restore"
input_arg {

View File

@ -22,7 +22,7 @@ REGISTER_OP("ResizeArea")
.Input("images: T")
.Input("size: int32")
.Output("resized_images: float")
.Attr("T: {uint8, int8, int16, int32, int64, float, double}")
.Attr("T: {uint8, int8, int16, int32, int64, half, float, double}")
.Attr("align_corners: bool = false")
.Doc(R"doc(
Resize `images` to `size` using area interpolation.
@ -44,7 +44,7 @@ REGISTER_OP("ResizeBicubic")
.Input("images: T")
.Input("size: int32")
.Output("resized_images: float")
.Attr("T: {uint8, int8, int16, int32, int64, float, double}")
.Attr("T: {uint8, int8, int16, int32, int64, half, float, double}")
.Attr("align_corners: bool = false")
.Doc(R"doc(
Resize `images` to `size` using bicubic interpolation.
@ -66,7 +66,7 @@ REGISTER_OP("ResizeBilinear")
.Input("images: T")
.Input("size: int32")
.Output("resized_images: float")
.Attr("T: {uint8, int8, int16, int32, int64, float, double}")
.Attr("T: {uint8, int8, int16, int32, int64, half, float, double}")
.Attr("align_corners: bool = false")
.Doc(R"doc(
Resize `images` to `size` using bilinear interpolation.
@ -88,7 +88,7 @@ REGISTER_OP("ResizeBilinearGrad")
.Input("grads: float")
.Input("original_image: T")
.Output("output: T")
.Attr("T: {float, double}")
.Attr("T: {float, half, double}")
.Attr("align_corners: bool = false")
.Doc(R"doc(
Computes the gradient of bilinear interpolation.
@ -109,7 +109,7 @@ REGISTER_OP("ResizeNearestNeighbor")
.Input("images: T")
.Input("size: int32")
.Output("resized_images: T")
.Attr("T: {uint8, int8, int16, int32, int64, float, double}")
.Attr("T: {uint8, int8, int16, int32, int64, half, float, double}")
.Attr("align_corners: bool = false")
.Doc(R"doc(
Resize `images` to `size` using nearest neighbor interpolation.
@ -129,7 +129,7 @@ REGISTER_OP("ResizeNearestNeighborGrad")
.Input("grads: T")
.Input("size: int32")
.Output("output: T")
.Attr("T: {uint8, int8, int32, float, double}")
.Attr("T: {uint8, int8, int32, half, float, double}")
.Attr("align_corners: bool = false")
.Doc(R"doc(
Computes the gradient of nearest neighbor interpolation.

View File

@ -9004,6 +9004,7 @@ op {
type: DT_INT16
type: DT_INT32
type: DT_INT64
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
}
@ -9047,6 +9048,7 @@ op {
type: DT_INT16
type: DT_INT32
type: DT_INT64
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
}
@ -9090,6 +9092,7 @@ op {
type: DT_INT16
type: DT_INT32
type: DT_INT64
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
}
@ -9129,6 +9132,7 @@ op {
allowed_values {
list {
type: DT_FLOAT
type: DT_HALF
type: DT_DOUBLE
}
}
@ -9170,6 +9174,7 @@ op {
type: DT_INT16
type: DT_INT32
type: DT_INT64
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
}
@ -9210,6 +9215,7 @@ op {
type: DT_UINT8
type: DT_INT8
type: DT_INT32
type: DT_HALF
type: DT_FLOAT
type: DT_DOUBLE
}

View File

@ -7,7 +7,7 @@ Input images can be of different types but output images are always float.
##### Args:
* <b>`images`</b>: A `Tensor`. Must be one of the following types: `uint8`, `int8`, `int16`, `int32`, `int64`, `float32`, `float64`.
* <b>`images`</b>: A `Tensor`. Must be one of the following types: `uint8`, `int8`, `int16`, `int32`, `int64`, `half`, `float32`, `float64`.
4-D with shape `[batch, height, width, channels]`.
* <b>`size`</b>: A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The
new size for the images.

View File

@ -7,7 +7,7 @@ Input images can be of different types but output images are always float.
##### Args:
* <b>`images`</b>: A `Tensor`. Must be one of the following types: `uint8`, `int8`, `int16`, `int32`, `int64`, `float32`, `float64`.
* <b>`images`</b>: A `Tensor`. Must be one of the following types: `uint8`, `int8`, `int16`, `int32`, `int64`, `half`, `float32`, `float64`.
4-D with shape `[batch, height, width, channels]`.
* <b>`size`</b>: A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The
new size for the images.

View File

@ -7,7 +7,7 @@ Input images can be of different types but output images are always float.
##### Args:
* <b>`images`</b>: A `Tensor`. Must be one of the following types: `uint8`, `int8`, `int16`, `int32`, `int64`, `float32`, `float64`.
* <b>`images`</b>: A `Tensor`. Must be one of the following types: `uint8`, `int8`, `int16`, `int32`, `int64`, `half`, `float32`, `float64`.
4-D with shape `[batch, height, width, channels]`.
* <b>`size`</b>: A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The
new size for the images.

View File

@ -5,7 +5,7 @@ Resize `images` to `size` using nearest neighbor interpolation.
##### Args:
* <b>`images`</b>: A `Tensor`. Must be one of the following types: `uint8`, `int8`, `int16`, `int32`, `int64`, `float32`, `float64`.
* <b>`images`</b>: A `Tensor`. Must be one of the following types: `uint8`, `int8`, `int16`, `int32`, `int64`, `half`, `float32`, `float64`.
4-D with shape `[batch, height, width, channels]`.
* <b>`size`</b>: A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The
new size for the images.

View File

@ -265,7 +265,7 @@ Input images can be of different types but output images are always float.
##### Args:
* <b>`images`</b>: A `Tensor`. Must be one of the following types: `uint8`, `int8`, `int16`, `int32`, `int64`, `float32`, `float64`.
* <b>`images`</b>: A `Tensor`. Must be one of the following types: `uint8`, `int8`, `int16`, `int32`, `int64`, `half`, `float32`, `float64`.
4-D with shape `[batch, height, width, channels]`.
* <b>`size`</b>: A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The
new size for the images.
@ -292,7 +292,7 @@ Input images can be of different types but output images are always float.
##### Args:
* <b>`images`</b>: A `Tensor`. Must be one of the following types: `uint8`, `int8`, `int16`, `int32`, `int64`, `float32`, `float64`.
* <b>`images`</b>: A `Tensor`. Must be one of the following types: `uint8`, `int8`, `int16`, `int32`, `int64`, `half`, `float32`, `float64`.
4-D with shape `[batch, height, width, channels]`.
* <b>`size`</b>: A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The
new size for the images.
@ -319,7 +319,7 @@ Input images can be of different types but output images are always float.
##### Args:
* <b>`images`</b>: A `Tensor`. Must be one of the following types: `uint8`, `int8`, `int16`, `int32`, `int64`, `float32`, `float64`.
* <b>`images`</b>: A `Tensor`. Must be one of the following types: `uint8`, `int8`, `int16`, `int32`, `int64`, `half`, `float32`, `float64`.
4-D with shape `[batch, height, width, channels]`.
* <b>`size`</b>: A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The
new size for the images.
@ -344,7 +344,7 @@ Resize `images` to `size` using nearest neighbor interpolation.
##### Args:
* <b>`images`</b>: A `Tensor`. Must be one of the following types: `uint8`, `int8`, `int16`, `int32`, `int64`, `float32`, `float64`.
* <b>`images`</b>: A `Tensor`. Must be one of the following types: `uint8`, `int8`, `int16`, `int32`, `int64`, `half`, `float32`, `float64`.
4-D with shape `[batch, height, width, channels]`.
* <b>`size`</b>: A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The
new size for the images.

View File

@ -187,7 +187,7 @@ source "${VENV_DIR}/bin/activate" || \
# Force tensorflow reinstallation. Otherwise it may not get installed from
# last build if it had the same version number as previous build.
PIP_FLAGS="--upgrade --no-deps --force-reinstall"
PIP_FLAGS="--upgrade --force-reinstall"
pip install -v ${PIP_FLAGS} ${WHL_PATH} || \
die "pip install (forcing to reinstall tensorflow) FAILED"
echo "Successfully installed pip package ${WHL_PATH}"

View File

@ -39,12 +39,10 @@ apt-get install -y \
openjdk-8-jre-headless \
pkg-config \
python-dev \
python-numpy \
python-pandas \
python-pip \
python-virtualenv \
python3-dev \
python3-numpy \
python3-pandas \
python3-pip \
sudo \

View File

@ -16,13 +16,52 @@
set -e
# Use pip to install scipy to get the latest version, instead of 0.13 through
# apt-get
pip install scipy==0.15.1
pip3 install scipy==0.15.1
# Install pip packages from whl files to avoid the time-consuming process of
# building from source.
pip install sklearn
pip3 install scikit-learn
# Use pip to install numpy to the latest version, instead of 1.8.2 through
# apt-get
wget -q https://pypi.python.org/packages/06/92/3c786303889e6246971ad4c48ac2b4e37a1b1c67c0dc2106dc85cb15c18e/numpy-1.11.0-cp27-cp27mu-manylinux1_x86_64.whl#md5=6ffb66ff78c28c55bfa09a2ceee487df
mv numpy-1.11.0-cp27-cp27mu-manylinux1_x86_64.whl \
numpy-1.11.0-cp27-none-linux_x86_64.whl
pip install numpy-1.11.0-cp27-none-linux_x86_64.whl
rm numpy-1.11.0-cp27-none-linux_x86_64.whl
wget -q https://pypi.python.org/packages/ea/ca/5e48a68be496e6f79c3c8d90f7c03ea09bbb154ea4511f5b3d6c825cefe5/numpy-1.11.0-cp34-cp34m-manylinux1_x86_64.whl#md5=08a002aeffa20354aa5045eadb549361
mv numpy-1.11.0-cp34-cp34m-manylinux1_x86_64.whl \
numpy-1.11.0-cp34-cp34m-linux_x86_64.whl
pip3 install numpy-1.11.0-cp34-cp34m-linux_x86_64.whl
rm numpy-1.11.0-cp34-cp34m-linux_x86_64.whl
# Use pip to install scipy to get the latest version, instead of 0.13 through
# apt-get.
# pip install scipy==0.17.1
wget -q https://pypi.python.org/packages/8a/de/326cf31a5a3ba0c01c40cdd78f7140b0510ed80e6d5ec5b2ec173c72df03/scipy-0.17.1-cp27-cp27mu-manylinux1_x86_64.whl#md5=8d0df61ceba78a2796f8d90fc979576f
mv scipy-0.17.1-cp27-cp27mu-manylinux1_x86_64.whl \
scipy-0.17.1-cp27-none-linux_x86_64.whl
pip install scipy-0.17.1-cp27-none-linux_x86_64.whl
rm scipy-0.17.1-cp27-none-linux_x86_64.whl
# pip3 install scipy==0.17.1
wget -q https://pypi.python.org/packages/eb/2e/76aff3b25dd06cab06622f82a4790ff5002ab686e940847bb2503b4b2122/scipy-0.17.1-cp34-cp34m-manylinux1_x86_64.whl#md5=bb39b9e1d16fa220967ad7edd39a8b28
mv scipy-0.17.1-cp34-cp34m-manylinux1_x86_64.whl \
scipy-0.17.1-cp34-cp34m-linux_x86_64.whl
pip3 install scipy-0.17.1-cp34-cp34m-linux_x86_64.whl
rm scipy-0.17.1-cp34-cp34m-linux_x86_64.whl
# pip install sklearn
wget -q https://pypi.python.org/packages/bf/80/06e77e5a682c46a3880ec487a5f9d910f5c8d919df9aca58052089687c7e/scikit_learn-0.17.1-cp27-cp27mu-manylinux1_x86_64.whl#md5=337b91f502138ba7fd722803138f6dfd
mv scikit_learn-0.17.1-cp27-cp27mu-manylinux1_x86_64.whl \
scikit_learn-0.17.1-cp27-none-linux_x86_64.whl
pip install scikit_learn-0.17.1-cp27-none-linux_x86_64.whl
rm scikit_learn-0.17.1-cp27-none-linux_x86_64.whl
# pip3 install scikit-learn
wget -q https://pypi.python.org/packages/7e/f1/1cc8a1ae2b4de89bff0981aee904ff05779c49a4c660fa38178f9772d3a7/scikit_learn-0.17.1-cp34-cp34m-manylinux1_x86_64.whl#md5=a722a7372b64ec9f7b49a2532d21372b
mv scikit_learn-0.17.1-cp34-cp34m-manylinux1_x86_64.whl \
scikit_learn-0.17.1-cp34-cp34m-linux_x86_64.whl
pip3 install scikit_learn-0.17.1-cp34-cp34m-linux_x86_64.whl
rm scikit_learn-0.17.1-cp34-cp34m-linux_x86_64.whl
# Benchmark tests require the following:
pip install psutil
@ -33,3 +72,12 @@ pip3 install py-cpuinfo
# pylint tests require the following:
pip install pylint
pip3 install pylint
# Remove packages in /usr/lib/python* that may interfere with packages in
# /usr/local/lib. These packages may get installed inadvertantly with packages
# such as apt-get python-pandas. Their older versions can mask the more recent
# versions installed above with pip and cause test failures.
rm -rf /usr/lib/python2.7/dist-packages/numpy \
/usr/lib/python2.7/dist-packages/scipy \
/usr/lib/python3/dist-packages/numpy \
/usr/lib/python3/dist-packages/scipy