diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 9bdb702f716..6976a372983 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -71,6 +71,7 @@ filegroup( name = "all_opensource_files", data = [ ":all_files", + "//tensorflow/c:all_files", "//tensorflow/cc:all_files", "//tensorflow/contrib:all_files", "//tensorflow/contrib/copy_graph:all_files", diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD new file mode 100644 index 00000000000..fec8ca759ec --- /dev/null +++ b/tensorflow/c/BUILD @@ -0,0 +1,95 @@ +# Description: +# C API for TensorFlow, for use by client language bindings. + +licenses(["notice"]) # Apache 2.0 + +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", + "tf_cuda_library", +) + +# For platform specific build config +load( + "//tensorflow/core:platform/default/build_config.bzl", + "tf_kernel_tests_linkstatic", +) + +# ----------------------------------------------------------------------------- +# Public targets + +tf_cuda_library( + name = "c_api", + srcs = ["c_api.cc"], + hdrs = ["c_api.h"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], +) + +tf_cuda_library( + name = "tf_status_helper", + srcs = ["tf_status_helper.cc"], + hdrs = ["tf_status_helper.h"], + visibility = ["//visibility:public"], + deps = [ + ":c_api", + "//tensorflow/core:lib", + ], +) + +tf_cuda_library( + name = "checkpoint_reader", + srcs = ["checkpoint_reader.cc"], + hdrs = ["checkpoint_reader.h"], + visibility = ["//visibility:public"], + deps = [ + ":tf_status_helper", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], +) + +# ----------------------------------------------------------------------------- +# Tests + +tf_cc_test( + name = "c_api_test", + size = "small", + linkopts = select({ + "//tensorflow:darwin": ["-headerpad_max_install_names"], + "//conditions:default": [], + }), + linkstatic = tf_kernel_tests_linkstatic(), + deps = [ + ":c_api", + "//tensorflow/core:direct_session", + "//tensorflow/core:framework", + "//tensorflow/core:proto_text", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/kernels:array", + "//tensorflow/core/kernels:math", + "//third_party/eigen3", + ], +) + +# ----------------------------------------------------------------------------- +# Google-internal targets. + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/core/client/tensor_c_api.cc b/tensorflow/c/c_api.cc similarity index 99% rename from tensorflow/core/client/tensor_c_api.cc rename to tensorflow/c/c_api.cc index 99e5d796817..54c33c9dffa 100644 --- a/tensorflow/core/client/tensor_c_api.cc +++ b/tensorflow/c/c_api.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/public/tensor_c_api.h" +#include "tensorflow/c/c_api.h" #include #include diff --git a/tensorflow/core/public/tensor_c_api.h b/tensorflow/c/c_api.h similarity index 99% rename from tensorflow/core/public/tensor_c_api.h rename to tensorflow/c/c_api.h index 9f4f7adde91..9d0b979bb94 100644 --- a/tensorflow/core/public/tensor_c_api.h +++ b/tensorflow/c/c_api.h @@ -13,9 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// TODO(jeff,sanjay): Rename to tensorflow/public/c_api.h -#ifndef TENSORFLOW_PUBLIC_TENSOR_C_API_H_ -#define TENSORFLOW_PUBLIC_TENSOR_C_API_H_ +#ifndef TENSORFLOW_C_C_API_H_ +#define TENSORFLOW_C_C_API_H_ #include #include @@ -699,4 +698,4 @@ extern TF_Buffer TF_GetOpList(TF_Library* lib_handle); } /* end extern "C" */ #endif -#endif // TENSORFLOW_PUBLIC_TENSOR_C_API_H_ +#endif // TENSORFLOW_C_C_API_H_ diff --git a/tensorflow/core/client/tensor_c_api_test.cc b/tensorflow/c/c_api_test.cc similarity index 99% rename from tensorflow/core/client/tensor_c_api_test.cc rename to tensorflow/c/c_api_test.cc index 0bbc22495aa..23963caba7e 100644 --- a/tensorflow/core/client/tensor_c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/public/tensor_c_api.h" +#include "tensorflow/c/c_api.h" #include #include "tensorflow/core/framework/graph.pb_text.h" diff --git a/tensorflow/core/util/checkpoint_reader.cc b/tensorflow/c/checkpoint_reader.cc similarity index 97% rename from tensorflow/core/util/checkpoint_reader.cc rename to tensorflow/c/checkpoint_reader.cc index ba252ecc926..dd9cb225598 100644 --- a/tensorflow/core/util/checkpoint_reader.cc +++ b/tensorflow/c/checkpoint_reader.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/util/checkpoint_reader.h" +#include "tensorflow/c/checkpoint_reader.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/env.h" diff --git a/tensorflow/core/util/checkpoint_reader.h b/tensorflow/c/checkpoint_reader.h similarity index 90% rename from tensorflow/core/util/checkpoint_reader.h rename to tensorflow/c/checkpoint_reader.h index 65d1949ef49..fb06d6d8640 100644 --- a/tensorflow/core/util/checkpoint_reader.h +++ b/tensorflow/c/checkpoint_reader.h @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CORE_UTIL_CHECKPOINT_READER_H -#define TENSORFLOW_CORE_UTIL_CHECKPOINT_READER_H +#ifndef TENSORFLOW_C_CHECKPOINT_READER_H +#define TENSORFLOW_C_CHECKPOINT_READER_H +#include "tensorflow/c/tf_status_helper.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/tensor_slice_reader.h" -#include "tensorflow/core/util/tf_status_helper.h" namespace tensorflow { @@ -60,4 +60,4 @@ class CheckpointReader { } // namespace checkpoint } // namespace tensorflow -#endif // TENSORFLOW_CORE_UTIL_CHECKPOINT_READER_H +#endif // TENSORFLOW_C_CHECKPOINT_READER_H diff --git a/tensorflow/core/util/tf_status_helper.cc b/tensorflow/c/tf_status_helper.cc similarity index 98% rename from tensorflow/core/util/tf_status_helper.cc rename to tensorflow/c/tf_status_helper.cc index d119b9845cf..747fd672f08 100644 --- a/tensorflow/core/util/tf_status_helper.cc +++ b/tensorflow/c/tf_status_helper.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/util/tf_status_helper.h" +#include "tensorflow/c/tf_status_helper.h" namespace tensorflow { diff --git a/tensorflow/core/util/tf_status_helper.h b/tensorflow/c/tf_status_helper.h similarity index 82% rename from tensorflow/core/util/tf_status_helper.h rename to tensorflow/c/tf_status_helper.h index b3cea3072c4..4bc56f9cb40 100644 --- a/tensorflow/core/util/tf_status_helper.h +++ b/tensorflow/c/tf_status_helper.h @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CORE_UTIL_TF_STATUS_HELPER_H -#define TENSORFLOW_CORE_UTIL_TF_STATUS_HELPER_H +#ifndef TENSORFLOW_C_TF_STATUS_HELPER_H +#define TENSORFLOW_C_TF_STATUS_HELPER_H +#include "tensorflow/c/c_api.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/public/tensor_c_api.h" namespace tensorflow { @@ -26,4 +26,4 @@ void Set_TF_Status_from_Status(TF_Status* tf_status, const Status& status); } // namespace tensorflow -#endif // TENSORFLOW_CORE_UTIL_TF_STATUS_HELPER_H +#endif // TENSORFLOW_C_TF_STATUS_HELPER_H diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index 626f4589ea5..3fd428e1220 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -263,6 +263,28 @@ cuda_py_tests( ], ) +cuda_py_tests( + name = "shape_test", + size = "small", + srcs = ["python/kernel_tests/shape_test.py"], + additional_deps = [ + ":distributions_py", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + +cuda_py_tests( + name = "bijector_test", + size = "small", + srcs = ["python/kernel_tests/bijector_test.py"], + additional_deps = [ + ":distributions_py", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijector_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijector_test.py new file mode 100644 index 00000000000..fd2cf58fd29 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijector_test.py @@ -0,0 +1,67 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +import tensorflow as tf + +from tensorflow.contrib.distributions.python.ops.bijector import _Exp # pylint: disable=line-too-long +from tensorflow.contrib.distributions.python.ops.bijector import _Identity # pylint: disable=line-too-long +from tensorflow.contrib.distributions.python.ops.shape import _ShapeUtil # pylint: disable=line-too-long + + +class IdentityBijectorTest(tf.test.TestCase): + """Tests the correctness of the Y = g(X) = X transformation.""" + + def testBijector(self): + with self.test_session(): + bijector = _Identity(_ShapeUtil(batch_ndims=1, event_ndims=1)) + self.assertEqual(bijector.name, 'Identity') + x = [[[0.], [1]]] + self.assertAllEqual(bijector.forward(x).eval(), x) + self.assertAllEqual(bijector.inverse(x).eval(), x) + self.assertAllEqual(bijector.inverse_log_det_jacobian(x).eval(), + [[0., 0]]) + rev, jac = bijector.inverse_and_inverse_log_det_jacobian(x) + self.assertAllEqual(rev.eval(), x) + self.assertAllEqual(jac.eval(), [[0., 0]]) + + +class ExpBijectorTest(tf.test.TestCase): + """Tests the correctness of the Y = g(X) = exp(X) transformation.""" + + def testBijector(self): + with self.test_session(): + bijector = _Exp(_ShapeUtil(batch_ndims=1, event_ndims=1)) + self.assertEqual(bijector.name, 'Exp') + x = [[[1.], [2]]] + self.assertAllClose(bijector.forward(x).eval(), + [[[math.exp(1.)], [math.exp(2.)]]]) + self.assertAllClose(bijector.inverse(x).eval(), + [[[math.log(1.)], [math.log(2.)]]]) + self.assertAllClose(bijector.inverse_log_det_jacobian(x).eval(), + [[0., -math.log(2.)]]) + rev, jac = bijector.inverse_and_inverse_log_det_jacobian(x) + self.assertAllClose(rev.eval(), [[[math.log(1.)], [math.log(2.)]]]) + self.assertAllClose(jac.eval(), [[0., -math.log(2.)]]) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/shape_test.py b/tensorflow/contrib/distributions/python/kernel_tests/shape_test.py new file mode 100644 index 00000000000..351c69c747f --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/shape_test.py @@ -0,0 +1,165 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for ShapeUtil.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf +from tensorflow.contrib.distributions.python.ops.shape import _ShapeUtil # pylint: disable=line-too-long + + +class ShapeUtilTest(tf.test.TestCase): + + def testShapeUtilGetNdims(self): + with self.test_session(): + shaper = _ShapeUtil(batch_ndims=0, event_ndims=0) + x = 1 + self.assertEqual(shaper.get_sample_ndims(x), 0) + self.assertEqual(shaper.batch_ndims, 0) + self.assertEqual(shaper.event_ndims, 0) + + shaper = _ShapeUtil(batch_ndims=1, event_ndims=1) + x = [[[0., 1, 2], [3, 4, 5]]] + self.assertAllEqual(shaper.get_ndims(x), 3) + self.assertEqual(shaper.get_sample_ndims(x), 1) + self.assertEqual(shaper.batch_ndims, 1) + self.assertEqual(shaper.event_ndims, 1) + + x += [[[6, 7, 8], [9, 10, 11]]] + self.assertAllEqual(shaper.get_ndims(x), 3) + self.assertEqual(shaper.get_sample_ndims(x), 1) + self.assertEqual(shaper.batch_ndims, 1) + self.assertEqual(shaper.event_ndims, 1) + + # Test ndims functions work, even despite unfed Tensors. + y = tf.placeholder(tf.float32, shape=(1024, None, 1024)) + self.assertAllEqual(shaper.get_ndims(y), 3) + self.assertEqual(shaper.get_sample_ndims(y), 1) + self.assertEqual(shaper.batch_ndims, 1) + self.assertEqual(shaper.event_ndims, 1) + + with self.assertRaises(ValueError): + y = tf.placeholder(tf.float32) + shaper.get_ndims(y) + + def testShapeUtilGetDims(self): + with self.test_session(): + shaper = _ShapeUtil(batch_ndims=0, event_ndims=0) + with self.assertRaises(ValueError): + y = tf.placeholder(tf.float32) + shaper.get_sample_dims(y) + with self.assertRaises(ValueError): + y = tf.placeholder(tf.float32) + shaper.get_batch_dims(y) + with self.assertRaises(ValueError): + y = tf.placeholder(tf.float32) + shaper.get_event_dims(y) + + shaper = _ShapeUtil(batch_ndims=0, event_ndims=0) + x = 1 + self.assertAllEqual(shaper.get_sample_dims(x), []) + self.assertAllEqual(shaper.get_batch_dims(x), []) + self.assertAllEqual(shaper.get_event_dims(x), []) + self.assertAllEqual(shaper.get_dims(x, sample=False), []) + + shaper = _ShapeUtil(batch_ndims=1, event_ndims=2) + x = [[[[0., 1], [2, 4]]]] + self.assertAllEqual(shaper.get_sample_dims(x), [0]) + self.assertAllEqual(shaper.get_batch_dims(x), [1]) + self.assertAllEqual(shaper.get_event_dims(x), [2, 3]) + self.assertAllEqual(shaper.get_dims(x, sample=False), [1, 2, 3]) + + x += x + self.assertAllEqual(shaper.get_sample_dims(x), [0]) + self.assertAllEqual(shaper.get_batch_dims(x), [1]) + self.assertAllEqual(shaper.get_event_dims(x), [2, 3]) + self.assertAllEqual(shaper.get_dims(x, sample=False), [1, 2, 3]) + + # Test dims functions work, despite unfed Tensors. + y = tf.placeholder(tf.float32, shape=(1024, None, 5, 5)) + self.assertAllEqual(shaper.get_sample_dims(y), [0]) + self.assertAllEqual(shaper.get_batch_dims(y), [1]) + self.assertAllEqual(shaper.get_event_dims(y), [2, 3]) + + def testShapeUtilGetShape(self): + with self.test_session() as sess: + shaper = _ShapeUtil(batch_ndims=0, event_ndims=0) + with self.assertRaises(ValueError): + y = tf.placeholder(tf.float32) + shaper.get_sample_shape(y) + with self.assertRaises(ValueError): + y = tf.placeholder(tf.float32) + shaper.get_batch_shape(y) + with self.assertRaises(ValueError): + y = tf.placeholder(tf.float32) + shaper.get_event_shape(y) + + shaper = _ShapeUtil(batch_ndims=0, event_ndims=0) + x = 1 + self.assertAllEqual(shaper.get_sample_shape(x), []) + self.assertAllEqual(shaper.get_batch_shape(x), []) + self.assertAllEqual(shaper.get_event_shape(x), []) + self.assertAllEqual(shaper.get_shape(x, batch=False), []) + + shaper = _ShapeUtil(batch_ndims=1, event_ndims=1) + x = [[[0., 1, 2], [3, 4, 5]]] + self.assertAllEqual(shaper.get_sample_shape(x), [1]) + self.assertAllEqual(shaper.get_batch_shape(x), [2]) + self.assertAllEqual(shaper.get_event_shape(x), [3]) + self.assertAllEqual(shaper.get_shape(x, batch=False), [1, 3]) + + x += [[[6, 7, 8], [9, 10, 11]]] + self.assertAllEqual(shaper.get_sample_shape(x), [2]) + self.assertAllEqual(shaper.get_batch_shape(x), [2]) + self.assertAllEqual(shaper.get_event_shape(x), [3]) + self.assertAllEqual(shaper.get_shape(x, batch=False), [2, 3]) + + shaper = _ShapeUtil(batch_ndims=0, event_ndims=1) + x = tf.ones((3, 2)) + self.assertAllEqual(shaper.get_shape(x, sample=False), (2,)) + + def feed_eval(fun, build_shape=(None, None, 2), graph_shape=(3, 4, 2)): + """Helper to use a deferred-shape tensor eval'ed at graph runtime.""" + y = tf.placeholder(tf.int32, shape=build_shape) + y_value = np.ones(graph_shape, dtype=y.dtype.as_numpy_dtype()) + return sess.run(fun(y), + feed_dict={y: y_value}) + + shaper = _ShapeUtil(batch_ndims=1, event_ndims=1) + self.assertAllEqual(feed_eval(shaper.get_sample_shape), [3]) + self.assertAllEqual(feed_eval(shaper.get_batch_shape), [4]) + self.assertAllEqual(feed_eval(shaper.get_event_shape), [2]) + self.assertAllEqual( + feed_eval(lambda y: shaper.get_shape(y, batch=False)), + [3, 2]) + + shaper = _ShapeUtil(batch_ndims=0, event_ndims=1) + self.assertAllEqual( + feed_eval(lambda y: shaper.get_shape(y, batch=False), + (None, None), + (3, 2)), + [3, 2]) + self.assertAllEqual( + feed_eval(lambda y: shaper.get_shape(y, sample=False), + (None, None), + (3, 2)), + [2]) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/contrib/distributions/python/ops/bijector.py b/tensorflow/contrib/distributions/python/ops/bijector.py new file mode 100644 index 00000000000..ff54b6d386e --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijector.py @@ -0,0 +1,350 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""An API for reversible (bijective) transformations of random variables.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops + + +class _Bijector(object): + """An interface for transforming random variable(s). + + A bijector is characterized by three operations: + + 1) Forward Evaluation + Useful for turning one random outcome into another random outcome from a + different distribution. + + 2) Inverse Evaluation + Useful for "reversing" a transformation to compute one probability in terms + of another. + + 3) (log o det o Jacobian o inverse)(x) + "The log of the determinant of the matrix of all first-order partial + derivatives of the inverse function." + Useful for inverting a transformation to compute one probability in terms + of another. Geometrically, the det(Jacobian) is the volume of the + transformation and is used to scale the probability. + + By convention, transformations of random variables are named in terms of the + forward transformation. The forward transformation creates samples, the + inverse is useful for computing probabilities. + + Example transformations: + "Exponential" + + ``` + Y = g(X) = exp(X) + X ~ Normal(0, 1) # Univariate. + ``` + + Implies: + + ``` + g^{-1}(Y) = log(Y) + |Jacobian(g^{-1})(y)| = 1 / y + Y ~ LogNormal(0, 1), i.e., + prob(Y=y) = |Jacobian(g^{-1})(y)| * prob(X=g^{-1}(y)) + = (1 / y) Normal(log(y); 0, 1) + ``` + + "ShiftAndScale" + + ``` + Y = g(X) = sqrtSigma * X + mu + X ~ MultivariateNormal(0, I_d) + ``` + + Implies: + + ``` + g^{-1}(Y) = inv(sqrtSigma) * (Y - mu) + |Jacobian(g^{-1})(y)| = det(inv(sqrtSigma)) + Y ~ MultivariateNormal(mu, sqrtSigma) , i.e., + prob(Y=y) = |Jacobian(g^{-1})(y)| * prob(X=g^{-1}(y)) + = det(sqrtSigma)^(-d) * + MultivariateNormal(inv(sqrtSigma) * (y - mu); 0, I_d) + ``` + + Example use: + Basic properties: + + ```python + x = ... # A tensor. + # Evaluate forward transformation. + fwd_x = my_bijector.forward(x) + x != my_bijector.forward(fwd_x) # Not equal because g(x) != g(g(x)). + x == my_bijector.inverse(fwd_x) + ``` + + Computing a log-likelihood: + + ```python + def transformed_log_pdf(bijector, log_pdf, x): + return (bijector.inverse_log_det_jacobian(x) + + log_pdf(bijector.inverse(x))) + ``` + + Transforming a random outcome: + + ```python + def transformed_sample(bijector, x): + return bijector.forward(x) + ``` + + """ + + # TODO(b/30476956): Try to remove constructor dependence on shape util. + def __init__(self, shaper=None, name=None): + """Constructs Bijector. + + A bijector transforms random variables into new random variables. Managing + shape is typically an important piece of this so a Bijector is usually + composed of ShapeUtil. The ShapeUtil object handles input shape checks as + well as reshaping/transposing for easier linear algebra operations. + + Example: + ```python + # Create the Y = g(X) = X transform which operates on 4-Tensors of vectors. + identity = Identity(ShapeUtil(batch_ndims=4, event_ndims=1)) + + # Create the Y = g(X) = exp(X) transform which operates on matrices. + exp = Exp(ShapeUtil(batch_ndims=0, event_ndims=2)) + ``` + + See Bijector subclass doc for more details and examples. + + Args: + shaper: object used for managing and manipulating shape, typically an + instance of ShapeUtil. + name: The name to give Ops created by the initializer. + """ + self._shaper = shaper + self._name = name or type(self).__name__ + + @property + def shaper(self): + """Returns shape object used to manage shape constraints.""" + return self._shaper + + @property + def name(self): + """Returns the string name of this bijector.""" + return self._name + + def forward(self, x, name='forward'): + """Returns the forward bijector evaluation, i.e., X = g(Y). + + Args: + x: `Tensor`. The input to the "forward" evaluation. + name: The name to give this op. + + Returns: + `Tensor`. + """ + with ops.name_scope(self.name): + with ops.op_scope([x], name): + x = ops.convert_to_tensor(x) + return self._forward(x) + + def inverse(self, x, name='inverse'): + """Returns the inverse bijector evaluation, i.e., X = g^{-1}(Y). + + Args: + x: `Tensor`. The input to the "inverse" evaluation. + name: The name to give this op. + + Returns: + `Tensor`. + """ + with ops.name_scope(self.name): + with ops.op_scope([x], name): + x = ops.convert_to_tensor(x) + try: + return self._inverse(x) + except NotImplementedError: + return self._inverse_and_inverse_log_det_jacobian(x)[0] + + def inverse_log_det_jacobian(self, x, name='inverse_log_det_jacobian'): + """Returns the (log o det o Jacobian o inverse)(x). + + Mathematically, returns: log(det(dY/dX g^{-1}))(Y). + + Args: + x: `Tensor`. The input to the "inverse" Jacobian evaluation. + name: The name to give this op. + + Returns: + `Tensor`. + """ + with ops.name_scope(self.name): + with ops.op_scope([x], name): + x = ops.convert_to_tensor(x) + try: + return self._inverse_log_det_jacobian(x) + except NotImplementedError: + return self._inverse_and_inverse_log_det_jacobian(x)[1] + + def inverse_and_inverse_log_det_jacobian( + self, x, name='inverse_and_inverse_log_det_jacobian'): + """Returns both the inverse evaluation and inverse_log_det_jacobian. + + Enables possibly more efficient calculation when both inverse and + corresponding Jacobian are needed. + + See `inverse()`, `inverse_log_det_jacobian()` for more details. + + Args: + x: `Tensor`. The input to the "inverse" Jacobian evaluation. + name: The name to give this op. + + Returns: + `Tensor`. + """ + with ops.name_scope(self.name): + with ops.op_scope([x], name): + x = ops.convert_to_tensor(x) + try: + return self._inverse_and_inverse_log_det_jacobian(x) + except NotImplementedError: + return self._inverse(x), self._inverse_log_det_jacobian(x) + + # Subclass interface. + def _forward(self, x): + """Subclass implementation of forward(). + + Args: + x: `Tensor`. The input to the "forward" evaluation. + + Raises: + `NotImplementedError`: if subclass implementation not provided + + Returns: + `Tensor`. + """ + raise NotImplementedError('_forward not implemented') + + def _inverse(self, x): + """Subclass implementation of inverse(). + + Args: + x: `Tensor`. The input to the "inverse" evaluation. + + Raises: + `NotImplementedError`: if subclass implementation not provided + + Returns: + `Tensor`. + """ + raise NotImplementedError('_inverse not implemented') + + def _inverse_log_det_jacobian(self, x): + """Subclass implementation of inverse_log_det_jacobian(). + + Args: + x: `Tensor`. The input to the "inverse" Jacobian evaluation. + + Raises: + `NotImplementedError`: if subclass implementation not provided + + Returns: + `Tensor`. + """ + raise NotImplementedError('_inverse_log_det_jacobian not implemented') + + def _inverse_and_inverse_log_det_jacobian(self, x): + """Subclass implementation of inverse_and_inverse_log_det_jacobian(). + + Args: + x: `Tensor`. The input to the "inverse" evaluation. + + Returns: + List of two `Tensor` items, inverse and inverse_log_det_jacobian. + """ + raise NotImplementedError( + '_inverse_and_inverse_log_det_jacobian not implemented') + + +class _Identity(_Bijector): + """Bijector which computes Y = g(X) = X. + + Example Use: + ```python + # Create the Y=g(X)=X transform which works only on Tensors with 1 batch + # ndims and 1 event ndim (i.e., vector of vectors). + identity = Identity(ShapeUtil(batch_ndims=1, event_ndims=1)) + x = [[1., 2], + [3, 4]] + x == identity.forward(x) == identity.inverse(x) + ``` + + """ + + # TODO(b/30476956): Try to remove constructor dependence on shape util. + def __init__(self, shaper=None, name='Identity'): + super(_Identity, self).__init__(shaper, name) + + def _forward(self, x): + return x + + def _inverse(self, x): + return x + + def _inverse_log_det_jacobian(self, x): + result_shape = self.shaper.get_shape( + x, sample=True, batch=True, event=False) + return array_ops.zeros(result_shape, dtype=x.dtype) + + +class _Exp(_Bijector): + """Bijector which computes Y = g(X) = exp(X). + + Example Use: + ```python + # Create the Y=g(X)=exp(X) transform which works only on Tensors with 1 + # batch ndims and 2 event ndim (i.e., vector of matrices). + exp = Exp(ShapeUtil(batch_ndims=1, event_ndims=2)) + x = [[[1., 2], + [3, 4]], + [[5, 6], + [7, 8]]] + exp(x) == exp.forward(x) + log(x) == exp.inverse(x) + ``` + + """ + + # TODO(b/30476956): Try to remove constructor dependence on shape util. + def __init__(self, shaper=None, name='Exp'): + super(_Exp, self).__init__(shaper, name) + + def _forward(self, x): + return math_ops.exp(x) + + def _inverse(self, x): + return math_ops.log(x) + + def _inverse_log_det_jacobian(self, x): + d = self.shaper.get_event_dims(x) + return -math_ops.reduce_sum(math_ops.log(x), d) + + def _inverse_and_inverse_log_det_jacobian(self, x): + y = math_ops.log(x) + d = self.shaper.get_event_dims(x) + return y, -math_ops.reduce_sum(y, d) diff --git a/tensorflow/contrib/distributions/python/ops/shape.py b/tensorflow/contrib/distributions/python/ops/shape.py new file mode 100644 index 00000000000..2030ce22cd5 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/shape.py @@ -0,0 +1,396 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A helper class for inferring Distribution shape.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops + + +class _ShapeUtil(object): + """Class which helps infer/identify subsets of tensor dimensions. + + Terminology: + Recall that a `Tensor` has: + shape: sizes of tensor dimensions, + ndims: size of shape; number of tensor dimensions, + dims: indexes into shape; useful for transpose, reduce. + + Tensors sampled from a `Distribution` can be partitioned by: + sample dims: indexes independent, identically distributed (iid) draws, + batch dims: indexes non-identical draws, + event dims: indexes coordinates of a single draw. + + The sample, batch, and event dimensions constitute the entirety of a + `Tensor` shape. The dimensions are always in sample, batch, event order. + + Assumptions: + We assume that batch_ndims and event_ndims are statically known for both + creating this object and for inputs to its functions. + TODO(jvdillon): Relax this assumption and support fully unknown shape. + + We also assume that the `Tensor` rank is static, i.e., `x.get_shape().ndims + is not None`. + + Possible use-cases: + ~ Sample dimensions: + Computing summary statistics, i.e., the average is a reduction over sample + dimensions. + + ~ Batch dimensions: + Log-likelihood under model predicted location: + ```python + mu = ... # vector of predictions, one for each covariate. + neg_log_likelihood = -tf.reduce_mean( + Normal(loc=mu, scale=1).log_pdf(x), + reduce_dims=[0]) + ``` + + Monte Carlo estimation of a marginal probability: + Average over batch dimensions where batch dimensions are associated with + random draws of a prior. + E.g., suppose we want to find the Monte Carlo estimate of the marginal + distribution of a Normal with a random Laplace location: + ``` + P(X=x) = integral P(X=x|y) P(Y=y) dy + ~= 1/n sum_{i=1}^n P(X=x|y_i), y_i ~iid Laplace(0,1) + = tf.reduce_mean(Normal(loc=Laplace(0, 1).sample_n(n=1000), + scale=tf.ones([1000, 1])).pdf(x), + reduce_dims=[0]) + ``` + + The `Laplace` distribution generates a tensor of shape [1000, 1]. When fed + to a `Normal`, this is interpreted as 1000 different locations, i.e., + 1000 non-identical Normals. Therefore a single call to pdf(x) yields 1000 + probabilities, one for every location. The average over this batch yields + the marginal. + + ~ Event dimensions: + Computing the determinant of the Jacobian of a function of a random + variable involves a reduction over event dimensions. + + Examples: + Write S, B, E for sample shape, batch shape, and event shape (resp.). + + ```python + x.get_shape() == S + B + E # For statically known x shape. + + # 100 iid samples from one multivariate Normal with two + # degrees of freedom (DF). + mu = [0., 0] + sigma = [[1., 0], + [0, 1]] + X = MultivariateNormal(loc=mu, scale=sigma).sample_n(n=100) + # S = [100] + # B = [] + # E = [2] + + # 100 iid samples from one Wishart with 2x2 DF. + sigma = [[1., 0], + [0, 1]] + X = Wishart(scale=sigma).sample_n(n=100) + # S = [100] + # B = [] + # E = [2, 2] + + # 100 iid samples (with shape [2, 50]) from two, non-identical bivariate + # Normal distributions. + mu = ... # shape(2, 2) + sigma = ... # shape(2, 2, 2) + X = MultivariateNormal(loc=mu, scale=sigma).sample(shape=[2, 50]) + # S = [2, 50] + # B = [2] + # E = [2] + ``` + + """ + + def __init__(self, batch_ndims=None, event_ndims=None, name='ShapeUtil'): + """Construct ShapeUtil with known sample, batch, and/or event ndims. + + Typically, batch_ndims and event_ndims are fixed throughout the lifetime of + a Distribution. + + Args: + batch_ndims: number of dims (rank) of the batch portion of indexes of a + `Tensor`. A "batch" is a non-identical distribution, i.e, Normal with + different parameters. + event_ndims: number of dims (rank) of the event portion of indexes of a + `Tensor`. An "event" is what is sampled from a distribution, i.e., a + trivariate Normal has an event shape of [3] and a 4 dimensional Wishart + has an event shape of [4, 4]. + name: `String`. The name to give Ops created by this class. + + Raises: + ValueError: if batch_ndims or event_ndims are invalid. + """ + if batch_ndims < 0: + raise ValueError('must specify non-negative batch_ndims(%d)', batch_ndims) + if batch_ndims > 0 and event_ndims < 1: + raise ValueError('must specify positive event_ndims(%d) when ' + 'batch_ndims(%d) is positive', event_ndims, batch_ndims) + # TODO(jvdillon): Support batches of scalars. + self._name = name + self._batch_ndims = batch_ndims + self._event_ndims = event_ndims + + @property + def name(self): + """Name given to ops created by this class.""" + return self._name + + @property + def batch_ndims(self): + """Returns number of dimensions corresponding to non-identical draws.""" + return self._batch_ndims + + @property + def event_ndims(self): + """Returns number of dimensions needed to index a sample's coordinates.""" + return self._event_ndims + + def get_ndims(self, x, name='get_ndims'): + """Get tensor ndims (rank). + + Args: + x: `Tensor`. + name: `String`. The name to give this op. + + Raises: + ValueError: if ndims is not statically known. + + Returns: + `Scalar` number of dimensions associated with a `Tensor`. + """ + if x is None: + raise ValueError('Input was None which does not have known ndims.') + with ops.name_scope(self.name): + with ops.op_scope([x], name): + ndims = ops.convert_to_tensor(x).get_shape().ndims + if ndims is None: + raise ValueError('ShapeUtil assumes static number of ' + 'dimensions(%d)', ndims) + return ndims + + def get_sample_ndims(self, x): + """Returns number of dimensions corresponding to iid draws. + + Args: + x: `Tensor`. + + Raises: + ValueError: if batch_ndims or event_ndims are not statically known. + ValueError: if static sample_ndims does not match inferred + + Returns: + Scalar number of dimensions associated with a sample. + """ + ndims = self.get_ndims(x) + sample_ndims = ndims - self.batch_ndims - self.event_ndims + if sample_ndims < 0: + raise ValueError('expected batch_ndims(%d) + event_ndims(%d) < ndims(%d)', + self.batch_ndims, self.event_ndims, ndims) + return sample_ndims + + def get_dims(self, x, sample=True, batch=True, event=True): + """Returns subset of tensor's dimension indexes (indexes into shape). + + Args: + x: `Tensor`. + sample: `Boolean`. Include sample dimensions or not. + batch: `Boolean`. Include batch dimensions or not. + event: `Boolean`. Include event dimensions or not. + + Raises: + ValueError: if `x.get_shape().ndims` is `None` + + Returns: + List enumerating requested dimensions. + """ + ndims = self.get_ndims(x) + + if sample and batch and event: + return list(range(ndims)) + + sample_start = 0 + batch_start = self.get_sample_ndims(x) + event_start = batch_start + self.batch_ndims + + sample_shape = list(range(sample_start, batch_start)) if sample else [] + batch_shape = list(range(batch_start, event_start)) if batch else [] + event_shape = list(range(event_start, ndims)) if event else [] + + return sample_shape + batch_shape + event_shape + + def get_shape(self, x, sample=True, batch=True, event=True, name='get_shape'): + """Returns subset of tensor's shape (size of dimensions). + + Args: + x: `Tensor`. + sample: `Boolean`. Include sample shape or not. + batch: `Boolean`. Include batch shape or not. + event: `Boolean`. Include event shape or not. + name: `String`. The name to give this op. + + Raises: + ValueError: if `x.get_shape().ndims` is `None` + + Returns: + List describing event shape if known statically, `Tensor` otherwise. + """ + if not sample and not batch and not event: + return [] + with ops.name_scope(self._name): + with ops.op_scope([x], name): + x = ops.convert_to_tensor(x) + shape = (x.get_shape().as_list() + if x.get_shape().is_fully_defined() + else array_ops.shape(x)) + + if sample and batch and event: + return shape + + sample_start = 0 + batch_start = self.get_sample_ndims(x) + event_start = batch_start + self.batch_ndims + + sample_shape = shape[sample_start:batch_start] if sample else [] + batch_shape = shape[batch_start:event_start] if batch else [] + event_shape = shape[event_start:] if event else [] + + if not batch and not event: + return sample_shape + if not sample and not event: + return batch_shape + if not sample and not batch: + return event_shape + + if x.get_shape().is_fully_defined(): + return sample_shape + batch_shape + event_shape + else: + return array_ops.concat(0, [sample_shape, batch_shape, event_shape]) + + def get_sample_dims(self, x): + """Returns dimension indexes corresponding to sample. + + Convenience function; identical to: + + ```python + get_dims(x, sample=True, batch=False, event=False) + ``` + + Args: + x: `Tensor`. + + Raises: + ValueError: if `x.get_shape().ndims` is `None` + + Returns: + List enumerating sample dimensions. + """ + return self.get_dims(x, sample=True, batch=False, event=False) + + def get_batch_dims(self, x): + """Returns dimension indexes corresponding to batch. + + Convenience function; identical to: + + ```python + get_dims(x, sample=False, batch=True, event=False) + ``` + + Args: + x: `Tensor`. + + Raises: + ValueError: if `x.get_shape().ndims` is `None` + + Returns: + List enumerating batch dimensions. + """ + return self.get_dims(x, sample=False, batch=True, event=False) + + def get_event_dims(self, x): + """Returns dimension indexes corresponding to event. + + Convenience function; identical to: + + ```python + get_dims(x, sample=False, batch=False, event=True) + ``` + + Args: + x: `Tensor`. + + Raises: + ValueError: if `x.get_shape().ndims` is `None` + + Returns: + List enumerating event dimensions. + """ + return self.get_dims(x, sample=False, batch=False, event=True) + + def get_sample_shape(self, x): + """Returns shape corresponding to sample. + + Convenience function; identical to: + + ```python + get_shape(x, sample=True, batch=False, event=False) + ``` + + Args: + x: `Tensor`. + + Returns: + List describing sample shape if known statically, `Tensor` otherwise. + """ + return self.get_shape(x, sample=True, batch=False, event=False) + + def get_batch_shape(self, x): + """Returns shape corresponding to batch. + + Convenience function; identical to: + + ```python + get_shape(x, sample=False, batch=True, event=False) + ``` + + Args: + x: `Tensor`. + + Returns: + List describing batch shape if known statically, `Tensor` otherwise. + """ + return self.get_shape(x, sample=False, batch=True, event=False) + + def get_event_shape(self, x): + """Returns shape corresponding to event. + + Convenience function; identical to: + + ```python + get_shape(x, sample=False, batch=False, event=True) + ``` + + Args: + x: `Tensor`. + + Returns: + List describing event shape if known statically, `Tensor` otherwise. + """ + return self.get_shape(x, sample=False, batch=False, event=True) diff --git a/tensorflow/contrib/framework/BUILD b/tensorflow/contrib/framework/BUILD index 9011b01ce61..9dd59319848 100644 --- a/tensorflow/contrib/framework/BUILD +++ b/tensorflow/contrib/framework/BUILD @@ -41,6 +41,7 @@ py_test( size = "small", srcs = ["python/framework/checkpoint_utils_test.py"], srcs_version = "PY2AND3", + tags = ["manual"], # http://b/30468735 deps = ["//tensorflow:tensorflow_py"], ) diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index 53e42c03138..1f63b8a0151 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -1036,6 +1036,7 @@ def separable_convolution2d( depthwise_weights = variables.model_variable( 'depthwise_weights', shape=depthwise_shape, + dtype=dtype, initializer=weights_initializer, regularizer=weights_regularizer, trainable=trainable, @@ -1048,6 +1049,7 @@ def separable_convolution2d( pointwise_weights = variables.model_variable( 'pointwise_weights', shape=pointwise_shape, + dtype=dtype, initializer=weights_initializer, regularizer=weights_regularizer, trainable=trainable, diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index 02d38220def..9c9cfe4c99b 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -1604,10 +1604,28 @@ class RepeatTests(tf.test.TestCase): class SeparableConv2dTest(tf.test.TestCase): - def testCreateConv(self): + def testCreateConvInt32(self): height, width = 3, 3 with self.test_session(): - images = tf.random_uniform((5, height, width, 3), seed=1) + images = tf.random_uniform( + (5, height, width, 3), seed=1, dtype=tf.int32, maxval=12345) + with self.assertRaisesRegexp(TypeError, 'non-floating point type'): + tf.contrib.layers.separable_conv2d(images, 32, [3, 3], 2) + + def testCreateConvFloat32(self): + height, width = 3, 3 + with self.test_session(): + images = tf.random_uniform( + (5, height, width, 3), seed=1, dtype=tf.float32) + output = tf.contrib.layers.separable_conv2d(images, 32, [3, 3], 2) + self.assertEquals(output.op.name, 'SeparableConv2d/Relu') + self.assertListEqual(output.get_shape().as_list(), [5, height, width, 32]) + + def testCreateConvFloat64(self): + height, width = 3, 3 + with self.test_session(): + images = tf.random_uniform( + (5, height, width, 3), seed=1, dtype=tf.float64) output = tf.contrib.layers.separable_conv2d(images, 32, [3, 3], 2) self.assertEquals(output.op.name, 'SeparableConv2d/Relu') self.assertListEqual(output.get_shape().as_list(), [5, height, width, 32]) diff --git a/tensorflow/contrib/learn/python/learn/supervised_session.py b/tensorflow/contrib/learn/python/learn/supervised_session.py index 768af3cc91f..982f49dd70a 100644 --- a/tensorflow/contrib/learn/python/learn/supervised_session.py +++ b/tensorflow/contrib/learn/python/learn/supervised_session.py @@ -123,8 +123,8 @@ class Scaffold(object): global_step_tensor = contrib_variables.get_or_create_global_step() self.global_step_tensor = global_step_tensor if init_op is None: - init_op = Scaffold._get_or_default( - ops.GraphKeys.INIT_OP, variables.initialize_all_variables) + init_op = Scaffold._get_or_default('init_op', ops.GraphKeys.INIT_OP, + variables.initialize_all_variables) self.init_op = init_op self.init_feed_dict = init_feed_dict # NOTE(touts): modifying the init function to be passed the scaffold is a @@ -135,19 +135,23 @@ class Scaffold(object): self.init_fn = None if ready_op is None: ready_op = Scaffold._get_or_default( - ops.GraphKeys.READY_OP, variables.report_uninitialized_variables) + 'ready_op', ops.GraphKeys.READY_OP, + variables.report_uninitialized_variables) self.ready_op = ready_op if local_init_op is None: - local_init_op = Scaffold._get_or_default( - ops.GraphKeys.LOCAL_INIT_OP, Scaffold._default_local_init_op) + local_init_op = Scaffold._get_or_default('local_init_op', + ops.GraphKeys.LOCAL_INIT_OP, + Scaffold._default_local_init_op) self.local_init_op = local_init_op if summary_op is None: - summary_op = Scaffold._get_or_default( - ops.GraphKeys.SUMMARY_OP, logging_ops.merge_all_summaries) + summary_op = Scaffold._get_or_default('summary_op', + ops.GraphKeys.SUMMARY_OP, + logging_ops.merge_all_summaries) self.summary_op = summary_op # pylint: disable=g-long-lambda if saver is None: saver = Scaffold._get_or_default( + 'saver', ops.GraphKeys.SAVERS, lambda: training_saver.Saver(sharded=True, max_to_keep=keep_checkpoint_max)) @@ -157,9 +161,16 @@ class Scaffold(object): ops.get_default_graph().finalize() @staticmethod - def _get_or_default(collection_key, default_constructor): + def _get_or_default(arg_name, collection_key, default_constructor): + """Get from cache or create a default operation.""" elements = ops.get_collection(collection_key) if elements: + if len(elements) > 1: + raise RuntimeError('More than one item in the collection "%s". ' + 'Please indicate which one to use by passing it to ' + 'the tf.Scaffold constructor as: ' + 'tf.Scaffold(%s=item to use)', collection_key, + arg_name) return elements[0] op = default_constructor() if op is not None: diff --git a/tensorflow/contrib/learn/python/learn/tests/supervised_session_test.py b/tensorflow/contrib/learn/python/learn/tests/supervised_session_test.py index ee03a565074..203878010d7 100644 --- a/tensorflow/contrib/learn/python/learn/tests/supervised_session_test.py +++ b/tensorflow/contrib/learn/python/learn/tests/supervised_session_test.py @@ -57,6 +57,14 @@ class ScaffoldTest(tf.test.TestCase): self.assertEqual(scaffold1.local_init_op, scaffold2.local_init_op) self.assertEqual(scaffold1.saver, scaffold2.saver) + def test_raise_error_if_more_than_one_cached_item(self): + with tf.Graph().as_default(): + tf.Variable([1]) + tf.add_to_collection(tf.GraphKeys.SAVERS, tf.train.Saver()) + tf.add_to_collection(tf.GraphKeys.SAVERS, tf.train.Saver()) + with self.assertRaisesRegexp(RuntimeError, 'More than one item'): + supervised_session.Scaffold() + def test_uses_passed_values(self): with tf.Graph().as_default(): scaffold = supervised_session.Scaffold(global_step_tensor=1, diff --git a/tensorflow/contrib/makefile/Makefile b/tensorflow/contrib/makefile/Makefile index bad3e07d72e..4987e9bcd40 100644 --- a/tensorflow/contrib/makefile/Makefile +++ b/tensorflow/contrib/makefile/Makefile @@ -393,10 +393,8 @@ $(wildcard tensorflow/core/graph/dot.*) \ $(wildcard tensorflow/core/lib/gif/*) \ $(wildcard tensorflow/core/lib/jpeg/*) \ $(wildcard tensorflow/core/lib/png/*) \ -$(wildcard tensorflow/core/util/checkpoint_reader.*) \ $(wildcard tensorflow/core/util/events_writer.*) \ $(wildcard tensorflow/core/util/reporter.*) \ -$(wildcard tensorflow/core/util/tf_status_helper.*) \ $(wildcard tensorflow/core/platform/default/stream_executor.*) \ $(wildcard tensorflow/core/platform/default/test_benchmark.*) \ $(wildcard tensorflow/core/platform/cuda.h) \ diff --git a/tensorflow/contrib/quantization/ops/array_ops.cc b/tensorflow/contrib/quantization/ops/array_ops.cc index 35d0e7f4c9e..e1cf3ded93f 100644 --- a/tensorflow/contrib/quantization/ops/array_ops.cc +++ b/tensorflow/contrib/quantization/ops/array_ops.cc @@ -13,11 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" namespace tensorflow { -// -------------------------------------------------------------------------- +using shape_inference::InferenceContext; +using shape_inference::Shape; REGISTER_OP("QuantizeV2") .Input("input: float") @@ -28,6 +31,15 @@ REGISTER_OP("QuantizeV2") .Output("output_max: float") .Attr("T: quantizedtype") .Attr("mode: {'MIN_COMBINED', 'MIN_FIRST'} = 'MIN_COMBINED'") + .SetShapeFn([](InferenceContext* c) { + TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c)); + const Shape* unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + c->set_output(1, c->Scalar()); + c->set_output(2, c->Scalar()); + return Status::OK(); + }) .Doc(R"doc( Quantize the 'input' tensor of type float to 'output' tensor of type 'T'. @@ -96,6 +108,13 @@ REGISTER_OP("Dequantize") .Output("output: float") .Attr("T: quantizedtype") .Attr("mode: {'MIN_COMBINED', 'MIN_FIRST'} = 'MIN_COMBINED'") + .SetShapeFn([](InferenceContext* c) { + TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c)); + const Shape* unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + return Status::OK(); + }) .Doc(R"doc( Dequantize the 'input' tensor into a float Tensor. diff --git a/tensorflow/contrib/quantization/ops/nn_ops.cc b/tensorflow/contrib/quantization/ops/nn_ops.cc index 814011e4111..18db2b0eaa2 100644 --- a/tensorflow/contrib/quantization/ops/nn_ops.cc +++ b/tensorflow/contrib/quantization/ops/nn_ops.cc @@ -156,6 +156,15 @@ REGISTER_OP("QuantizedMaxPool") .Attr("ksize: list(int)") .Attr("strides: list(int)") .Attr(GetPaddingAttrString()) + .SetShapeFn([](InferenceContext* c) { + TF_RETURN_IF_ERROR(shape_inference::MaxPoolShape(c)); + const Shape* unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + c->set_output(1, c->Scalar()); + c->set_output(2, c->Scalar()); + return Status::OK(); + }) .Doc(R"doc( Produces the max pool of the input tensor for quantized types. diff --git a/tensorflow/contrib/slim/README.md b/tensorflow/contrib/slim/README.md index 85bf786a66a..79b1a203ca3 100644 --- a/tensorflow/contrib/slim/README.md +++ b/tensorflow/contrib/slim/README.md @@ -198,7 +198,7 @@ of the corresponding TF-Slim code: ```python input = ... -net = slim.conv2d(input, [3, 3], 128, scope='conv1_1') +net = slim.conv2d(input, 128, [3, 3], scope='conv1_1') ``` TF-Slim provides standard implementations for numerous components for building @@ -431,7 +431,7 @@ between the predicted and true values. Certain models, such as multi-task learning models, require the use of multiple loss functions simultaneously. In -other words, the loss function ultimatey being minimized is the sum of various +other words, the loss function ultimately being minimized is the sum of various other loss functions. For example, consider a model that predicts both the type of scene in an image as well as the depth from the camera of each pixel. This model's loss function would be the sum of the diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 60bc8e17fa8..da0e8ca4c95 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -156,6 +156,7 @@ cc_library( "lib/io/table_options.h", "lib/jpeg/jpeg_mem.h", "lib/monitoring/counter.h", + "lib/monitoring/metric_def.h", "lib/random/distribution_sampler.h", "lib/random/philox_random.h", "lib/random/simple_philox.h", # TODO(josh11b): make internal @@ -423,9 +424,6 @@ tf_cuda_library( "graph/validate.h", "public/session.h", "public/session_options.h", - "public/tensor_c_api.h", - "util/checkpoint_reader.h", - "util/tf_status_helper.h", ], visibility = ["//visibility:public"], deps = [ @@ -605,10 +603,8 @@ filegroup( "lib/jpeg/**/*", "lib/png/**/*", "lib/gif/**/*", - "util/checkpoint_reader.*", "util/events_writer.*", "util/reporter.*", - "util/tf_status_helper.*", "platform/default/stream_executor.*", "platform/default/test_benchmark.*", "platform/cuda.h", @@ -917,8 +913,6 @@ tf_cuda_library( "**/*test*", "**/*main.cc", "example/example_parser_configuration.*", - "util/tf_status_helper.*", - "util/checkpoint_reader.*", "util/reporter.h", "util/reporter.cc", "framework/fake_input.*", @@ -1063,10 +1057,7 @@ tf_cuda_library( "graph/**/*.cc", "public/session.h", "public/session_options.h", - "public/tensor_c_api.h", "public/version.h", - "util/tf_status_helper.*", - "util/checkpoint_reader.*", ], exclude = [ "**/*test*", @@ -1497,36 +1488,6 @@ tf_cc_tests( ], ) -tf_cc_tests( - size = "small", - linkopts = select({ - "//tensorflow:darwin": ["-headerpad_max_install_names"], - "//conditions:default": [], - }), - linkstatic = tf_kernel_tests_linkstatic(), - tests = [ - "client/tensor_c_api_test.cc", - ], - deps = [ - ":core", - ":core_cpu", - ":core_cpu_internal", - ":direct_session_internal", - ":framework", - ":framework_internal", - ":lib", - ":lib_internal", - ":proto_text", - ":protos_all_cc", - ":test", - ":test_main", - ":testlib", - "//tensorflow/core/kernels:array", - "//tensorflow/core/kernels:math", - "//third_party/eigen3", - ], -) - # GPU-related tests tf_cc_tests_gpu( size = "small", @@ -2000,6 +1961,7 @@ filegroup( "example/testdata/parse_example_graph_def.pbtxt", ], ) + # ----------------------------------------------------------------------------- # Google-internal targets go here (must be at the end). diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc index d9fd1d406ad..ea8de56e324 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc @@ -182,8 +182,8 @@ Status GrpcServer::Init() { // Finish setting up worker environment. worker_env_.graph_mgr = new GraphMgr(&worker_env_); - worker_env_.rendezvous_mgr = new RpcRendezvousMgr(&worker_env_); worker_env_.compute_pool = ComputePool(sess_opts); + worker_env_.rendezvous_mgr = new RpcRendezvousMgr(&worker_env_); return Status::OK(); } diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc index 96f7db2694b..b01d603c6a6 100644 --- a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc +++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc @@ -37,8 +37,9 @@ namespace { class RpcRemoteRendezvous : public BaseRemoteRendezvous { public: - RpcRemoteRendezvous(const WorkerEnv* env, int64 step_id) - : BaseRemoteRendezvous(env, step_id, false) {} + RpcRemoteRendezvous(const WorkerEnv* env, WorkerCacheInterface* cache, + int64 step_id) + : BaseRemoteRendezvous(env, step_id, false), cache_(cache) {} protected: void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed, @@ -48,6 +49,7 @@ class RpcRemoteRendezvous : public BaseRemoteRendezvous { private: ~RpcRemoteRendezvous() override {} + WorkerCacheInterface* cache_; // Not owned. TF_DISALLOW_COPY_AND_ASSIGN(RpcRemoteRendezvous); }; @@ -55,13 +57,12 @@ class RpcRemoteRendezvous : public BaseRemoteRendezvous { class RpcRecvTensorCall : public BaseRecvTensorCall { public: RpcRecvTensorCall() - : wi_(nullptr), wc_(nullptr), allocator_(nullptr), dst_device_(nullptr) {} + : wi_(nullptr), allocator_(nullptr), dst_device_(nullptr) {} - void Init(WorkerCacheInterface* wc, WorkerInterface* wi, int64 step_id, - StringPiece key, Allocator* allocator, Device* dst_device, + void Init(WorkerInterface* wi, int64 step_id, StringPiece key, + Allocator* allocator, Device* dst_device, const Rendezvous::Args& recv_args, Rendezvous::DoneCallback done) { wi_ = wi; - wc_ = wc; allocator_ = allocator; dst_device_ = dst_device; recv_args_ = recv_args; @@ -73,7 +74,6 @@ class RpcRecvTensorCall : public BaseRecvTensorCall { void Reset() { delete wi_; wi_ = nullptr; - wc_ = nullptr; allocator_ = nullptr; dst_device_ = nullptr; // We don't clear opts_ and assume that Init will set up the state for @@ -123,6 +123,8 @@ class RpcRecvTensorCall : public BaseRecvTensorCall { const Rendezvous::DoneCallback& done() const { return done_; } private: + friend class RpcRemoteRendezvous; + // Start the main RecvTensor call, checking for an async abort. void StartRTCall(std::function recv_done) { wi_->RecvTensorAsync(&opts_, &req_, &resp_, @@ -137,8 +139,9 @@ class RpcRecvTensorCall : public BaseRecvTensorCall { }); } - WorkerInterface* wi_; // Owned. - WorkerCacheInterface* wc_; // Not owned. + string src_worker_; + string src_rel_device_; + WorkerInterface* wi_; Allocator* allocator_; Device* dst_device_; CallOptions opts_; @@ -153,7 +156,6 @@ class RpcRecvTensorCall : public BaseRecvTensorCall { TF_DISALLOW_COPY_AND_ASSIGN(RpcRecvTensorCall); }; -namespace { class RpcRecvTensorFreeList { public: RpcRecvTensorFreeList() {} @@ -195,32 +197,99 @@ class RpcRecvTensorFreeList { }; static RpcRecvTensorFreeList call_freelist_; -} + +// A private cache that wraps env->worker_cache and allows reuse of +// WorkerInterface objects. +class WorkerFreeListCache : public WorkerCacheInterface { + public: + explicit WorkerFreeListCache(WorkerCacheInterface* w) : wrapped_(w) {} + + ~WorkerFreeListCache() { + for (auto p : workers_) { + delete p.second.worker; + } + } + + void ListWorkers(std::vector* workers) override { + wrapped_->ListWorkers(workers); + } + + WorkerInterface* CreateWorker(const string& target) override { + mutex_lock l(mu_); + auto p = workers_.find(target); + if (p != workers_.end()) { + return p->second.worker; + } + WorkerState state; + state.worker = wrapped_->CreateWorker(target); + if (state.worker != nullptr) { + workers_.insert(make_pair(target, state)); + } + return state.worker; + } + + void ReleaseWorker(const string& target, WorkerInterface* worker) override { + // TODO(jeff,sanjay): Should decrement ref-count when we implement eviction. + } + + bool GetDeviceBusNonBlocking(const string& device, + BusAdjacency* ba) override { + return wrapped_->GetDeviceBusNonBlocking(device, ba); + } + + void GetDeviceBusAsync(const string& device, BusAdjacency* ba, + StatusCallback done) override { + wrapped_->GetDeviceBusAsync(device, ba, done); + } + + void SetLogging(bool active) override { wrapped_->SetLogging(active); } + + void ClearLogs() override { wrapped_->ClearLogs(); } + + bool RetrieveLogs(int64 step_id, StepStats* ss) override { + return wrapped_->RetrieveLogs(step_id, ss); + } + + private: + WorkerCacheInterface* wrapped_; + + // Information kept per created WorkerInterface. + struct WorkerState { + WorkerInterface* worker; + // TODO(jeff,sanjay): Add reference count if we support eviction. + }; + + // TODO(jeff,sanjay): Eviction when the map becomes too big. + mutex mu_; + std::unordered_map workers_ GUARDED_BY(mu_); +}; void RpcRemoteRendezvous::RecvFromRemoteAsync( const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& recv_args, DoneCallback done) { Status s; - // key.src_device identifies a remote device. - string src_worker; - string src_rel_device; - if (!DeviceNameUtils::SplitDeviceName(parsed.src_device, &src_worker, - &src_rel_device)) { - s = errors::Internal(parsed.src_device, - " is invalid remote source device."); - } // TODO(jeff): Consider checking for a valid worker_cache during the // constructor of RpcRemoteRendezvous, rather than here, to simplify // the twisty logic below. - WorkerCacheInterface* worker_cache = env_->worker_cache; - if (s.ok() && worker_cache == nullptr) { + if (env_->worker_cache == nullptr) { s = errors::Internal("No remote worker cache available."); + done(s, Args(), recv_args, Tensor{}, false); + return; } - WorkerInterface* rwi = - (worker_cache ? worker_cache->CreateWorker(src_worker) : nullptr); + + // Prepare a RecvTensor call that can handle being aborted. + RpcRecvTensorCall* call = call_freelist_.New(); + + // key.src_device identifies a remote device. + if (!DeviceNameUtils::SplitDeviceName(parsed.src_device, &call->src_worker_, + &call->src_rel_device_)) { + s = errors::Internal(parsed.src_device, + " is invalid remote source device."); + } + WorkerInterface* rwi = cache_->CreateWorker(call->src_worker_); if (s.ok() && rwi == nullptr) { - s = errors::Internal("No worker known as ", src_worker); + s = errors::Internal("No worker known as ", call->src_worker_); } Device* dst_device; @@ -228,21 +297,20 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync( s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_device); } if (!s.ok()) { + call_freelist_.Release(call); done(s, Args(), recv_args, Tensor{}, false); return; } Allocator* allocator = dst_device->GetAllocator(recv_args.alloc_attrs); - // Prepare a RecvTensor call that can handle being aborted. - RpcRecvTensorCall* call = call_freelist_.New(); - - call->Init(worker_cache, rwi, step_id_, parsed.FullKey(), allocator, - dst_device, recv_args, std::move(done)); + call->Init(rwi, step_id_, parsed.FullKey(), allocator, dst_device, recv_args, + std::move(done)); // Record "call" in active_ so that it can be aborted cleanly. RegisterCall(call); // Start "call". + Ref(); call->Start([this, call]() { // Removes "call" from active_. Prevent StartAbort(). DeregisterCall(call); @@ -255,15 +323,22 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync( call->tensor_proto(), call->recv_args().alloc_attrs, &val); } call->done()(s, Args(), call->recv_args(), val, call->is_dead()); + cache_->ReleaseWorker(call->src_worker_, call->wi_); + call->wi_ = nullptr; call_freelist_.Release(call); + Unref(); }); } } // namespace +RpcRendezvousMgr::RpcRendezvousMgr(const WorkerEnv* env) + : BaseRendezvousMgr(env), + cache_(new WorkerFreeListCache(env->worker_cache)) {} + BaseRemoteRendezvous* RpcRendezvousMgr::Create(int64 step_id, const WorkerEnv* worker_env) { - return new RpcRemoteRendezvous(worker_env, step_id); + return new RpcRemoteRendezvous(worker_env, cache_.get(), step_id); } } // end namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h index 7447c94c392..6a65d04ba47 100644 --- a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h +++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_RPC_RENDEZVOUS_MGR_H_ #include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h" +#include "tensorflow/core/distributed_runtime/worker_cache.h" #include "tensorflow/core/distributed_runtime/worker_env.h" #include "tensorflow/core/platform/macros.h" @@ -42,13 +43,16 @@ namespace tensorflow { // RendezvousMgr must have keys generated by Rendezvous::CreateKey. class RpcRendezvousMgr : public BaseRendezvousMgr { public: - explicit RpcRendezvousMgr(const WorkerEnv* env) : BaseRendezvousMgr(env) {} + explicit RpcRendezvousMgr(const WorkerEnv* env); protected: BaseRemoteRendezvous* Create(int64 step_id, const WorkerEnv* worker_env) override; private: + // Private cache_ that allows us to reuse WorkerInterface objects. + std::unique_ptr cache_; + TF_DISALLOW_COPY_AND_ASSIGN(RpcRendezvousMgr); }; diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc index 7e18278f309..dce49d33d77 100644 --- a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc +++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc @@ -46,10 +46,28 @@ Rendezvous::ParsedKey MakeKey(const string& s) { return key; } +namespace { +// Fake cache implementation for WorkerEnv. +class DummyWorkerCache : public WorkerCacheInterface { + void ListWorkers(std::vector* workers) override {} + WorkerInterface* CreateWorker(const string& target) override { + return nullptr; + } + bool GetDeviceBusNonBlocking(const string& device, + BusAdjacency* ba) override { + return false; + } + void GetDeviceBusAsync(const string& device, BusAdjacency* ba, + StatusCallback done) override {} +}; +} // namespace + TEST(RpcRendezvousMgrTest, LocalSendRecv) { + DummyWorkerCache cache; WorkerEnv env; env.env = Env::Default(); env.worker_name = "/job:mnist/replica:1/task:2"; + env.worker_cache = &cache; RpcRendezvousMgr rmgr(&env); const int64 step_id = 123; const Rendezvous::ParsedKey key = MakeKey(Rendezvous::CreateKey( @@ -71,9 +89,11 @@ TEST(RpcRendezvousMgrTest, LocalSendRecv) { } TEST(RpcRendezvousMgrTest, LocalAbort) { + DummyWorkerCache cache; WorkerEnv env; env.env = Env::Default(); env.worker_name = "/job:mnist/replica:1/task:2"; + env.worker_cache = &cache; RpcRendezvousMgr rmgr(&env); const Rendezvous::ParsedKey key = MakeKey(Rendezvous::CreateKey( "/job:mnist/replica:1/task:2/cpu:0", 7890, @@ -107,9 +127,11 @@ TEST(RpcRendezvousMgrTest, LocalAbort) { } TEST(RpcRendezvousMgrTest, CleanupAll) { + DummyWorkerCache cache; WorkerEnv env; env.env = Env::Default(); env.worker_name = "/job:mnist/replica:1/task:2"; + env.worker_cache = &cache; RpcRendezvousMgr rmgr(&env); const Rendezvous::ParsedKey key = MakeKey(Rendezvous::CreateKey( "/job:mnist/replica:1/task:2/cpu:0", 7890, @@ -140,9 +162,11 @@ class DummyDeviceContext : public DeviceContext { TEST(RpcRendezvousMgrTest, TransferDummyDeviceContext) { DummyDeviceContext* dc = new DummyDeviceContext(123); + DummyWorkerCache cache; WorkerEnv env; env.env = Env::Default(); env.worker_name = "/job:mnist/replica:1/task:2"; + env.worker_cache = &cache; RpcRendezvousMgr rmgr(&env); const int64 step_id = 123; const Rendezvous::ParsedKey key = MakeKey(Rendezvous::CreateKey( diff --git a/tensorflow/core/distributed_runtime/worker_cache.h b/tensorflow/core/distributed_runtime/worker_cache.h index 3efe14998fb..c46c0561364 100644 --- a/tensorflow/core/distributed_runtime/worker_cache.h +++ b/tensorflow/core/distributed_runtime/worker_cache.h @@ -19,7 +19,7 @@ limitations under the License. #include #include -#include "tensorflow/core/distributed_runtime/worker_interface.h" // for CallOptions +#include "tensorflow/core/distributed_runtime/worker_interface.h" #include "tensorflow/core/framework/device_attributes.pb.h" // for BusAdjacency #include "tensorflow/core/lib/core/status.h" @@ -28,7 +28,6 @@ typedef std::function StatusCallback; class ChannelCache; class StepStats; -class WorkerInterface; class WorkerCacheInterface { public: @@ -46,6 +45,17 @@ class WorkerCacheInterface { // ownership, not a cache lookup. virtual WorkerInterface* CreateWorker(const string& target) = 0; + // Release a worker previously returned by this->CreateWorker(target). + // + // TODO(jeff,sanjay): Consider moving target into WorkerInterface. + // TODO(jeff,sanjay): Consider disallowing direct deletion of WorkerInterface. + // TODO(jeff,sanjay): Unify all worker-cache impls and factor out a + // per-rpc-subsystem WorkerInterface creator. + virtual void ReleaseWorker(const string& target, WorkerInterface* worker) { + // Subclasses may override to reuse worker objects. + delete worker; + } + // Set *ba with the BusAdjacency of the specified remote device // within its local environment. Returns true if the device bus // affinity was set, using only locally cached data. Returns false diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc index fac3a047990..eea3112b3fa 100644 --- a/tensorflow/core/framework/common_shape_fns.cc +++ b/tensorflow/core/framework/common_shape_fns.cc @@ -414,6 +414,99 @@ Status AvgPoolShape(shape_inference::InferenceContext* c) { return Status::OK(); } +Status MaxPoolShape(shape_inference::InferenceContext* c) { + const Shape* input_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape)); + + string data_format; + Status s = c->GetAttr("data_format", &data_format); + + std::vector strides; + TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); + if (strides.size() != 4) { + return errors::InvalidArgument( + "AvgPool requires the stride attribute to contain 4 values, but " + "got: ", + strides.size()); + } + + std::vector kernel_sizes; + TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes)); + if (kernel_sizes.size() != 4) { + return errors::InvalidArgument( + "AvgPool requires the ksize attribute to contain 4 values, but got: ", + kernel_sizes.size()); + } + + int32 stride_rows, stride_cols, stride_depth; + int32 kernel_rows, kernel_cols, kernel_depth; + + if (s.ok() && data_format == "NCHW") { + // Convert input shape to default NHWC for inference + input_shape = + c->MakeShape({{c->Dim(input_shape, 0), c->Dim(input_shape, 2), + c->Dim(input_shape, 3), c->Dim(input_shape, 1)}}); + stride_depth = strides[1]; + stride_rows = strides[2]; + stride_cols = strides[3]; + kernel_depth = kernel_sizes[1]; + kernel_rows = kernel_sizes[2]; + kernel_cols = kernel_sizes[3]; + } else { + stride_rows = strides[1]; + stride_cols = strides[2]; + stride_depth = strides[3]; + kernel_rows = kernel_sizes[1]; + kernel_cols = kernel_sizes[2]; + kernel_depth = kernel_sizes[3]; + } + + const Dimension* batch_size_dim = c->Dim(input_shape, 0); + const Dimension* in_rows_dim = c->Dim(input_shape, 1); + const Dimension* in_cols_dim = c->Dim(input_shape, 2); + const Dimension* in_depth_dim = c->Dim(input_shape, 3); + + // At the moment we need to know the values of several fields. + TF_RETURN_IF_ERROR(CheckKnownDim(c, in_rows_dim, "in_rows")); + TF_RETURN_IF_ERROR(CheckKnownDim(c, in_cols_dim, "in_cols")); + TF_RETURN_IF_ERROR(CheckKnownDim(c, in_depth_dim, "in_depth")); + + Padding padding; + TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); + + // TODO(mrry,shlens): Raise an error if the stride would cause + // information in the input to be ignored. This will require a change + // in the kernel implementation. + auto in_rows = c->Value(in_rows_dim); + auto in_cols = c->Value(in_cols_dim); + auto in_depth = c->Value(in_depth_dim); + + int64 output_rows, output_cols, output_depth; + int64 padding_before, padding_after; + TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( + in_rows, kernel_rows, stride_rows, padding, &output_rows, &padding_before, + &padding_after)); + TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( + in_cols, kernel_cols, stride_cols, padding, &output_cols, &padding_before, + &padding_after)); + TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( + in_depth, kernel_depth, stride_depth, padding, &output_depth, + &padding_before, &padding_after)); + + const Shape* output_shape = + c->MakeShape({batch_size_dim, output_rows, output_cols, output_depth}); + + if (data_format == "NCHW") { + // Convert output shape back to expected NCHW data format. + output_shape = + c->MakeShape({c->Dim(output_shape, 0), c->Dim(output_shape, 3), + c->Dim(output_shape, 1), c->Dim(output_shape, 2)}); + } + + c->set_output(0, output_shape); + return Status::OK(); +} + Status UnknownShape(shape_inference::InferenceContext* c) { for (int i = 0; i < c->num_outputs(); ++i) { c->set_output(i, c->UnknownShape()); diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h index e4c5afc10eb..f1bdd5ee8d1 100644 --- a/tensorflow/core/framework/common_shape_fns.h +++ b/tensorflow/core/framework/common_shape_fns.h @@ -163,6 +163,9 @@ Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c); // Shape function for AvgPool-like operations. Status AvgPoolShape(shape_inference::InferenceContext* c); +// Shape function for MaxPool-like operations. +Status MaxPoolShape(shape_inference::InferenceContext* c); + // Shape function for use with ops whose output shapes are unknown. Status UnknownShape(shape_inference::InferenceContext* c); diff --git a/tensorflow/core/framework/common_shape_fns_test.cc b/tensorflow/core/framework/common_shape_fns_test.cc index ffe09040235..eada469b17a 100644 --- a/tensorflow/core/framework/common_shape_fns_test.cc +++ b/tensorflow/core/framework/common_shape_fns_test.cc @@ -485,6 +485,33 @@ TEST(CommonShapeFnsTest, AvgPool2DShapeTest) { INFER_ERROR("must be rank 4", op, "[4,4]"); } +TEST(CommonShapeFnsTest, MaxPool2DShapeTest) { + ShapeInferenceTestOp op("MaxPool"); + auto set_op = [&op](const std::vector& strides, + const std::vector& ksizes, const string& padding, + const string& data_format) { + TF_CHECK_OK(NodeDefBuilder("test", "MaxPool") + .Input("input", 0, DT_FLOAT) + .Attr("strides", strides) + .Attr("ksize", ksizes) + .Attr("padding", padding) + .Attr("data_format", data_format) + .Finalize(&op.node_def)); + }; + + // Most of the functionality is tested by conv-like shapes, + // so we check the very-specific maxpooling features here, + // namely depthwise kernel and striding. + + // all 1 strides, depth 2 filter + set_op({1, 1, 1, 1}, {1, 1, 1, 2}, "VALID", "NHWC"); + INFER_OK(op, "[1,2,2,2]", "[d0_0,2,2,1]"); + + // depth 3 stride, 1x1x1 filter, NCHW + set_op({1, 3, 1, 1}, {1, 1, 1, 1}, "VALID", "NCHW"); + INFER_OK(op, "[1,7,5,5]", "[d0_0,3,5,5]"); +} + TEST(CommonShapeFnsTest, UnknownShapeTest) { { // Single output diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc index dce44d8d4d8..c66d9fb4e14 100644 --- a/tensorflow/core/framework/shape_inference.cc +++ b/tensorflow/core/framework/shape_inference.cc @@ -492,6 +492,10 @@ Status InferenceContext::MakeDimForScalarInput(int idx, const Dimension** out) { *out = UnknownDim(); return Status::OK(); } + const int rank = t->dims(); + if (rank != 0) { + return errors::InvalidArgument("Input must be scalar but has rank ", rank); + } int64 val; if (t->dtype() == DT_INT32) { diff --git a/tensorflow/core/kernels/bcast_ops.cc b/tensorflow/core/kernels/bcast_ops.cc index 706ecaaee0a..b7851f9ff67 100644 --- a/tensorflow/core/kernels/bcast_ops.cc +++ b/tensorflow/core/kernels/bcast_ops.cc @@ -58,6 +58,8 @@ class BCastGradArgsOp : public OpKernel { Output(ctx, 1, bcast.grad_y_reduce_idx()); } + bool IsExpensive() override { return false; } + private: void Output(OpKernelContext* ctx, int idx, const BCast::Vec& v) { const int64 len = v.size(); diff --git a/tensorflow/core/kernels/one_hot_op.cc b/tensorflow/core/kernels/one_hot_op.cc index 916d85df2bc..1dc1bf65b22 100644 --- a/tensorflow/core/kernels/one_hot_op.cc +++ b/tensorflow/core/kernels/one_hot_op.cc @@ -85,26 +85,28 @@ class OneHotOp : public OpKernel { Tensor* output; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &output)); - // prefix_dim_size == # of elements before the axis - // depth_v == # of elements per axis - // suffix_dim_size == # of elements after the axis - int64 prefix_dim_size = 1; - for (int i = 0; i < axis; ++i) { - prefix_dim_size *= indices_shape.dim_size(i); + if (output_shape.num_elements() > 0) { + // prefix_dim_size == # of elements before the axis + // depth_v == # of elements per axis + // suffix_dim_size == # of elements after the axis + int64 prefix_dim_size = 1; + for (int i = 0; i < axis; ++i) { + prefix_dim_size *= indices_shape.dim_size(i); + } + TI suffix_dim_size = indices_shape.num_elements() / prefix_dim_size; + + // Split indices into matrix of size prefix_dim_size x suffix_dim_size + auto indices_t = + indices.shaped({prefix_dim_size, suffix_dim_size}); + // Split output into 3-Tensor of size: + // prefix_dim_size x depth x suffix_dim_size. + auto output_t = + output->shaped({prefix_dim_size, depth_v, suffix_dim_size}); + + functor::OneHot::Compute(ctx->eigen_device(), + indices_t, on_value_t, + off_value_t, &output_t); } - TI suffix_dim_size = - indices_shape.num_elements() / prefix_dim_size; - - // Split indices into matrix of size prefix_dim_size x suffix_dim_size - auto indices_t = - indices.shaped({prefix_dim_size, suffix_dim_size}); - // Split output into 3-Tensor of size: - // prefix_dim_size x depth x suffix_dim_size. - auto output_t = - output->shaped({prefix_dim_size, depth_v, suffix_dim_size}); - - functor::OneHot::Compute(ctx->eigen_device(), indices_t, - on_value_t, off_value_t, &output_t); } private: @@ -113,12 +115,12 @@ class OneHotOp : public OpKernel { TF_DISALLOW_COPY_AND_ASSIGN(OneHotOp); }; -#define REGISTER_ONE_HOT_INDEX(type, index_type) \ - REGISTER_KERNEL_BUILDER(Name("OneHot") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("TI") \ - .TypeConstraint("T") \ - .HostMemory("depth"), \ +#define REGISTER_ONE_HOT_INDEX(type, index_type) \ + REGISTER_KERNEL_BUILDER(Name("OneHot") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("TI") \ + .TypeConstraint("T") \ + .HostMemory("depth"), \ OneHotOp); #define REGISTER_ONE_HOT(type) \ @@ -132,13 +134,13 @@ TF_CALL_ALL_TYPES(REGISTER_ONE_HOT); // Forward declarations of the functor specializations for GPU. namespace functor { -#define DECLARE_GPU_SPEC_INDEX(T, TI) \ - template <> \ - void OneHot::Compute( \ - const GPUDevice& d, const typename TTypes::ConstMatrix& indices, \ - const typename TTypes::ConstScalar& on_value, \ - const typename TTypes::ConstScalar& off_value, \ - typename TTypes::Tensor* output); \ +#define DECLARE_GPU_SPEC_INDEX(T, TI) \ + template <> \ + void OneHot::Compute( \ + const GPUDevice& d, const typename TTypes::ConstMatrix& indices, \ + const typename TTypes::ConstScalar& on_value, \ + const typename TTypes::ConstScalar& off_value, \ + typename TTypes::Tensor* output); \ extern template struct OneHot; #define DECLARE_GPU_SPEC(T) \ @@ -154,12 +156,12 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); } // namespace functor // Registration of the GPU implementations. -#define REGISTER_ONE_HOT_GPU_INDEX(type, index_type) \ - REGISTER_KERNEL_BUILDER(Name("OneHot") \ - .Device(DEVICE_GPU) \ - .TypeConstraint("TI") \ - .TypeConstraint("T") \ - .HostMemory("depth"), \ +#define REGISTER_ONE_HOT_GPU_INDEX(type, index_type) \ + REGISTER_KERNEL_BUILDER(Name("OneHot") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("TI") \ + .TypeConstraint("T") \ + .HostMemory("depth"), \ OneHotOp); #define REGISTER_ONE_HOT_GPU(type) \ diff --git a/tensorflow/core/kernels/pooling_ops_common.h b/tensorflow/core/kernels/pooling_ops_common.h index 3d95338d11d..168e8ec1eda 100644 --- a/tensorflow/core/kernels/pooling_ops_common.h +++ b/tensorflow/core/kernels/pooling_ops_common.h @@ -111,6 +111,17 @@ class MaxPoolingOp : public OpKernel { 0, params.forward_output_shape(), &output)); if (params.depth_window > 1) { + // Validate spec against the current implementation. A + // relaxation of these requirements would be ideal. + OP_REQUIRES(context, params.depth % params.depth_window == 0, + errors::Unimplemented( + "Depthwise max pooling requires " + "the depth window to evenly divide the input depth.")); + OP_REQUIRES( + context, params.depth_window == params.depth_stride, + errors::Unimplemented("Depthwise max pooling requires " + "the depth window to equal the depth stride.")); + DepthwiseMaxPool(context, output, tensor_in, params); } else { SpatialMaxPool(context, output, tensor_in, params, padding_); diff --git a/tensorflow/core/kernels/sequence_ops.cc b/tensorflow/core/kernels/sequence_ops.cc index 0acde9c498b..3cbd9691d18 100644 --- a/tensorflow/core/kernels/sequence_ops.cc +++ b/tensorflow/core/kernels/sequence_ops.cc @@ -118,16 +118,23 @@ class LinSpaceOp : public OpKernel { } }; -#define REGISTER_CPU_KERNEL(T) \ +#define REGISTER_KERNEL(DEV, T) \ REGISTER_KERNEL_BUILDER(Name("LinSpace") \ - .Device(DEVICE_CPU) \ + .Device(DEV) \ .TypeConstraint("T") \ .HostMemory("start") \ .HostMemory("stop") \ .HostMemory("num") \ .HostMemory("output"), \ LinSpaceOp); +#define REGISTER_CPU_KERNEL(T) REGISTER_KERNEL(DEVICE_CPU, T) TF_CALL_float(REGISTER_CPU_KERNEL); TF_CALL_double(REGISTER_CPU_KERNEL); +// NOTE(touts): We register the op on GPU but it still runs on CPU +// because its inputs and outputs are tagged as HostMemory. +#define REGISTER_GPU_KERNEL(T) REGISTER_KERNEL(DEVICE_GPU, T) +TF_CALL_float(REGISTER_GPU_KERNEL); +TF_CALL_double(REGISTER_GPU_KERNEL); + } // namespace tensorflow diff --git a/tensorflow/core/lib/monitoring/counter.cc b/tensorflow/core/lib/monitoring/counter.cc deleted file mode 100644 index 37960a4acd7..00000000000 --- a/tensorflow/core/lib/monitoring/counter.cc +++ /dev/null @@ -1,31 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/lib/monitoring/counter.h" - -#include "tensorflow/core/platform/logging.h" - -namespace tensorflow { -namespace monitoring { - -void CounterCell::IncrementBy(const int64 step) { - DCHECK_LE(0, step) << "Must not decrement cumulative metrics."; - value_ += step; -} - -int64 CounterCell::value() const { return value_; } - -} // namespace monitoring -} // namespace tensorflow diff --git a/tensorflow/core/lib/monitoring/counter.h b/tensorflow/core/lib/monitoring/counter.h index af76884012d..7de85b75cb6 100644 --- a/tensorflow/core/lib/monitoring/counter.h +++ b/tensorflow/core/lib/monitoring/counter.h @@ -20,6 +20,8 @@ limitations under the License. #include #include +#include "tensorflow/core/lib/monitoring/metric_def.h" +#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/thread_annotations.h" @@ -71,9 +73,12 @@ class CounterCell { template class Counter { public: - Counter() {} ~Counter() {} + explicit Counter( + const MetricDef& metric_def) + : metric_def_(metric_def) {} + // Retrieves the cell for the specified labels, creating it on demand if // not already present. template @@ -82,6 +87,10 @@ class Counter { private: mutable mutex mu_; + // The metric definition. This will be used to identify the metric when we + // register it for exporting. + const MetricDef metric_def_; + using LabelArray = std::array; std::map cells_ GUARDED_BY(mu_); @@ -92,6 +101,13 @@ class Counter { // Implementation details follow. API readers may skip. //// +inline void CounterCell::IncrementBy(const int64 step) { + DCHECK_LE(0, step) << "Must not decrement cumulative metrics."; + value_ += step; +} + +inline int64 CounterCell::value() const { return value_; } + template template CounterCell* Counter::GetCell(const Labels&... labels) diff --git a/tensorflow/core/lib/monitoring/counter_test.cc b/tensorflow/core/lib/monitoring/counter_test.cc index 0010662e263..0e42aed794d 100644 --- a/tensorflow/core/lib/monitoring/counter_test.cc +++ b/tensorflow/core/lib/monitoring/counter_test.cc @@ -19,26 +19,28 @@ limitations under the License. namespace tensorflow { namespace monitoring { +namespace { class LabeledCounterTest : public ::testing::Test { protected: LabeledCounterTest() {} - Counter<1> counter_; + Counter<1> counter_with_labels_{{"/tensorflow/test/counter_with_labels_", + "Counter with one label.", "One label"}}; }; TEST_F(LabeledCounterTest, InitializedWithZero) { - EXPECT_EQ(0, counter_.GetCell("Empty")->value()); + EXPECT_EQ(0, counter_with_labels_.GetCell("Empty")->value()); } TEST_F(LabeledCounterTest, GetCell) { - auto* cell = counter_.GetCell("GetCellOp"); + auto* cell = counter_with_labels_.GetCell("GetCellOp"); EXPECT_EQ(0, cell->value()); cell->IncrementBy(42); EXPECT_EQ(42, cell->value()); - auto* same_cell = counter_.GetCell("GetCellOp"); + auto* same_cell = counter_with_labels_.GetCell("GetCellOp"); EXPECT_EQ(42, same_cell->value()); same_cell->IncrementBy(58); @@ -49,29 +51,31 @@ TEST_F(LabeledCounterTest, GetCell) { using LabeledCounterDeathTest = LabeledCounterTest; TEST_F(LabeledCounterDeathTest, DiesOnDecrement) { - EXPECT_DEBUG_DEATH({ counter_.GetCell("DyingOp")->IncrementBy(-1); }, - "decrement"); + EXPECT_DEBUG_DEATH( + { counter_with_labels_.GetCell("DyingOp")->IncrementBy(-1); }, + "decrement"); } class UnlabeledCounterTest : public ::testing::Test { protected: UnlabeledCounterTest() {} - Counter<0> counter_; + Counter<0> counter_without_labels_{ + {"/tensorflow/test/counter0", "Counter without any labels."}}; }; TEST_F(UnlabeledCounterTest, InitializedWithZero) { - EXPECT_EQ(0, counter_.GetCell()->value()); + EXPECT_EQ(0, counter_without_labels_.GetCell()->value()); } TEST_F(UnlabeledCounterTest, GetCell) { - auto* cell = counter_.GetCell(); + auto* cell = counter_without_labels_.GetCell(); EXPECT_EQ(0, cell->value()); cell->IncrementBy(42); EXPECT_EQ(42, cell->value()); - auto* same_cell = counter_.GetCell(); + auto* same_cell = counter_without_labels_.GetCell(); EXPECT_EQ(42, same_cell->value()); same_cell->IncrementBy(58); @@ -82,8 +86,10 @@ TEST_F(UnlabeledCounterTest, GetCell) { using UnlabeledCounterDeathTest = UnlabeledCounterTest; TEST_F(UnlabeledCounterDeathTest, DiesOnDecrement) { - EXPECT_DEBUG_DEATH({ counter_.GetCell()->IncrementBy(-1); }, "decrement"); + EXPECT_DEBUG_DEATH({ counter_without_labels_.GetCell()->IncrementBy(-1); }, + "decrement"); } +} // namespace } // namespace monitoring } // namespace tensorflow diff --git a/tensorflow/core/lib/monitoring/metric_def.h b/tensorflow/core/lib/monitoring/metric_def.h new file mode 100644 index 00000000000..f7037359eb3 --- /dev/null +++ b/tensorflow/core/lib/monitoring/metric_def.h @@ -0,0 +1,128 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_METRIC_DEF_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_METRIC_DEF_H_ + +#include +#include + +#include "tensorflow/core/lib/core/stringpiece.h" + +namespace tensorflow { +namespace monitoring { + +// Everything in the internal namespace is implementation details. Do not depend +// on this. +namespace internal { + +// Ensures that the string is a compile-time string literal. +class StringLiteral { + public: + // We allow implicit conversions here on purpose. + template + StringLiteral(const char (&data)[N]) : literal_(data, N) {} + + // This ctor will be called for non-literals, causing compile-time failure. + template + StringLiteral(const NotStringLiteral& not_string_literal) = delete; + + // Implicit conversion to StringPiece. + operator StringPiece() const { return literal_; } + + private: + const StringPiece literal_; +}; + +} // namespace internal + +// The different metric kinds available. +// +// Gauge indicates that the metric's values are instantaneous measurements of a +// (typically) continuously varying quantity. Examples: a process's current heap +// size, a queue's current length. +// +// Cumulative indicates that the metric's values represent non-negative changes +// over specified time periods. Example: the number of rpc calls to a service. +enum MetricKind { GAUGE, CUMULATIVE }; + +// Abstract base class for a metric definition. +// +// Unlike MetricDef, this class is non-templatized and allows storing and +// accessing metric definitions without the full type information. +// +// Everything except the value type of a metric is stored here. Please read +// MetricDef class comments for more details. +class AbstractMetricDef { + public: + MetricKind kind() const { return kind_; } + + StringPiece name() const { return name_; } + + StringPiece description() const { return description_; } + + const std::vector label_descriptions() const { + return label_descriptions_; + } + + private: + template + friend class MetricDef; + + AbstractMetricDef( + const MetricKind kind, const internal::StringLiteral name, + const internal::StringLiteral description, + const std::vector& label_descriptions) + : kind_(kind), + name_(name), + description_(description), + label_descriptions_( + {label_descriptions.begin(), label_descriptions.end()}) {} + + const MetricKind kind_; + const StringPiece name_; + const StringPiece description_; + const std::vector label_descriptions_; +}; + +// Metric definition. +// +// A metric is defined by its kind, value-type, name, description and the +// description of its labels. +// +// NOTE: We allow only string literals for the name, description and label +// descriptions because these should be fixed at compile-time and shouldn't be +// dynamic. +template +class MetricDef : public AbstractMetricDef { + public: + using value_type = Value; + + template + MetricDef(const internal::StringLiteral name, + const internal::StringLiteral description, + const LabelDesc&... label_descriptions) + : AbstractMetricDef(metric_kind, name, description, + {label_descriptions...}) { + static_assert(sizeof...(LabelDesc) == NumLabels, + "Mismatch between Counter and number of label " + "descriptions."); + } +}; + +} // namespace monitoring +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_METRIC_DEF_H_ diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index 42013bed56a..fe3d7406961 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -344,6 +344,34 @@ REGISTER_OP("Split") .Output("output: num_split * T") .Attr("num_split: int >= 1") .Attr("T: type") + .SetShapeFn([](InferenceContext* c) { + const Dimension* split_dimension; + TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(0, &split_dimension)); + int num_split = c->num_outputs(); + const Shape* input = c->input(1); + const Shape* out; + if (!c->ValueKnown(split_dimension)) { + if (c->RankKnown(input)) { + std::vector dims; + dims.resize(c->Rank(input)); + for (int i = 0; i < dims.size(); ++i) dims[i] = c->UnknownDim(); + out = c->MakeShape(dims); + } else { + out = c->UnknownShape(); + } + } else { + int64 split_dim = c->Value(split_dimension); + TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, split_dim + 1, &input)); + const Dimension* split_dim_size; + TF_RETURN_WITH_CONTEXT_IF_ERROR( + c->Divide(c->Dim(input, split_dim), num_split, &split_dim_size), + "Number of ways to split should evenly divide the split dimension"); + TF_RETURN_IF_ERROR( + c->ReplaceDim(input, split_dim, split_dim_size, &out)); + } + for (int i = 0; i < num_split; ++i) c->set_output(i, out); + return Status::OK(); + }) .Doc(R"doc( Splits a tensor into `num_split` tensors along one dimension. @@ -1443,7 +1471,7 @@ Status ShapeShapeFn(InferenceContext* c) { for (int i = 0; i < c->num_inputs(); ++i) { const Dimension* dim; if (c->RankKnown(c->input(i))) { - dim = c->MakeDim(c->Rank(c->input(0))); + dim = c->MakeDim(c->Rank(c->input(i))); } else { dim = c->UnknownDim(); } @@ -1479,6 +1507,7 @@ REGISTER_OP("ShapeN") .Output("output: N * int32") .Attr("N: int") .Attr("T: type") + .SetShapeFn(ShapeShapeFn) .Doc(R"doc( Returns shape of tensors. @@ -1493,6 +1522,42 @@ REGISTER_OP("ReverseSequence") .Attr("seq_dim: int") .Attr("batch_dim: int = 0") .Attr("T: type") + .SetShapeFn([](InferenceContext* c) { + const Shape* input = c->input(0); + const Shape* seq_lens_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &seq_lens_shape)); + + int64 seq_dim; + TF_RETURN_IF_ERROR(c->GetAttr("seq_dim", &seq_dim)); + int64 batch_dim; + TF_RETURN_IF_ERROR(c->GetAttr("batch_dim", &batch_dim)); + + if (!c->RankKnown(input)) { + return shape_inference::UnknownShape(c); + } + + // Validate batch_dim and seq_dim against input. + const int32 input_rank = c->Rank(input); + if (batch_dim >= input_rank) { + return errors::InvalidArgument("batch_dim must be < input rank: ", + batch_dim, " vs. ", input_rank); + } + if (seq_dim >= input_rank) { + return errors::InvalidArgument("seq_dim must be < input rank: ", + seq_dim, " vs. ", input_rank); + } + + const Dimension* batch_dim_dim = c->Dim(input, batch_dim); + TF_RETURN_IF_ERROR( + c->Merge(batch_dim_dim, c->Dim(seq_lens_shape, 0), &batch_dim_dim)); + + // Replace batch_dim of input with batch_size + const Shape* output_shape; + TF_RETURN_IF_ERROR( + c->ReplaceDim(input, batch_dim, batch_dim_dim, &output_shape)); + c->set_output(0, output_shape); + return Status::OK(); + }) .Doc(R"doc( Reverses variable length slices. @@ -1564,6 +1629,7 @@ REGISTER_OP("Rank") .Input("input: T") .Output("output: int32") .Attr("T: type") + .SetShapeFn(shape_inference::ScalarShape) .Doc(R"doc( Returns the rank of a tensor. @@ -1587,6 +1653,7 @@ REGISTER_OP("Size") .Input("input: T") .Output("output: int32") .Attr("T: type") + .SetShapeFn(shape_inference::ScalarShape) .Doc(R"doc( Returns the size of a tensor. @@ -1669,7 +1736,7 @@ begin_mask: a bitmask where a bit i being 1 means to ignore the begin begin[i] will be replaced with `[0, n-1) if `stride[i] > 0` or `[-1, n-1]` if `stride[i] < 0` end_mask: analogous to `begin_mask` -ellipsis_mask: a bitmask where bit `i` being 1 means the `i`th +ellipsis_mask: a bitmask where bit `i` being 1 means the `i`th position is actually an ellipsis. One bit at most can be 1. new_axis_mask: a bitmask where bit `i` being 1 means the `i`th position creates a dimension in the tensor of length 1. Thus @@ -1678,7 +1745,7 @@ new_axis_mask: a bitmask where bit `i` being 1 means the `i`th shrink_axis_mask: a bitmask where bit `i` implies that the `i`th position should shrink the dimensionality. begin and end must imply a slice of size 1 in the dimension. For example in - python one might do `foo[:,3,:]` which would result in + python one might do `foo[:,3,:]` which would result in `shrink_axis_mask` being 2. )doc"); @@ -1705,7 +1772,7 @@ as `shape`). The gradient will be zero in any element that the slice does not select. Arguments are the same as StridedSliceGrad with the exception that -`dy` is the input gradient to be propagated and `shape` is the +`dy` is the input gradient to be propagated and `shape` is the shape of `StridedSlice`'s `input`. )doc"); @@ -1744,7 +1811,14 @@ each repeated tile of `input` into `output`. )doc"); // -------------------------------------------------------------------------- -REGISTER_OP("Where").Input("input: bool").Output("index: int64").Doc(R"doc( +REGISTER_OP("Where") + .Input("input: bool") + .Output("index: int64") + .SetShapeFn([](InferenceContext* c) { + c->set_output(0, c->Matrix(c->UnknownDim(), c->Rank(c->input(0)))); + return Status::OK(); + }) + .Doc(R"doc( Returns locations of true values in a boolean tensor. This operation returns the coordinates of true elements in `input`. The diff --git a/tensorflow/core/ops/array_ops_test.cc b/tensorflow/core/ops/array_ops_test.cc index eaaaaf884c5..6516b24f0b5 100644 --- a/tensorflow/core/ops/array_ops_test.cc +++ b/tensorflow/core/ops/array_ops_test.cc @@ -237,6 +237,20 @@ TEST(ArrayOpsTest, Shape_ShapeFn) { INFER_OK(op, "[?,2,3,4,5]", "[5]"); } +TEST(ArrayOpsTest, ShapeN_ShapeFn) { + ShapeInferenceTestOp op("ShapeN"); + int n = 3; + std::vector src_list; + for (int i = 0; i < n; ++i) src_list.emplace_back("a", 0, DT_FLOAT); + TF_CHECK_OK(NodeDefBuilder("test", "ShapeN") + .Input(src_list) + .Attr("N", n) + .Finalize(&op.node_def)); + INFER_OK(op, "?;?;?", "[?];[?];[?]"); + INFER_OK(op, "[?];[?];[?]", "[1];[1];[1]"); + INFER_OK(op, "[?,2,3,4,5];?;[1,?,3]", "[5];[?];[3]"); +} + TEST(ArrayOpsTest, Unique_ShapeFn) { ShapeInferenceTestOp op("Unique"); INFER_OK(op, "?", "[?];in0"); @@ -696,4 +710,61 @@ TEST(ArrayOpsTest, Squeeze_ShapeFn) { INFER_ERROR("not in [-3,3)", op, "[1,2,3]"); } +TEST(ArrayOpsTest, ReverseSequence_ShapeFn) { + ShapeInferenceTestOp op("ReverseSequence"); + auto rebuild_node_def = [&op](const int32 seq_dim, const int32 batch_dim) { + TF_CHECK_OK(NodeDefBuilder("test", "ReverseSequence") + .Input("input", 0, DT_FLOAT) + .Input("seq_lengths", 1, DT_INT64) + .Attr("seq_dim", seq_dim) + .Attr("batch_dim", batch_dim) + .Finalize(&op.node_def)); + }; + + rebuild_node_def(1, 2); + // No valid shape provided, so output is unknown. + INFER_OK(op, "?;[10]", "?"); + + // Bad rank for seq_lengths + INFER_ERROR("Shape must be rank 1 but is rank 2", op, "?;[10,10]"); + + // Validate seq_dim and batch_dim + rebuild_node_def(1, 4); + INFER_ERROR("batch_dim must be < input rank", op, "[1,2,3];[3]"); + rebuild_node_def(4, 1); + INFER_ERROR("seq_dim must be < input rank", op, "[1,2,3];[3]"); + + rebuild_node_def(1, 2); + INFER_OK(op, "[1,2,3];[3]", "[d0_0,d0_1,d0_2]"); + // Resolves uncertainty on batch dimension by merging. + INFER_OK(op, "[1,2,?];[3]", "[d0_0,d0_1,d1_0]"); + INFER_OK(op, "[1,2,3];[?]", "[d0_0,d0_1,d0_2]"); +} + +TEST(ArrayOpsTest, Split_ShapeFn) { + ShapeInferenceTestOp op("Split"); + op.input_tensors.resize(2); + + // No value for split_dim and no input. + TF_CHECK_OK(NodeDefBuilder("test", "Split") + .Input("split_dim", 0, DT_INT32) + .Input("value", 1, DT_FLOAT) + .Attr("num_split", 2) + .Finalize(&op.node_def)); + INFER_OK(op, "?;?", "?;?"); + // If the rank is known, we know the rank of each output. + INFER_OK(op, "?;[?,?]", "[?,?];[?,?]"); + + // split_dim is known. + Tensor split_dim = test::AsTensor({1, 2}); + op.input_tensors[0] = &split_dim; + INFER_ERROR("Input must be scalar but has rank 1", op, "[?];[?,?]"); + split_dim = test::AsScalar(1); + INFER_OK(op, "?;?", "?;?"); + INFER_OK(op, "?;[?,?]", "[d1_0,?];[d1_0,?]"); + INFER_OK(op, "?;[1,4]", "[d1_0,2];[d1_0,2]"); + INFER_OK(op, "?;[1,?]", "[d1_0,?];[d1_0,?]"); + INFER_ERROR("Dimension size must be divisible by 2 but is 5", op, "?;[1,5]"); +} + } // end namespace tensorflow diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index 360cf68fb43..03ada875112 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -858,6 +858,7 @@ REGISTER_OP("MaxPool") .Attr(GetConvnetDataFormatAttrString()) .Input("input: T") .Output("output: T") + .SetShapeFn(shape_inference::MaxPoolShape) .Doc(R"doc( Performs max pooling on the input. @@ -914,6 +915,11 @@ REGISTER_OP("MaxPoolWithArgmax") .Output("output: T") .Output("argmax: Targmax") .Attr("T: {float, half} = DT_FLOAT") + .SetShapeFn([](InferenceContext* c) { + TF_RETURN_IF_ERROR(shape_inference::MaxPoolShape(c)); + c->set_output(1, c->output(0)); + return Status::OK(); + }) .Doc(R"doc( Performs max pooling on the input and outputs both max values and indices. diff --git a/tensorflow/core/public/README.md b/tensorflow/core/public/README.md index cd1cefbb158..45767e1c8c8 100644 --- a/tensorflow/core/public/README.md +++ b/tensorflow/core/public/README.md @@ -21,7 +21,7 @@ Then: ```python import tensorflow as tf -with tf.Session("local"): +with tf.Session(): input1 = tf.constant(1.0, shape=[1, 1], name="input1") input2 = tf.constant(2.0, shape=[1, 1], name="input2") output = tf.matmul(input1, input2) diff --git a/tensorflow/g3doc/api_docs/python/client.md b/tensorflow/g3doc/api_docs/python/client.md index a66e7cef00f..1490bbaec70 100644 --- a/tensorflow/g3doc/api_docs/python/client.md +++ b/tensorflow/g3doc/api_docs/python/client.md @@ -373,8 +373,7 @@ the session constructor. * `target`: (Optional.) The execution engine to connect to. - Defaults to using an in-process engine. At present, no value - other than the empty string is supported. + Defaults to using an in-process engine. * `graph`: (Optional.) The `Graph` to be launched (described above). * `config`: (Optional) `ConfigProto` proto used to configure the session. diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.InteractiveSession.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.InteractiveSession.md index cdb5101815d..308a0a80b49 100644 --- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.InteractiveSession.md +++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard2/tf.InteractiveSession.md @@ -53,8 +53,7 @@ the session constructor. * `target`: (Optional.) The execution engine to connect to. - Defaults to using an in-process engine. At present, no value - other than the empty string is supported. + Defaults to using an in-process engine. * `graph`: (Optional.) The `Graph` to be launched (described above). * `config`: (Optional) `ConfigProto` proto used to configure the session. diff --git a/tensorflow/g3doc/resources/faq.md b/tensorflow/g3doc/resources/faq.md index 61437286545..6dc208f8dd6 100644 --- a/tensorflow/g3doc/resources/faq.md +++ b/tensorflow/g3doc/resources/faq.md @@ -147,7 +147,7 @@ graphs and running steps; we also have an experimental API for We would like to support more client languages, as determined by community interest. TensorFlow has a -[C-based client API](https://www.tensorflow.org/code/tensorflow/core/public/tensor_c_api.h) +[C-based client API](https://www.tensorflow.org/code/tensorflow/c/c_api.h) that makes it easy to build a client in many different languages. We invite contributions of new language bindings. diff --git a/tensorflow/g3doc/tutorials/index.md b/tensorflow/g3doc/tutorials/index.md index 292596837da..a489d977c8f 100644 --- a/tensorflow/g3doc/tutorials/index.md +++ b/tensorflow/g3doc/tutorials/index.md @@ -2,6 +2,10 @@ ## Basic Neural Networks +The first few Tensorflow tutorials guide you through training and testing a +simple neural network to classify handwritten digits from the MNIST database of +digit images. + ### MNIST For ML Beginners If you're new to machine learning, we recommend starting here. You'll learn @@ -27,13 +31,6 @@ example. [View Tutorial](../tutorials/mnist/tf/index.md) -### MNIST Data Download - -Details about downloading the MNIST handwritten digits data set. Exciting -stuff. - -[View Tutorial](../tutorials/mnist/download/index.md) - ## Easy ML with tf.contrib.learn diff --git a/tensorflow/g3doc/tutorials/leftnav_files b/tensorflow/g3doc/tutorials/leftnav_files index 09cd084b490..9c80a6c6e1b 100644 --- a/tensorflow/g3doc/tutorials/leftnav_files +++ b/tensorflow/g3doc/tutorials/leftnav_files @@ -2,7 +2,6 @@ mnist/beginners/index.md mnist/pros/index.md mnist/tf/index.md -mnist/download/index.md ### Easy ML with tf.contrib.learn tflearn/index.md linear/overview.md diff --git a/tensorflow/g3doc/tutorials/mnist/beginners/index.md b/tensorflow/g3doc/tutorials/mnist/beginners/index.md index 5d099c4bf2f..e3302db200c 100644 --- a/tensorflow/g3doc/tutorials/mnist/beginners/index.md +++ b/tensorflow/g3doc/tutorials/mnist/beginners/index.md @@ -1,11 +1,11 @@ # MNIST For ML Beginners *This tutorial is intended for readers who are new to both machine learning and -TensorFlow. If you already -know what MNIST is, and what softmax (multinomial logistic) regression is, -you might prefer this [faster paced tutorial](../pros/index.md). -Be sure to [install TensorFlow](../../../get_started/os_setup.md) before -starting either tutorial.* +TensorFlow. If you already know what MNIST is, and what softmax (multinomial +logistic) regression is, you might prefer this +[faster paced tutorial](../pros/index.md). Be sure to +[install TensorFlow](../../../get_started/os_setup.md) before starting either +tutorial.* When one learns how to program, there's a tradition that the first thing you do is print "Hello World." Just like programming has Hello World, machine learning @@ -33,21 +33,45 @@ important to understand the ideas behind it: both how TensorFlow works and the core machine learning concepts. Because of this, we are going to very carefully work through the code. +## About this tutorial + +This tutorial is an explanation, line by line, of what is happening in the +[mnist_softmax.py](https://www.tensorflow.org/code/tensorflow/examples/tutorials/mnist/mnist_softmax.py) code. + +You can use this tutorial in a few different ways, including: + +- Copy and paste each code snippet, line by line, into a Python environment as + you read through the explanations of each line. + +- Run the entire `mnist_softmax.py` Python file either before or after reading + through the explanations, and use this tutorial to understand the lines of + code that aren't clear to you. + +What we will accomplish in this tutorial: + +- Learn about the MNIST data and softmax regressions + +- Create a function that is a model for recognizing digits, based on looking at + every pixel in the image + +- Use Tensorflow to train the model to recognize digits by having it "look" at + thousands of examples (and run our first Tensorflow session to do so) + +- Check the model's accuracy with our test data + ## The MNIST Data The MNIST data is hosted on -[Yann LeCun's website](http://yann.lecun.com/exdb/mnist/). For your -convenience, we've included some python code to download and install the data -automatically. You can either download -[the code](https://www.tensorflow.org/code/tensorflow/examples/tutorials/mnist/input_data.py) -and import it as below, or simply copy and paste it in. +[Yann LeCun's website](http://yann.lecun.com/exdb/mnist/). If you are copying and +pasting in the code from this tutorial, start here with these two lines of code +which will download and read in the data automatically: ```python from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) ``` -The downloaded data is split into three parts, 55,000 data points of training +The MNIST data is split into three parts: 55,000 data points of training data (`mnist.train`), 10,000 points of test data (`mnist.test`), and 5,000 points of validation data (`mnist.validation`). This split is very important: it's essential in machine learning that we have separate data which we don't @@ -55,10 +79,10 @@ learn from so that we can make sure that what we've learned actually generalizes! As mentioned earlier, every MNIST data point has two parts: an image of a -handwritten digit and a corresponding label. We will call the images "xs" and -the labels "ys". Both the training set and test set contain xs and ys, for -example the training images are `mnist.train.images` and the train labels are -`mnist.train.labels`. +handwritten digit and a corresponding label. We'll call the images "x" +and the labels "y". Both the training set and test set contain images and their +corresponding labels; for example the training images are `mnist.train.images` +and the training labels are `mnist.train.labels`. Each image is 28 pixels by 28 pixels. We can interpret this as a big array of numbers: @@ -77,26 +101,26 @@ From this perspective, the MNIST images are just a bunch of points in a Flattening the data throws away information about the 2D structure of the image. Isn't that bad? Well, the best computer vision methods do exploit this structure, and we will in later tutorials. But the simple method we will be -using here, a softmax regression, won't. +using here, a softmax regression (defined below), won't. The result is that `mnist.train.images` is a tensor (an n-dimensional array) -with a shape of `[55000, 784]`. The first dimension indexes the images and the -second dimension indexes the pixels in each image. Each entry in the tensor is -the pixel intensity between 0 and 1, for a particular pixel in a particular -image. +with a shape of `[55000, 784]`. The first dimension is an index into the list +of images and the second dimension is the index for each pixel in each image. +East entry in the tensor is a pixel intensity between 0 and 1, for a particular +pixel in a particular image.
-The corresponding labels in MNIST are numbers between 0 and 9, describing -which digit a given image is of. -For the purposes of this tutorial, we're going to want our labels -as "one-hot vectors". A one-hot vector is a vector which is 0 in most -dimensions, and 1 in a single dimension. In this case, the \\(n\\)th digit will -be represented as a vector which is 1 in the \\(n\\)th dimensions. For example, -3 would be \\([0,0,0,1,0,0,0,0,0,0]\\). -Consequently, `mnist.train.labels` is a +Each image in MNIST has a corresponding label, a number between 0 and 9 +representing the digit drawn in the image. + +For the purposes of this tutorial, we're going to want our labels as "one-hot +vectors". A one-hot vector is a vector which is 0 in most dimensions, and 1 in a +single dimension. In this case, the \\(n\\)th digit will be represented as a +vector which is 1 in the \\(n\\)th dimensions. For example, 3 would be +\\([0,0,0,1,0,0,0,0,0,0]\\). Consequently, `mnist.train.labels` is a `[55000, 10]` array of floats.
@@ -107,24 +131,26 @@ We're now ready to actually make our model! ## Softmax Regressions -We know that every image in MNIST is a digit, whether it's a zero or a nine. We -want to be able to look at an image and give probabilities for it being each +We know that every image in MNIST is of a handwritten digit between zero and +nine. So there are only ten possible things that a given image can be. We want +to be able to look at an image and give the probabilities for it being each digit. For example, our model might look at a picture of a nine and be 80% sure it's a nine, but give a 5% chance to it being an eight (because of the top loop) -and a bit of probability to all the others because it isn't sure. +and a bit of probability to all the others because isn't 100% sure. This is a classic case where a softmax regression is a natural, simple model. If you want to assign probabilities to an object being one of several different -things, softmax is the thing to do. Even later on, when we train more -sophisticated models, the final step will be a layer of softmax. +things, softmax is the thing to do, because softmax gives us a list of values +between 0 and 1 that add up to 1. Even later on, when we train more sophisticated +models, the final step will be a layer of softmax. A softmax regression has two steps: first we add up the evidence of our input being in certain classes, and then we convert that evidence into probabilities. To tally up the evidence that a given image is in a particular class, we do a weighted sum of the pixel intensities. The weight is negative if that pixel -having a high intensity is evidence against the image being in that class, -and positive if it is evidence in favor. +having a high intensity is evidence against the image being in that class, and +positive if it is evidence in favor. The following diagram shows the weights one model learned for each of these classes. Red represents negative weights, while blue represents positive @@ -160,18 +186,16 @@ If you expand that equation out, you get: $$\text{softmax}(x)_i = \frac{\exp(x_i)}{\sum_j \exp(x_j)}$$ -But it's often more helpful to think of softmax the first way: -exponentiating its inputs and then normalizing them. -The exponentiation means that one more unit of evidence increases the weight -given to any hypothesis multiplicatively. -And conversely, having one less unit of evidence means that a -hypothesis gets a fraction of its earlier weight. No hypothesis ever has zero -or negative weight. Softmax then normalizes these weights, so that they add up -to one, forming a valid probability distribution. (To get more intuition about -the softmax function, check out the -[section](http://neuralnetworksanddeeplearning.com/chap3.html#softmax) -on it in Michael Nielsen's book, complete with an interactive visualization.) - +But it's often more helpful to think of softmax the first way: exponentiating +its inputs and then normalizing them. The exponentiation means that one more +unit of evidence increases the weight given to any hypothesis multiplicatively. +And conversely, having one less unit of evidence means that a hypothesis gets a +fraction of its earlier weight. No hypothesis ever has zero or negative +weight. Softmax then normalizes these weights, so that they add up to one, +forming a valid probability distribution. (To get more intuition about the +softmax function, check out the +[section](http://neuralnetworksanddeeplearning.com/chap3.html#softmax) on it in +Michael Nielsen's book, complete with an interactive visualization.) You can picture our softmax regression as looking something like the following, although with a lot more \\(x\\)s. For each output, we compute a weighted sum of @@ -199,26 +223,26 @@ More compactly, we can just write: $$y = \text{softmax}(Wx + b)$$ +Now let's turn that into something that Tensorflow can use. ## Implementing the Regression To do efficient numerical computing in Python, we typically use libraries like -NumPy that do expensive operations such as matrix multiplication outside Python, -using highly efficient code implemented in another language. -Unfortunately, there can still be a lot of overhead from switching back to -Python every operation. This overhead is especially bad if you want to run -computations on GPUs or in a distributed manner, where there can be a high cost -to transferring data. +[NumPy](http://www.numpy.org/) that do expensive operations such as matrix +multiplication outside Python, using highly efficient code implemented in +another language. Unfortunately, there can still be a lot of overhead from +switching back to Python every operation. This overhead is especially bad if you +want to run computations on GPUs or in a distributed manner, where there can be +a high cost to transferring data. -TensorFlow also does its heavy lifting outside python, -but it takes things a step further to avoid this overhead. -Instead of running a single expensive operation independently -from Python, TensorFlow lets us describe a graph of interacting operations that -run entirely outside Python. (Approaches like this can be seen in a few -machine learning libraries.) +TensorFlow also does its heavy lifting outside Python, but it takes things a +step further to avoid this overhead. Instead of running a single expensive +operation independently from Python, TensorFlow lets us describe a graph of +interacting operations that run entirely outside Python. (Approaches like this +can be seen in a few machine learning libraries.) -To use TensorFlow, we need to import it. +To use TensorFlow, first we need to import it. ```python import tensorflow as tf @@ -239,11 +263,10 @@ this as a 2-D tensor of floating-point numbers, with a shape `[None, 784]`. We also need the weights and biases for our model. We could imagine treating these like additional inputs, but TensorFlow has an even better way to handle -it: `Variable`. -A `Variable` is a modifiable tensor that lives in TensorFlow's graph of -interacting -operations. It can be used and even modified by the computation. For machine -learning applications, one generally has the model parameters be `Variable`s. +it: `Variable`. A `Variable` is a modifiable tensor that lives in TensorFlow's +graph of interacting operations. It can be used and even modified by the +computation. For machine learning applications, one generally has the model +parameters be `Variable`s. ```python W = tf.Variable(tf.zeros([784, 10])) @@ -260,17 +283,16 @@ Notice that `W` has a shape of [784, 10] because we want to multiply the evidence for the difference classes. `b` has a shape of [10] so we can add it to the output. -We can now implement our model. It only takes one line! +We can now implement our model. It only takes one line to define it! ```python y = tf.nn.softmax(tf.matmul(x, W) + b) ``` First, we multiply `x` by `W` with the expression `tf.matmul(x, W)`. This is -flipped from when we multiplied them in our equation, where we had \\(Wx\\), as a -small trick -to deal with `x` being a 2D tensor with multiple inputs. We then add `b`, and -finally apply `tf.nn.softmax`. +flipped from when we multiplied them in our equation, where we had \\(Wx\\), as +a small trick to deal with `x` being a 2D tensor with multiple inputs. We then +add `b`, and finally apply `tf.nn.softmax`. That's it. It only took us one line to define our model, after a couple short lines of setup. That isn't because TensorFlow is designed to make a softmax @@ -282,79 +304,80 @@ your computer's CPU, GPUs, and even phones! ## Training -In order to train our model, we need to define what it means for the model to -be good. Well, actually, in machine learning we typically define what it means -for a model to be bad, called the cost or loss, and then try to minimize how bad -it is. But the two are equivalent. +In order to train our model, we need to define what it means for the model to be +good. Well, actually, in machine learning we typically define what it means for +a model to be bad. We call this the cost, or the loss, and it represents how far +off our model is from our desired outcome. We try to minimize that error, and +the smaller the error margin, the better our model is. -One very common, very nice cost function is "cross-entropy." Surprisingly, -cross-entropy arises from thinking about information compressing codes in -information theory but it winds up being an important idea in lots of areas, -from gambling to machine learning. It's defined: +One very common, very nice function to determine the loss of a model is called +"cross-entropy." Cross-entropy arises from thinking about information +compressing codes in information theory but it winds up being an important idea +in lots of areas, from gambling to machine learning. It's defined as: $$H_{y'}(y) = -\sum_i y'_i \log(y_i)$$ Where \\(y\\) is our predicted probability distribution, and \\(y'\\) is the true -distribution (the one-hot vector we'll input). In some rough sense, the +distribution (the one-hot vector with the digit labels). In some rough sense, the cross-entropy is measuring how inefficient our predictions are for describing the truth. Going into more detail about cross-entropy is beyond the scope of this tutorial, but it's well worth [understanding](http://colah.github.io/posts/2015-09-Visual-Information/). -To implement cross-entropy we need to first add a new placeholder to input -the correct answers: +To implement cross-entropy we need to first add a new placeholder to input the +correct answers: ```python y_ = tf.placeholder(tf.float32, [None, 10]) ``` -Then we can implement the cross-entropy, \\(-\sum y'\log(y)\\): +Then we can implement the cross-entropy function, \\(-\sum y'\log(y)\\): ```python cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1])) ``` First, `tf.log` computes the logarithm of each element of `y`. Next, we multiply -each element of `y_` with the corresponding element of `tf.log(y)`. Then -`tf.reduce_sum` adds the elements in the second dimension of y, due to the -`reduction_indices=[1]` parameter. Finally, `tf.reduce_mean` computes the mean +each element of `y_` with the corresponding element of `tf.log(y)`. Then +`tf.reduce_sum` adds the elements in the second dimension of y, due to the +`reduction_indices=[1]` parameter. Finally, `tf.reduce_mean` computes the mean over all the examples in the batch. Now that we know what we want our model to do, it's very easy to have TensorFlow -train it to do so. -Because TensorFlow knows the entire graph of your computations, it -can automatically use the [backpropagation -algorithm](http://colah.github.io/posts/2015-08-Backprop/) -to efficiently determine how your variables affect the cost you ask it to +train it to do so. Because TensorFlow knows the entire graph of your +computations, it can automatically use the +[backpropagation algorithm](http://colah.github.io/posts/2015-08-Backprop/) to +efficiently determine how your variables affect the loss you ask it to minimize. Then it can apply your choice of optimization algorithm to modify the -variables and reduce the cost. +variables and reduce the loss. ```python train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) ``` -In this case, we ask TensorFlow to minimize `cross_entropy` using the gradient -descent algorithm with a learning rate of 0.5. Gradient descent is a simple -procedure, where TensorFlow simply shifts each variable a little bit in the -direction that reduces the cost. But TensorFlow also provides +In this case, we ask TensorFlow to minimize `cross_entropy` using the +[gradient descent algorithm](https://en.wikipedia.org/wiki/Gradient_descent) +with a learning rate of 0.5. Gradient descent is a simple procedure, where +TensorFlow simply shifts each variable a little bit in the direction that +reduces the cost. But TensorFlow also provides [many other optimization algorithms] (../../../api_docs/python/train.md#optimizers): using one is as simple as tweaking one line. -What TensorFlow actually does here, behind the scenes, is it adds new operations -to your graph which -implement backpropagation and gradient descent. Then it gives you back a -single operation which, when run, will do a step of gradient descent training, -slightly tweaking your variables to reduce the cost. +What TensorFlow actually does here, behind the scenes, is to add new operations +to your graph which implement backpropagation and gradient descent. Then it +gives you back a single operation which, when run, does a step of gradient +descent training, slightly tweaking your variables to reduce the loss. -Now we have our model set up to train. One last thing before we launch it, -we have to add an operation to initialize the variables we created: +Now we have our model set up to train. One last thing before we launch it, we +have to create an operation to initialize the variables we created. Note that +this defines the operation but does not run it yet: ```python init = tf.initialize_all_variables() ``` -We can now launch the model in a `Session`, and run the operation that +We can now launch the model in a `Session`, and now we run the operation that initializes the variables: ```python @@ -374,10 +397,10 @@ Each step of the loop, we get a "batch" of one hundred random data points from our training set. We run `train_step` feeding in the batches data to replace the `placeholder`s. -Using small batches of random data is called stochastic training -- in -this case, stochastic gradient descent. Ideally, we'd like to use all our data -for every step of training because that would give us a better sense of what -we should be doing, but that's expensive. So, instead, we use a different subset +Using small batches of random data is called stochastic training -- in this +case, stochastic gradient descent. Ideally, we'd like to use all our data for +every step of training because that would give us a better sense of what we +should be doing, but that's expensive. So, instead, we use a different subset every time. Doing this is cheap and has much of the same benefit. @@ -414,12 +437,12 @@ print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels} This should be about 92%. Is that good? Well, not really. In fact, it's pretty bad. This is because we're -using a very simple model. With some small changes, we can get to -97%. The best models can get to over 99.7% accuracy! (For more information, have -a look at this +using a very simple model. With some small changes, we can get to 97%. The best +models can get to over 99.7% accuracy! (For more information, have a look at +this [list of results](http://rodrigob.github.io/are_we_there_yet/build/classification_datasets_results.html).) What matters is that we learned from this model. Still, if you're feeling a bit -down about these results, check out [the next tutorial](../../../tutorials/mnist/pros/index.md) where we -do a lot better, and learn how to build more sophisticated models using -TensorFlow! +down about these results, check out +[the next tutorial](../../../tutorials/mnist/pros/index.md) where we do a lot +better, and learn how to build more sophisticated models using TensorFlow! diff --git a/tensorflow/g3doc/tutorials/mnist/download/index.md b/tensorflow/g3doc/tutorials/mnist/download/index.md deleted file mode 100644 index 16ff9e84227..00000000000 --- a/tensorflow/g3doc/tutorials/mnist/download/index.md +++ /dev/null @@ -1,85 +0,0 @@ -# MNIST Data Download - -Code: [tensorflow/examples/tutorials/mnist/](https://www.tensorflow.org/code/tensorflow/examples/tutorials/mnist/) - -The goal of this tutorial is to show how to download the dataset files required -for handwritten digit classification using the (classic) MNIST data set. - -## Tutorial Files - -This tutorial references the following files: - -File | Purpose ---- | --- -[`input_data.py`](https://www.tensorflow.org/code/tensorflow/examples/tutorials/mnist/input_data.py) | The code to download the MNIST dataset for training and evaluation. - -## Prepare the Data - -MNIST is a classic problem in machine learning. The problem is to look at -greyscale 28x28 pixel images of handwritten digits and determine which digit -the image represents, for all the digits from zero to nine. - -![MNIST Digits](../../../images/mnist_digits.png "MNIST Digits") - -For more information, refer to [Yann LeCun's MNIST page](http://yann.lecun.com/exdb/mnist/) -or [Chris Olah's visualizations of MNIST](http://colah.github.io/posts/2014-10-Visualizing-MNIST/). - -### Download - -[Yann LeCun's MNIST page](http://yann.lecun.com/exdb/mnist/) -also hosts the training and test data for download. - -File | Purpose ---- | --- -[`train-images-idx3-ubyte.gz`](http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz) | training set images - 55000 training images, 5000 validation images -[`train-labels-idx1-ubyte.gz`](http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz) | training set labels matching the images -[`t10k-images-idx3-ubyte.gz`](http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz) | test set images - 10000 images -[`t10k-labels-idx1-ubyte.gz`](http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz) | test set labels matching the images - -In the `input_data.py` file, the `maybe_download()` function will ensure these -files are downloaded into a local data folder for training. - -The folder name is specified in a flag variable at the top of the -`fully_connected_feed.py` file and may be changed to fit your needs. - -### Unpack and Reshape - -The files themselves are not in any standard image format and are manually -unpacked (following the instructions available at the website) by the -`extract_images()` and `extract_labels()` functions in `input_data.py`. - -The image data is extracted into a 2d tensor of: `[image index, pixel index]` -where each entry is the intensity value of a specific pixel in a specific -image, rescaled from `[0, 255]` to `[0, 1]`. The "image index" corresponds -to an image in the dataset, counting up from zero to the size of the dataset. -And the "pixel index" corresponds to a specific pixel in that image, ranging -from zero to the number of pixels in the image. - -The 60000 examples in the `train-*` files are then split into 55000 examples -for training and 5000 examples for validation. For all of the 28x28 -pixel greyscale images in the datasets the image size is 784 and so the output -tensor for the training set images is of shape `[55000, 784]`. - -The label data is extracted into a 1d tensor of: `[image index]` -with the class identifier for each example as the value. For the training set -labels, this would then be of shape `[55000]`. - -### DataSet Object - -The underlying code will download, unpack, and reshape images and labels for -the following datasets: - -Dataset | Purpose ---- | --- -`data_sets.train` | 55000 images and labels, for primary training. -`data_sets.validation` | 5000 images and labels, for iterative validation of training accuracy. -`data_sets.test` | 10000 images and labels, for final testing of trained accuracy. - -The `read_data_sets()` function will return a dictionary with a `DataSet` -instance for each of these three sets of data. The `DataSet.next_batch()` -method can be used to fetch a tuple consisting of `batch_size` lists of images -and labels to be fed into the running TensorFlow session. - -```python -images_feed, labels_feed = data_set.next_batch(FLAGS.batch_size) -``` diff --git a/tensorflow/g3doc/tutorials/mnist/pros/index.md b/tensorflow/g3doc/tutorials/mnist/pros/index.md index f0bb36220a6..d3a0af6e652 100644 --- a/tensorflow/g3doc/tutorials/mnist/pros/index.md +++ b/tensorflow/g3doc/tutorials/mnist/pros/index.md @@ -1,10 +1,9 @@ -# Deep MNIST for Experts +# Deep MNIST for Experts TensorFlow is a powerful library for doing large-scale numerical computation. One of the tasks at which it excels is implementing and training deep neural -networks. -In this tutorial we will learn the basic building blocks of a TensorFlow model -while constructing a deep convolutional MNIST classifier. +networks. In this tutorial we will learn the basic building blocks of a +TensorFlow model while constructing a deep convolutional MNIST classifier. *This introduction assumes familiarity with neural networks and the MNIST dataset. If you don't have @@ -12,6 +11,30 @@ a background with them, check out the [introduction for beginners](../beginners/index.md). Be sure to [install TensorFlow](../../../get_started/os_setup.md) before starting.* + +## About this tutorial + +The first part of this tutorial explains what is happening in the +[mnist_softmax.py](https://www.tensorflow.org/code/tensorflow/examples/tutorials/mnist/mnist_softmax.py) +code, which is a basic implementation of a Tensorflow model. The second part +shows some ways to improve the accuracy. + +You can copy and paste each code snippet from this tutorial into a Python +environment, or you can choose to just read through the code. + +What we will accomplish in this tutorial: + +- Create a softmax regression function that is a model for recognizing MNIST + digits, based on looking at every pixel in the image + +- Use Tensorflow to train the model to recognize digits by having it "look" at + thousands of examples (and run our first Tensorflow session to do so) + +- Check the model's accuracy with our test data + +- Build, train, and test a multilayer convolutional neural network to improve + the results + ## Setup Before we create our model, we will first load the MNIST dataset, and start a @@ -19,10 +42,8 @@ TensorFlow session. ### Load MNIST Data -For your convenience, we've included -[a script](https://www.tensorflow.org/code/tensorflow/examples/tutorials/mnist/input_data.py) -which will help you download and import the MNIST dataset. Run the following commands to create a -directory `'MNIST_data'` in the current folder, the data files will be stored inside that directory. +If you are copying and pasting in the code from this tutorial, start here with +these two lines of code which will download and read in the data automatically: ```python from tensorflow.examples.tutorials.mnist import input_data @@ -30,9 +51,8 @@ mnist = input_data.read_data_sets('MNIST_data', one_hot=True) ``` Here `mnist` is a lightweight class which stores the training, validation, and -testing sets as NumPy arrays. -It also provides a function for iterating through data minibatches, which we -will use below. +testing sets as NumPy arrays. It also provides a function for iterating through +data minibatches, which we will use below. ### Start TensorFlow InteractiveSession @@ -40,17 +60,15 @@ TensorFlow relies on a highly efficient C++ backend to do its computation. The connection to this backend is called a session. The common usage for TensorFlow programs is to first create a graph and then launch it in a session. -Here we instead use the convenient `InteractiveSession` class, which -makes TensorFlow more flexible about how you -structure your code. -It allows you to interleave operations which build a +Here we instead use the convenient `InteractiveSession` class, which makes +TensorFlow more flexible about how you structure your code. It allows you to +interleave operations which build a [computation graph](../../../get_started/basic_usage.md#the-computation-graph) -with ones that run the graph. -This is particularly convenient when working in interactive contexts like -IPython. -If you are not using an `InteractiveSession`, then you should build -the entire computation graph before starting a session and [launching the -graph](../../../get_started/basic_usage.md#launching-the-graph-in-a-session). +with ones that run the graph. This is particularly convenient when working in +interactive contexts like IPython. If you are not using an +`InteractiveSession`, then you should build the entire computation graph before +starting a session and +[launching the graph](../../../get_started/basic_usage.md#launching-the-graph-in-a-session). ```python import tensorflow as tf @@ -60,19 +78,18 @@ sess = tf.InteractiveSession() #### Computation Graph To do efficient numerical computing in Python, we typically use libraries like -NumPy that do expensive operations such as matrix multiplication outside Python, -using highly efficient code implemented in another language. -Unfortunately, there can still be a lot of overhead from switching back to -Python every operation. This overhead is especially bad if you want to run -computations on GPUs or in a distributed manner, where there can be a high cost -to transferring data. +[NumPy](http://www.numpy.org/) that do expensive operations such as matrix +multiplication outside Python, using highly efficient code implemented in +another language. Unfortunately, there can still be a lot of overhead from +switching back to Python every operation. This overhead is especially bad if you +want to run computations on GPUs or in a distributed manner, where there can be +a high cost to transferring data. -TensorFlow also does its heavy lifting outside Python, -but it takes things a step further to avoid this overhead. -Instead of running a single expensive operation independently -from Python, TensorFlow lets us describe a graph of interacting operations that -run entirely outside Python. -This approach is similar to that used in Theano or Torch. +TensorFlow also does its heavy lifting outside Python, but it takes things a +step further to avoid this overhead. Instead of running a single expensive +operation independently from Python, TensorFlow lets us describe a graph of +interacting operations that run entirely outside Python. This approach is +similar to that used in Theano or Torch. The role of the Python code is therefore to build this external computation graph, and to dictate which parts of the computation graph should be run. See @@ -102,59 +119,58 @@ Here `x` and `y_` aren't specific values. Rather, they are each a `placeholder` -- a value that we'll input when we ask TensorFlow to run a computation. The input images `x` will consist of a 2d tensor of floating point numbers. -Here we assign it a `shape` of `[None, 784]`, where `784` is the dimensionality of -a single flattened MNIST image, and `None` indicates that the first dimension, -corresponding to the batch size, can be of any size. -The target output classes `y_` will also consist of a 2d tensor, -where each row is a one-hot 10-dimensional vector indicating -which digit class the corresponding MNIST image belongs to. +Here we assign it a `shape` of `[None, 784]`, where `784` is the dimensionality +of a single flattened 28 by 28 pixel MNIST image, and `None` indicates that the +first dimension, corresponding to the batch size, can be of any size. The +target output classes `y_` will also consist of a 2d tensor, where each row is a +one-hot 10-dimensional vector indicating which digit class (zero through nine) +the corresponding MNIST image belongs to. The `shape` argument to `placeholder` is optional, but it allows TensorFlow to automatically catch bugs stemming from inconsistent tensor shapes. ### Variables -We now define the weights `W` and biases `b` for our model. We could imagine treating -these like additional inputs, but TensorFlow has an even better way to handle -them: `Variable`. -A `Variable` is a value that lives in TensorFlow's computation graph. -It can be used and even modified by the computation. In machine -learning applications, one generally has the model parameters be `Variable`s. +We now define the weights `W` and biases `b` for our model. We could imagine +treating these like additional inputs, but TensorFlow has an even better way to +handle them: `Variable`. A `Variable` is a value that lives in TensorFlow's +computation graph. It can be used and even modified by the computation. In +machine learning applications, one generally has the model parameters be +`Variable`s. ```python W = tf.Variable(tf.zeros([784,10])) b = tf.Variable(tf.zeros([10])) ``` -We pass the initial value for each parameter in the call to `tf.Variable`. -In this case, we initialize both `W` and `b` as tensors full of -zeros. `W` is a 784x10 matrix (because we have 784 input features -and 10 outputs) and `b` is a 10-dimensional vector (because we have 10 classes). +We pass the initial value for each parameter in the call to `tf.Variable`. In +this case, we initialize both `W` and `b` as tensors full of zeros. `W` is a +784x10 matrix (because we have 784 input features and 10 outputs) and `b` is a +10-dimensional vector (because we have 10 classes). Before `Variable`s can be used within a session, they must be initialized using -that session. -This step takes the initial values (in this case tensors full of zeros) that -have already been specified, and assigns them to each `Variable`. This can be -done for all `Variables` at once. +that session. This step takes the initial values (in this case tensors full of +zeros) that have already been specified, and assigns them to each +`Variable`. This can be done for all `Variables` at once: ```python sess.run(tf.initialize_all_variables()) ``` -### Predicted Class and Cost Function +### Predicted Class and Loss Function -We can now implement our regression model. It only takes one line! -We multiply the vectorized input images `x` by the weight matrix `W`, add -the bias `b`, and compute the softmax probabilities that are assigned to each -class. +We can now implement our regression model. It only takes one line! We multiply +the vectorized input images `x` by the weight matrix `W`, add the bias `b`, and +compute the softmax probabilities that are assigned to each class. ```python y = tf.nn.softmax(tf.matmul(x,W) + b) ``` -The cost function to be minimized during training can be specified just as -easily. Our cost function will be the cross-entropy between the target and the -model's prediction. +We can specify a loss function just as easily. Loss indicates how bad the +model's prediction was on a single example; we try to minimize that while +training across all the examples. Here, our loss function is the cross-entropy +between the target and the model's prediction: ```python cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1])) @@ -165,16 +181,14 @@ the average over these sums. ## Train the Model -Now that we have defined our model and training cost function, it is -straightforward to train using TensorFlow. -Because TensorFlow knows the entire computation graph, it -can use automatic differentiation to find the gradients of the cost with -respect to each of the variables. -TensorFlow has a variety of -[builtin optimization algorithms] -(../../../api_docs/python/train.md#optimizers). -For this example, we will use steepest gradient descent, with a step length of -0.5, to descend the cross entropy. +Now that we have defined our model and training loss function, it is +straightforward to train using TensorFlow. Because TensorFlow knows the entire +computation graph, it can use automatic differentiation to find the gradients of +the loss with respect to each of the variables. TensorFlow has a variety of +[built-in optimization algorithms] +(../../../api_docs/python/train.md#optimizers). For this example, we will use +steepest gradient descent, with a step length of 0.5, to descend the cross +entropy. ```python train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) @@ -184,9 +198,9 @@ What TensorFlow actually did in that single line was to add new operations to the computation graph. These operations included ones to compute gradients, compute parameter update steps, and apply update steps to the parameters. -The returned operation `train_step`, when run, will apply the gradient -descent updates to the parameters. Training the model can therefore be -accomplished by repeatedly running `train_step`. +The returned operation `train_step`, when run, will apply the gradient descent +updates to the parameters. Training the model can therefore be accomplished by +repeatedly running `train_step`. ```python for i in range(1000): @@ -194,22 +208,21 @@ for i in range(1000): train_step.run(feed_dict={x: batch[0], y_: batch[1]}) ``` -Each training iteration we load 100 training examples. We then run the +We load 100 training examples in each training iteration. We then run the `train_step` operation, using `feed_dict` to replace the `placeholder` tensors -`x` and `y_` with the training examples. -Note that you can replace any tensor in your computation graph using `feed_dict` --- it's not restricted to just `placeholder`s. +`x` and `y_` with the training examples. Note that you can replace any tensor +in your computation graph using `feed_dict` -- it's not restricted to just +`placeholder`s. ### Evaluate the Model How well did our model do? -First we'll figure out where we predicted the correct label. `tf.argmax` -is an extremely useful function which gives you the index of the highest entry -in a tensor along some axis. For example, `tf.argmax(y,1)` is the label our -model thinks is most likely for each input, while `tf.argmax(y_,1)` is the -true label. We can use `tf.equal` to check if our prediction matches the -truth. +First we'll figure out where we predicted the correct label. `tf.argmax` is an +extremely useful function which gives you the index of the highest entry in a +tensor along some axis. For example, `tf.argmax(y,1)` is the label our model +thinks is most likely for each input, while `tf.argmax(y_,1)` is the true +label. We can use `tf.equal` to check if our prediction matches the truth. ```python correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1)) @@ -241,10 +254,11 @@ to around 99.2% accuracy -- not state of the art, but respectable. To create this model, we're going to need to create a lot of weights and biases. One should generally initialize weights with a small amount of noise for -symmetry breaking, and to prevent 0 gradients. Since we're using ReLU neurons, -it is also good practice to initialize them with a slightly positive initial -bias to avoid "dead neurons". Instead of doing this repeatedly while we build -the model, let's create two handy functions to do it for us. +symmetry breaking, and to prevent 0 gradients. Since we're using +[ReLU](https://en.wikipedia.org/wiki/Rectifier_(neural_networks)) neurons, it is +also good practice to initialize them with a slightly positive initial bias to +avoid "dead neurons". Instead of doing this repeatedly while we build the model, +let's create two handy functions to do it for us. ```python def weight_variable(shape): @@ -362,13 +376,21 @@ y_conv=tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2) ### Train and Evaluate the Model -How well does this model do? -To train and evaluate it we will use code that is nearly identical to that for -the simple one layer SoftMax network above. -The differences are that: we will replace the steepest gradient descent -optimizer with the more sophisticated ADAM optimizer; we will include the -additional parameter `keep_prob` in `feed_dict` to control the dropout rate; -and we will add logging to every 100th iteration in the training process. +How well does this model do? To train and evaluate it we will use code that is +nearly identical to that for the simple one layer SoftMax network above. + +The differences are that: + +- We will replace the steepest gradient descent optimizer with the more + sophisticated ADAM optimizer. + +- We will include the additional parameter `keep_prob` in `feed_dict` to control + the dropout rate. + +- We will add logging to every 100th iteration in the training process. + +Feel free to go ahead and run this code, but it does 20,000 training iterations +and may take a while (possibly up to half an hour), depending on your processor. ```python cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y_conv), reduction_indices=[1])) diff --git a/tensorflow/g3doc/tutorials/mnist/tf/index.md b/tensorflow/g3doc/tutorials/mnist/tf/index.md index 9d83393dc0c..c7d5eec401b 100644 --- a/tensorflow/g3doc/tutorials/mnist/tf/index.md +++ b/tensorflow/g3doc/tutorials/mnist/tf/index.md @@ -58,9 +58,6 @@ Dataset | Purpose `data_sets.validation` | 5000 images and labels, for iterative validation of training accuracy. `data_sets.test` | 10000 images and labels, for final testing of trained accuracy. -For more information about the data, please read the [Download](../../../tutorials/mnist/download/index.md) -tutorial. - ### Inputs and Placeholders The `placeholder_inputs()` function creates two [`tf.placeholder`](../../../api_docs/python/io_ops.md#placeholder) diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index fc2abced83e..c5c5573f211 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -1065,6 +1065,8 @@ tf_cuda_library( ":construction_fails_op", ":numpy_lib", ":test_ops_kernels", + "//tensorflow/c:c_api", + "//tensorflow/c:tf_status_helper", "//tensorflow/core", "//tensorflow/core:all_kernels", "//tensorflow/core:direct_session", @@ -1104,6 +1106,9 @@ tf_py_wrap_cc( ":tf_session_helper", "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", "//tensorflow/core/distributed_runtime/rpc:grpc_session", + "//tensorflow/c:c_api", + "//tensorflow/c:checkpoint_reader", + "//tensorflow/c:tf_status_helper", "//tensorflow/core:lib", "//tensorflow/core/distributed_runtime:server_lib", "//util/python:python_headers", @@ -1265,7 +1270,7 @@ cuda_py_tests( "training/server_lib_test.py", "training/session_manager_test.py", "training/supervisor_test.py", - "training/saver_test.py", + "training/saver_large_variable_test.py", ], ), additional_deps = [ @@ -1273,14 +1278,15 @@ cuda_py_tests( ], ) -cuda_py_test( - name = "saver_test", +py_test( + name = "saver_large_variable_test", size = "small", - srcs = ["training/saver_test.py"], - additional_deps = [ - ":training", - ], + srcs = ["training/saver_large_variable_test.py"], + srcs_version = "PY2AND3", tags = ["notsan"], # http://b/30379628 + deps = [ + "//tensorflow:tensorflow_py", + ], ) cuda_py_test( @@ -1311,6 +1317,7 @@ py_tests( ["training/input_test.py"], ), additional_deps = [ + "//tensorflow:tensorflow_py", ":training", ], ) diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py index 2697d8b5bc2..d8c5737fbb2 100644 --- a/tensorflow/python/client/session.py +++ b/tensorflow/python/client/session.py @@ -1217,8 +1217,7 @@ class InteractiveSession(BaseSession): Args: target: (Optional.) The execution engine to connect to. - Defaults to using an in-process engine. At present, no value - other than the empty string is supported. + Defaults to using an in-process engine. graph: (Optional.) The `Graph` to be launched (described above). config: (Optional) `ConfigProto` proto used to configure the session. """ diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i index 659aa4a748c..87391fff68d 100644 --- a/tensorflow/python/client/tf_session.i +++ b/tensorflow/python/client/tf_session.i @@ -213,7 +213,7 @@ tensorflow::ImportNumpy(); reinterpret_cast($1.data), $1.length); } -// Include the functions from tensor_c_api.h, except TF_Run. +// Include the functions from c_api.h, except TF_Run. %ignoreall %unignore TF_Code; %unignore TF_Status; @@ -238,7 +238,7 @@ tensorflow::ImportNumpy(); %unignore TF_NewLibrary; %unignore TF_LoadLibrary; %unignore TF_GetOpList; -%include "tensorflow/core/public/tensor_c_api.h" +%include "tensorflow/c/c_api.h" %ignoreall %insert("python") %{ diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc index a203a7539af..68ef4920ced 100644 --- a/tensorflow/python/client/tf_session_helper.cc +++ b/tensorflow/python/client/tf_session_helper.cc @@ -17,13 +17,13 @@ limitations under the License. #include +#include "tensorflow/c/tf_status_helper.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/log_memory.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/graph/equal_graph_def.h" #include "tensorflow/core/lib/core/coding.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/tf_status_helper.h" namespace tensorflow { diff --git a/tensorflow/python/client/tf_session_helper.h b/tensorflow/python/client/tf_session_helper.h index 591b8774b64..83cab586d8b 100644 --- a/tensorflow/python/client/tf_session_helper.h +++ b/tensorflow/python/client/tf_session_helper.h @@ -19,11 +19,11 @@ limitations under the License. // Must be included first #include "tensorflow/python/lib/core/numpy.h" +#include "tensorflow/c/c_api.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" -#include "tensorflow/core/public/tensor_c_api.h" namespace tensorflow { diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index a8ae70e49b0..00372831df6 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -721,6 +721,15 @@ class ControlFlowTest(tf.test.TestCase): c = tf.while_loop(lambda x: x < 10, lambda x: x + 1, [c]) self.assertEqual(10, sess.run(c, {b: True})) + def testWhileWithControl_4(self): + with self.test_session() as sess: + b = tf.placeholder(tf.bool) + c = tf.constant(1) + x0 = tf.constant(0) + with tf.control_dependencies([b]): + r = tf.while_loop(lambda x: x < 10, lambda x: x + tf.identity(c), [x0]) + self.assertEqual(10, sess.run(r, {b: True})) + def testCondWhile_1(self): with self.test_session(): n = tf.convert_to_tensor(0, name="n") diff --git a/tensorflow/python/kernel_tests/init_ops_test.py b/tensorflow/python/kernel_tests/init_ops_test.py index d23e4777e07..7d5323e5cb9 100644 --- a/tensorflow/python/kernel_tests/init_ops_test.py +++ b/tensorflow/python/kernel_tests/init_ops_test.py @@ -276,42 +276,55 @@ class RangeTest(tf.test.TestCase): # TODO(vrv): move to sequence_ops_test? class LinSpaceTest(tf.test.TestCase): + def _gpu_modes(self): + if tf.test.is_gpu_available(): + return [False, True] + else: + return [False] + def _LinSpace(self, start, stop, num): - with self.test_session(): - tf_ans = tf.linspace(start, stop, num, name="linspace") - self.assertEqual([num], tf_ans.get_shape()) - return tf_ans.eval() + # NOTE(touts): Needs to pass a graph to get a new session each time. + with tf.Graph().as_default() as graph: + with self.test_session(graph=graph, force_gpu=self.force_gpu): + tf_ans = tf.linspace(start, stop, num, name="linspace") + self.assertEqual([num], tf_ans.get_shape()) + return tf_ans.eval() def testPositive(self): - self.assertArrayNear(self._LinSpace(1., 5., 1), np.array([1.]), 1e-5) - self.assertArrayNear(self._LinSpace(1., 5., 2), np.array([1., 5.]), 1e-5) - self.assertArrayNear(self._LinSpace(1., 5., 3), - np.array([1., 3., 5.]), 1e-5) - self.assertArrayNear(self._LinSpace(1., 5., 4), - np.array([1., 7. / 3., 11. / 3., 5.]), 1e-5) + for self.force_gpu in self._gpu_modes(): + self.assertArrayNear(self._LinSpace(1., 5., 1), np.array([1.]), 1e-5) + self.assertArrayNear(self._LinSpace(1., 5., 2), np.array([1., 5.]), 1e-5) + self.assertArrayNear(self._LinSpace(1., 5., 3), + np.array([1., 3., 5.]), 1e-5) + self.assertArrayNear(self._LinSpace(1., 5., 4), + np.array([1., 7. / 3., 11. / 3., 5.]), 1e-5) def testNegative(self): - self.assertArrayNear(self._LinSpace(-1., -5., 1), np.array([-1.]), 1e-5) - self.assertArrayNear(self._LinSpace(-1., -5., 2), - np.array([-1., -5.]), 1e-5) - self.assertArrayNear(self._LinSpace(-1., -5., 3), - np.array([-1., -3., -5.]), 1e-5) - self.assertArrayNear(self._LinSpace(-1., -5., 4), - np.array([-1., -7. / 3., -11. / 3., -5.]), 1e-5) + for self.force_gpu in self._gpu_modes(): + self.assertArrayNear(self._LinSpace(-1., -5., 1), np.array([-1.]), 1e-5) + self.assertArrayNear(self._LinSpace(-1., -5., 2), + np.array([-1., -5.]), 1e-5) + self.assertArrayNear(self._LinSpace(-1., -5., 3), + np.array([-1., -3., -5.]), 1e-5) + self.assertArrayNear(self._LinSpace(-1., -5., 4), + np.array([-1., -7. / 3., -11. / 3., -5.]), 1e-5) def testNegativeToPositive(self): - self.assertArrayNear(self._LinSpace(-1., 5., 1), np.array([-1.]), 1e-5) - self.assertArrayNear(self._LinSpace(-1., 5., 2), np.array([-1., 5.]), 1e-5) - self.assertArrayNear(self._LinSpace(-1., 5., 3), - np.array([-1., 2., 5.]), 1e-5) - self.assertArrayNear(self._LinSpace(-1., 5., 4), - np.array([-1., 1., 3., 5.]), 1e-5) + for self.force_gpu in self._gpu_modes(): + self.assertArrayNear(self._LinSpace(-1., 5., 1), np.array([-1.]), 1e-5) + self.assertArrayNear(self._LinSpace(-1., 5., 2), np.array([-1., 5.]), + 1e-5) + self.assertArrayNear(self._LinSpace(-1., 5., 3), + np.array([-1., 2., 5.]), 1e-5) + self.assertArrayNear(self._LinSpace(-1., 5., 4), + np.array([-1., 1., 3., 5.]), 1e-5) def testPoint(self): - self.assertArrayNear(self._LinSpace(5., 5., 1), np.array([5.]), 1e-5) - self.assertArrayNear(self._LinSpace(5., 5., 2), np.array([5.] * 2), 1e-5) - self.assertArrayNear(self._LinSpace(5., 5., 3), np.array([5.] * 3), 1e-5) - self.assertArrayNear(self._LinSpace(5., 5., 4), np.array([5.] * 4), 1e-5) + for self.force_gpu in self._gpu_modes(): + self.assertArrayNear(self._LinSpace(5., 5., 1), np.array([5.]), 1e-5) + self.assertArrayNear(self._LinSpace(5., 5., 2), np.array([5.] * 2), 1e-5) + self.assertArrayNear(self._LinSpace(5., 5., 3), np.array([5.] * 3), 1e-5) + self.assertArrayNear(self._LinSpace(5., 5., 4), np.array([5.] * 4), 1e-5) class DeviceTest(tf.test.TestCase): diff --git a/tensorflow/python/kernel_tests/one_hot_op_test.py b/tensorflow/python/kernel_tests/one_hot_op_test.py index 913b5190a8b..9a9dbfe8c92 100644 --- a/tensorflow/python/kernel_tests/one_hot_op_test.py +++ b/tensorflow/python/kernel_tests/one_hot_op_test.py @@ -233,27 +233,54 @@ class OneHotTest(tf.test.TestCase): dtype=dtype, truth=[truth[0].T, truth[1].T]) # Do not transpose the batch + def _testEmpty(self, dtype): + indices = np.zeros((0, 16), dtype=np.int64) + depth = 3 + on_value = np.asarray(1.0, dtype=dtype) + off_value = np.asarray(-1.0, dtype=dtype) + truth = np.empty((0, 16, 3), dtype=dtype) + + # axis == -1 + self._testBothOneHot( + indices=indices, + depth=depth, + on_value=on_value, + off_value=off_value, + dtype=dtype, + truth=truth) + + def testHalfBatch(self): + self._testEmpty(np.float16) + self._testBatch(np.float16) + self._testDefaultValuesBatch(np.float16) + self._testValueTypeBatch(np.float16) + def testFloatBatch(self): + self._testEmpty(np.float32) self._testBatch(np.float32) self._testDefaultValuesBatch(np.float32) self._testValueTypeBatch(np.float32) def testDoubleBatch(self): + self._testEmpty(np.float64) self._testBatch(np.float64) self._testDefaultValuesBatch(np.float64) self._testValueTypeBatch(np.float64) def testInt32Batch(self): + self._testEmpty(np.int32) self._testBatch(np.int32) self._testDefaultValuesBatch(np.int32) self._testValueTypeBatch(np.int32) def testInt64Batch(self): + self._testEmpty(np.int64) self._testBatch(np.int64) self._testDefaultValuesBatch(np.int64) self._testValueTypeBatch(np.int64) def testComplexBatch(self): + self._testEmpty(np.complex64) self._testBatch(np.complex64) # self._testDefaultValuesBatch(np.complex64) self._testValueTypeBatch(np.complex64) diff --git a/tensorflow/python/lib/io/file_io.i b/tensorflow/python/lib/io/file_io.i index 42fbb906638..12ab8566e96 100644 --- a/tensorflow/python/lib/io/file_io.i +++ b/tensorflow/python/lib/io/file_io.i @@ -17,13 +17,13 @@ limitations under the License. %include "tensorflow/python/platform/base.i" %{ +#include "tensorflow/c/tf_status_helper.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/io/match.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/protobuf/meta_graph.pb.h" -#include "tensorflow/core/util/tf_status_helper.h" %} %{ diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index 18a7a20c11d..eee3b3e2d4e 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -118,6 +118,8 @@ def _Identity(data, name=None): else: return array_ops.identity(data, name=name) else: + if not isinstance(data, (ops.IndexedSlices, ops.SparseTensor)): + raise TypeError("Type %s not supported" % type(data)) values = _Identity(data.values, name=name) indices = array_ops.identity(data.indices, name="indices") if isinstance(data, ops.IndexedSlices): @@ -125,11 +127,9 @@ def _Identity(data, name=None): if dense_shape is not None: dense_shape = array_ops.identity(dense_shape, name="dense_shape") return ops.IndexedSlices(values, indices, dense_shape) - elif isinstance(data, ops.SparseTensor): + else: dense_shape = array_ops.identity(data.shape, name="dense_shape") return ops.SparseTensor(indices, values, dense_shape) - else: - raise TypeError("Type %s not supported" % type(data)) def _NextIteration(data, name=None): @@ -140,6 +140,8 @@ def _NextIteration(data, name=None): else: return next_iteration(data, name=name) else: + if not isinstance(data, (ops.IndexedSlices, ops.SparseTensor)): + raise TypeError("Type %s not supported" % type(data)) values = _NextIteration(data.values, name=name) indices = next_iteration(data.indices, name="indices") if isinstance(data, ops.IndexedSlices): @@ -147,11 +149,9 @@ def _NextIteration(data, name=None): if dense_shape is not None: dense_shape = next_iteration(dense_shape, name="dense_shape") return ops.IndexedSlices(values, indices, dense_shape) - elif isinstance(data, ops.SparseTensor): + else: dense_shape = next_iteration(data.shape, name="dense_shape") return ops.SparseTensor(indices, values, dense_shape) - else: - raise TypeError("Type %s not supported" % type(data)) def _Enter(data, frame_name, is_constant=False, parallel_iterations=10, @@ -183,6 +183,8 @@ def _Enter(data, frame_name, is_constant=False, parallel_iterations=10, return enter(data, frame_name, is_constant, parallel_iterations, name=name) else: + if not isinstance(data, (ops.IndexedSlices, ops.SparseTensor)): + raise TypeError("Type %s not supported" % type(data)) values = _Enter(data.values, frame_name, is_constant, parallel_iterations, name=name) indices = enter(data.indices, frame_name, is_constant, @@ -193,12 +195,10 @@ def _Enter(data, frame_name, is_constant=False, parallel_iterations=10, dense_shape = enter(dense_shape, frame_name, is_constant, parallel_iterations, name="dense_shape") return ops.IndexedSlices(values, indices, dense_shape) - elif isinstance(data, ops.SparseTensor): + else: dense_shape = enter(data.shape, frame_name, is_constant, parallel_iterations, name="dense_shape") return ops.SparseTensor(indices, values, dense_shape) - else: - raise TypeError("Type %s not supported" % type(data)) def exit(data, name=None): @@ -220,6 +220,8 @@ def exit(data, name=None): else: return gen_control_flow_ops._exit(data, name) else: + if not isinstance(data, (ops.IndexedSlices, ops.SparseTensor)): + raise TypeError("Type %s not supported" % type(data)) values = exit(data.values, name=name) indices = gen_control_flow_ops._exit(data.indices, name="indices") if isinstance(data, ops.IndexedSlices): @@ -227,11 +229,9 @@ def exit(data, name=None): if dense_shape is not None: dense_shape = gen_control_flow_ops._exit(dense_shape, name) return ops.IndexedSlices(values, indices, dense_shape) - elif isinstance(data, ops.SparseTensor): + else: dense_shape = gen_control_flow_ops._exit(data.shape, name) return ops.SparseTensor(indices, values, dense_shape) - else: - raise TypeError("Type %s not supported" % type(data)) def switch(data, pred, dtype=None, name=None): @@ -1493,8 +1493,15 @@ class WhileContext(ControlFlowContext): self._AddOpInternal(op) def _AddOpInternal(self, op): - """Add `op` to the current context.""" + """Add `op` to the current context. + + In the case that op has only external data inputs, we remove all of its + external control inputs so all its inputs are in the same while loop + context. This is valid because op now has an Enter input that has all + the right control dependency. + """ if not op.inputs: + # Remove any external control dependency on this op control_inputs = [x for x in op.control_inputs if x._get_control_flow_context() == self] if len(control_inputs) != len(op.control_inputs): @@ -1508,12 +1515,22 @@ class WhileContext(ControlFlowContext): for x in op.outputs: self._values.add(x.name) else: + has_internal_data_input = False for index in range(len(op.inputs)): x = op.inputs[index] self.AddValue(x) real_x = self._external_values.get(x.name) if real_x is not None: op._update_input(index, real_x) + else: + has_internal_data_input = True + if not has_internal_data_input: + # Remove any external control dependency on this op + control_inputs = [x for x in op.control_inputs + if x._get_control_flow_context() == self] + if len(control_inputs) != len(op.control_inputs): + del op.control_inputs[:] + op._add_control_inputs(control_inputs) # Add a control dependency to prevent loop invariants from # enabling ops that should not be executed. self._MaybeAddControlDependency(op) @@ -1879,6 +1896,8 @@ class WhileContext(ControlFlowContext): if isinstance(e, ops.Tensor): xs = [e] else: + if not isinstance(e, (ops.IndexedSlices, ops.SparseTensor)): + raise TypeError("Type %s not supported" % type(e)) xs = [e.values, e.indices] shape = e.dense_shape if isinstance(e, ops.IndexedSlices) else e.shape if shape is not None: diff --git a/tensorflow/python/training/saver_large_variable_test.py b/tensorflow/python/training/saver_large_variable_test.py new file mode 100644 index 00000000000..40f0a47e430 --- /dev/null +++ b/tensorflow/python/training/saver_large_variable_test.py @@ -0,0 +1,49 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +"""Tests for tensorflow.python.training.saver.py.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import tensorflow as tf + + +class SaverLargeVariableTest(tf.test.TestCase): + + # NOTE: This is in a separate file from saver_test.py because the + # large allocations do not play well with TSAN, and cause flaky + # failures. + def testLargeVariable(self): + save_path = os.path.join(self.get_temp_dir(), "large_variable") + with tf.Session("", graph=tf.Graph()) as sess: + # Declare a variable that is exactly 2GB. This should fail, + # because a serialized checkpoint includes other header + # metadata. + with tf.device("/cpu:0"): + var = tf.Variable( + tf.constant(False, shape=[2, 1024, 1024, 1024], dtype=tf.bool)) + save = tf.train.Saver({var.op.name: var}) + var.initializer.run() + with self.assertRaisesRegexp( + tf.errors.InvalidArgumentError, + "Tensor slice is too large to serialize"): + save.save(sess, save_path) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py index 629ba4bbc14..7eb1e7e519d 100644 --- a/tensorflow/python/training/saver_test.py +++ b/tensorflow/python/training/saver_test.py @@ -287,22 +287,6 @@ class SaverTest(tf.test.TestCase): expected_save_path = "%s-%d" % (save_path, global_step_int) self.assertEqual(expected_save_path, val) - def testLargeVariable(self): - save_path = os.path.join(self.get_temp_dir(), "large_variable") - with tf.Session("", graph=tf.Graph()) as sess: - # Declare a variable that is exactly 2GB. This should fail, - # because a serialized checkpoint includes other header - # metadata. - with tf.device("/cpu:0"): - var = tf.Variable( - tf.constant(False, shape=[2, 1024, 1024, 1024], dtype=tf.bool)) - save = tf.train.Saver({var.op.name: var}) - var.initializer.run() - with self.assertRaisesRegexp( - tf.errors.InvalidArgumentError, - "Tensor slice is too large to serialize"): - save.save(sess, save_path) - class SaveRestoreShardedTest(tf.test.TestCase): diff --git a/tensorflow/python/training/server_lib.i b/tensorflow/python/training/server_lib.i index 6f5f3d4fa56..94250304f85 100644 --- a/tensorflow/python/training/server_lib.i +++ b/tensorflow/python/training/server_lib.i @@ -58,9 +58,9 @@ limitations under the License. } %{ +#include "tensorflow/c/tf_status_helper.h" #include "tensorflow/core/distributed_runtime/server_lib.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/util/tf_status_helper.h" using tensorflow::ServerDef; diff --git a/tensorflow/python/util/py_checkpoint_reader.i b/tensorflow/python/util/py_checkpoint_reader.i index 39e14924dc3..acec8d03ab1 100644 --- a/tensorflow/python/util/py_checkpoint_reader.i +++ b/tensorflow/python/util/py_checkpoint_reader.i @@ -17,8 +17,8 @@ limitations under the License. %include "tensorflow/python/platform/base.i" %{ +#include "tensorflow/c/checkpoint_reader.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/util/checkpoint_reader.h" #include "tensorflow/python/lib/core/py_func.h" %} @@ -126,5 +126,5 @@ def NewCheckpointReader(filepattern): return CheckpointReader(compat.as_bytes(filepattern), status) %} -%include "tensorflow/core/util/checkpoint_reader.h" +%include "tensorflow/c/checkpoint_reader.h" %unignoreall diff --git a/tensorflow/tools/ci_build/update_version.sh b/tensorflow/tools/ci_build/update_version.sh index 2e8a4a57bc0..1d1e492ef87 100755 --- a/tensorflow/tools/ci_build/update_version.sh +++ b/tensorflow/tools/ci_build/update_version.sh @@ -101,7 +101,7 @@ OS_SETUP="${TF_SRC_DIR}/g3doc/get_started/os_setup.md" check_existence file "${OS_SETUP}" sed -i -r -e "s/(.*pip[0-9]* install .*tensorflow-)([0-9]+\.[0-9]+\.[[:alnum:]]+)(-.*\.whl)/\1${MAJOR}.${MINOR}.${PATCH}\3/g" "${OS_SETUP}" - +sed -i -r -e "s/(.*export TF_BINARY_URL.*tensorflow-)([0-9]+\.[0-9]+\.[[:alnum:]]+)(-.*\.whl)/\1${MAJOR}.${MINOR}.${PATCH}\3/g" "${OS_SETUP}" sed -i -r -e "s/(.*\(e\.g\..*[^0-9])([0-9]+\.[0-9]+\.[[:alnum:]]+)(-gpu.*)/\1${MAJOR}.${MINOR}.${PATCH}\3/g" "${OS_SETUP}"