commit
88d9bc16d6
tensorflow
BUILD
c
BUILDc_api.ccc_api.hc_api_test.cccheckpoint_reader.cccheckpoint_reader.htf_status_helper.cctf_status_helper.h
contrib
distributions
framework
layers/python/layers
learn/python/learn
makefile
quantization/ops
slim
core
g3doc
api_docs/python
resources
tutorials
python
tools/ci_build
@ -71,6 +71,7 @@ filegroup(
|
||||
name = "all_opensource_files",
|
||||
data = [
|
||||
":all_files",
|
||||
"//tensorflow/c:all_files",
|
||||
"//tensorflow/cc:all_files",
|
||||
"//tensorflow/contrib:all_files",
|
||||
"//tensorflow/contrib/copy_graph:all_files",
|
||||
|
95
tensorflow/c/BUILD
Normal file
95
tensorflow/c/BUILD
Normal file
@ -0,0 +1,95 @@
|
||||
# Description:
|
||||
# C API for TensorFlow, for use by client language bindings.
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_cc_test",
|
||||
"tf_cuda_library",
|
||||
)
|
||||
|
||||
# For platform specific build config
|
||||
load(
|
||||
"//tensorflow/core:platform/default/build_config.bzl",
|
||||
"tf_kernel_tests_linkstatic",
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Public targets
|
||||
|
||||
tf_cuda_library(
|
||||
name = "c_api",
|
||||
srcs = ["c_api.cc"],
|
||||
hdrs = ["c_api.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "tf_status_helper",
|
||||
srcs = ["tf_status_helper.cc"],
|
||||
hdrs = ["tf_status_helper.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":c_api",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "checkpoint_reader",
|
||||
srcs = ["checkpoint_reader.cc"],
|
||||
hdrs = ["checkpoint_reader.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":tf_status_helper",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Tests
|
||||
|
||||
tf_cc_test(
|
||||
name = "c_api_test",
|
||||
size = "small",
|
||||
linkopts = select({
|
||||
"//tensorflow:darwin": ["-headerpad_max_install_names"],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
linkstatic = tf_kernel_tests_linkstatic(),
|
||||
deps = [
|
||||
":c_api",
|
||||
"//tensorflow/core:direct_session",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:proto_text",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/kernels:array",
|
||||
"//tensorflow/core/kernels:math",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Google-internal targets.
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/public/tensor_c_api.h"
|
||||
#include "tensorflow/c/c_api.h"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
@ -13,9 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// TODO(jeff,sanjay): Rename to tensorflow/public/c_api.h
|
||||
#ifndef TENSORFLOW_PUBLIC_TENSOR_C_API_H_
|
||||
#define TENSORFLOW_PUBLIC_TENSOR_C_API_H_
|
||||
#ifndef TENSORFLOW_C_C_API_H_
|
||||
#define TENSORFLOW_C_C_API_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
@ -699,4 +698,4 @@ extern TF_Buffer TF_GetOpList(TF_Library* lib_handle);
|
||||
} /* end extern "C" */
|
||||
#endif
|
||||
|
||||
#endif // TENSORFLOW_PUBLIC_TENSOR_C_API_H_
|
||||
#endif // TENSORFLOW_C_C_API_H_
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/public/tensor_c_api.h"
|
||||
#include "tensorflow/c/c_api.h"
|
||||
|
||||
#include <vector>
|
||||
#include "tensorflow/core/framework/graph.pb_text.h"
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/util/checkpoint_reader.h"
|
||||
#include "tensorflow/c/checkpoint_reader.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
@ -13,14 +13,14 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_UTIL_CHECKPOINT_READER_H
|
||||
#define TENSORFLOW_CORE_UTIL_CHECKPOINT_READER_H
|
||||
#ifndef TENSORFLOW_C_CHECKPOINT_READER_H
|
||||
#define TENSORFLOW_C_CHECKPOINT_READER_H
|
||||
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/tensor_slice_reader.h"
|
||||
#include "tensorflow/core/util/tf_status_helper.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -60,4 +60,4 @@ class CheckpointReader {
|
||||
} // namespace checkpoint
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_UTIL_CHECKPOINT_READER_H
|
||||
#endif // TENSORFLOW_C_CHECKPOINT_READER_H
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/util/tf_status_helper.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
@ -13,11 +13,11 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_UTIL_TF_STATUS_HELPER_H
|
||||
#define TENSORFLOW_CORE_UTIL_TF_STATUS_HELPER_H
|
||||
#ifndef TENSORFLOW_C_TF_STATUS_HELPER_H
|
||||
#define TENSORFLOW_C_TF_STATUS_HELPER_H
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/public/tensor_c_api.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -26,4 +26,4 @@ void Set_TF_Status_from_Status(TF_Status* tf_status, const Status& status);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_UTIL_TF_STATUS_HELPER_H
|
||||
#endif // TENSORFLOW_C_TF_STATUS_HELPER_H
|
@ -263,6 +263,28 @@ cuda_py_tests(
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_tests(
|
||||
name = "shape_test",
|
||||
size = "small",
|
||||
srcs = ["python/kernel_tests/shape_test.py"],
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_tests(
|
||||
name = "bijector_test",
|
||||
size = "small",
|
||||
srcs = ["python/kernel_tests/bijector_test.py"],
|
||||
additional_deps = [
|
||||
":distributions_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
|
@ -0,0 +1,67 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for Bijector."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.contrib.distributions.python.ops.bijector import _Exp # pylint: disable=line-too-long
|
||||
from tensorflow.contrib.distributions.python.ops.bijector import _Identity # pylint: disable=line-too-long
|
||||
from tensorflow.contrib.distributions.python.ops.shape import _ShapeUtil # pylint: disable=line-too-long
|
||||
|
||||
|
||||
class IdentityBijectorTest(tf.test.TestCase):
|
||||
"""Tests the correctness of the Y = g(X) = X transformation."""
|
||||
|
||||
def testBijector(self):
|
||||
with self.test_session():
|
||||
bijector = _Identity(_ShapeUtil(batch_ndims=1, event_ndims=1))
|
||||
self.assertEqual(bijector.name, 'Identity')
|
||||
x = [[[0.], [1]]]
|
||||
self.assertAllEqual(bijector.forward(x).eval(), x)
|
||||
self.assertAllEqual(bijector.inverse(x).eval(), x)
|
||||
self.assertAllEqual(bijector.inverse_log_det_jacobian(x).eval(),
|
||||
[[0., 0]])
|
||||
rev, jac = bijector.inverse_and_inverse_log_det_jacobian(x)
|
||||
self.assertAllEqual(rev.eval(), x)
|
||||
self.assertAllEqual(jac.eval(), [[0., 0]])
|
||||
|
||||
|
||||
class ExpBijectorTest(tf.test.TestCase):
|
||||
"""Tests the correctness of the Y = g(X) = exp(X) transformation."""
|
||||
|
||||
def testBijector(self):
|
||||
with self.test_session():
|
||||
bijector = _Exp(_ShapeUtil(batch_ndims=1, event_ndims=1))
|
||||
self.assertEqual(bijector.name, 'Exp')
|
||||
x = [[[1.], [2]]]
|
||||
self.assertAllClose(bijector.forward(x).eval(),
|
||||
[[[math.exp(1.)], [math.exp(2.)]]])
|
||||
self.assertAllClose(bijector.inverse(x).eval(),
|
||||
[[[math.log(1.)], [math.log(2.)]]])
|
||||
self.assertAllClose(bijector.inverse_log_det_jacobian(x).eval(),
|
||||
[[0., -math.log(2.)]])
|
||||
rev, jac = bijector.inverse_and_inverse_log_det_jacobian(x)
|
||||
self.assertAllClose(rev.eval(), [[[math.log(1.)], [math.log(2.)]]])
|
||||
self.assertAllClose(jac.eval(), [[0., -math.log(2.)]])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
@ -0,0 +1,165 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for ShapeUtil."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow.contrib.distributions.python.ops.shape import _ShapeUtil # pylint: disable=line-too-long
|
||||
|
||||
|
||||
class ShapeUtilTest(tf.test.TestCase):
|
||||
|
||||
def testShapeUtilGetNdims(self):
|
||||
with self.test_session():
|
||||
shaper = _ShapeUtil(batch_ndims=0, event_ndims=0)
|
||||
x = 1
|
||||
self.assertEqual(shaper.get_sample_ndims(x), 0)
|
||||
self.assertEqual(shaper.batch_ndims, 0)
|
||||
self.assertEqual(shaper.event_ndims, 0)
|
||||
|
||||
shaper = _ShapeUtil(batch_ndims=1, event_ndims=1)
|
||||
x = [[[0., 1, 2], [3, 4, 5]]]
|
||||
self.assertAllEqual(shaper.get_ndims(x), 3)
|
||||
self.assertEqual(shaper.get_sample_ndims(x), 1)
|
||||
self.assertEqual(shaper.batch_ndims, 1)
|
||||
self.assertEqual(shaper.event_ndims, 1)
|
||||
|
||||
x += [[[6, 7, 8], [9, 10, 11]]]
|
||||
self.assertAllEqual(shaper.get_ndims(x), 3)
|
||||
self.assertEqual(shaper.get_sample_ndims(x), 1)
|
||||
self.assertEqual(shaper.batch_ndims, 1)
|
||||
self.assertEqual(shaper.event_ndims, 1)
|
||||
|
||||
# Test ndims functions work, even despite unfed Tensors.
|
||||
y = tf.placeholder(tf.float32, shape=(1024, None, 1024))
|
||||
self.assertAllEqual(shaper.get_ndims(y), 3)
|
||||
self.assertEqual(shaper.get_sample_ndims(y), 1)
|
||||
self.assertEqual(shaper.batch_ndims, 1)
|
||||
self.assertEqual(shaper.event_ndims, 1)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
y = tf.placeholder(tf.float32)
|
||||
shaper.get_ndims(y)
|
||||
|
||||
def testShapeUtilGetDims(self):
|
||||
with self.test_session():
|
||||
shaper = _ShapeUtil(batch_ndims=0, event_ndims=0)
|
||||
with self.assertRaises(ValueError):
|
||||
y = tf.placeholder(tf.float32)
|
||||
shaper.get_sample_dims(y)
|
||||
with self.assertRaises(ValueError):
|
||||
y = tf.placeholder(tf.float32)
|
||||
shaper.get_batch_dims(y)
|
||||
with self.assertRaises(ValueError):
|
||||
y = tf.placeholder(tf.float32)
|
||||
shaper.get_event_dims(y)
|
||||
|
||||
shaper = _ShapeUtil(batch_ndims=0, event_ndims=0)
|
||||
x = 1
|
||||
self.assertAllEqual(shaper.get_sample_dims(x), [])
|
||||
self.assertAllEqual(shaper.get_batch_dims(x), [])
|
||||
self.assertAllEqual(shaper.get_event_dims(x), [])
|
||||
self.assertAllEqual(shaper.get_dims(x, sample=False), [])
|
||||
|
||||
shaper = _ShapeUtil(batch_ndims=1, event_ndims=2)
|
||||
x = [[[[0., 1], [2, 4]]]]
|
||||
self.assertAllEqual(shaper.get_sample_dims(x), [0])
|
||||
self.assertAllEqual(shaper.get_batch_dims(x), [1])
|
||||
self.assertAllEqual(shaper.get_event_dims(x), [2, 3])
|
||||
self.assertAllEqual(shaper.get_dims(x, sample=False), [1, 2, 3])
|
||||
|
||||
x += x
|
||||
self.assertAllEqual(shaper.get_sample_dims(x), [0])
|
||||
self.assertAllEqual(shaper.get_batch_dims(x), [1])
|
||||
self.assertAllEqual(shaper.get_event_dims(x), [2, 3])
|
||||
self.assertAllEqual(shaper.get_dims(x, sample=False), [1, 2, 3])
|
||||
|
||||
# Test dims functions work, despite unfed Tensors.
|
||||
y = tf.placeholder(tf.float32, shape=(1024, None, 5, 5))
|
||||
self.assertAllEqual(shaper.get_sample_dims(y), [0])
|
||||
self.assertAllEqual(shaper.get_batch_dims(y), [1])
|
||||
self.assertAllEqual(shaper.get_event_dims(y), [2, 3])
|
||||
|
||||
def testShapeUtilGetShape(self):
|
||||
with self.test_session() as sess:
|
||||
shaper = _ShapeUtil(batch_ndims=0, event_ndims=0)
|
||||
with self.assertRaises(ValueError):
|
||||
y = tf.placeholder(tf.float32)
|
||||
shaper.get_sample_shape(y)
|
||||
with self.assertRaises(ValueError):
|
||||
y = tf.placeholder(tf.float32)
|
||||
shaper.get_batch_shape(y)
|
||||
with self.assertRaises(ValueError):
|
||||
y = tf.placeholder(tf.float32)
|
||||
shaper.get_event_shape(y)
|
||||
|
||||
shaper = _ShapeUtil(batch_ndims=0, event_ndims=0)
|
||||
x = 1
|
||||
self.assertAllEqual(shaper.get_sample_shape(x), [])
|
||||
self.assertAllEqual(shaper.get_batch_shape(x), [])
|
||||
self.assertAllEqual(shaper.get_event_shape(x), [])
|
||||
self.assertAllEqual(shaper.get_shape(x, batch=False), [])
|
||||
|
||||
shaper = _ShapeUtil(batch_ndims=1, event_ndims=1)
|
||||
x = [[[0., 1, 2], [3, 4, 5]]]
|
||||
self.assertAllEqual(shaper.get_sample_shape(x), [1])
|
||||
self.assertAllEqual(shaper.get_batch_shape(x), [2])
|
||||
self.assertAllEqual(shaper.get_event_shape(x), [3])
|
||||
self.assertAllEqual(shaper.get_shape(x, batch=False), [1, 3])
|
||||
|
||||
x += [[[6, 7, 8], [9, 10, 11]]]
|
||||
self.assertAllEqual(shaper.get_sample_shape(x), [2])
|
||||
self.assertAllEqual(shaper.get_batch_shape(x), [2])
|
||||
self.assertAllEqual(shaper.get_event_shape(x), [3])
|
||||
self.assertAllEqual(shaper.get_shape(x, batch=False), [2, 3])
|
||||
|
||||
shaper = _ShapeUtil(batch_ndims=0, event_ndims=1)
|
||||
x = tf.ones((3, 2))
|
||||
self.assertAllEqual(shaper.get_shape(x, sample=False), (2,))
|
||||
|
||||
def feed_eval(fun, build_shape=(None, None, 2), graph_shape=(3, 4, 2)):
|
||||
"""Helper to use a deferred-shape tensor eval'ed at graph runtime."""
|
||||
y = tf.placeholder(tf.int32, shape=build_shape)
|
||||
y_value = np.ones(graph_shape, dtype=y.dtype.as_numpy_dtype())
|
||||
return sess.run(fun(y),
|
||||
feed_dict={y: y_value})
|
||||
|
||||
shaper = _ShapeUtil(batch_ndims=1, event_ndims=1)
|
||||
self.assertAllEqual(feed_eval(shaper.get_sample_shape), [3])
|
||||
self.assertAllEqual(feed_eval(shaper.get_batch_shape), [4])
|
||||
self.assertAllEqual(feed_eval(shaper.get_event_shape), [2])
|
||||
self.assertAllEqual(
|
||||
feed_eval(lambda y: shaper.get_shape(y, batch=False)),
|
||||
[3, 2])
|
||||
|
||||
shaper = _ShapeUtil(batch_ndims=0, event_ndims=1)
|
||||
self.assertAllEqual(
|
||||
feed_eval(lambda y: shaper.get_shape(y, batch=False),
|
||||
(None, None),
|
||||
(3, 2)),
|
||||
[3, 2])
|
||||
self.assertAllEqual(
|
||||
feed_eval(lambda y: shaper.get_shape(y, sample=False),
|
||||
(None, None),
|
||||
(3, 2)),
|
||||
[2])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
350
tensorflow/contrib/distributions/python/ops/bijector.py
Normal file
350
tensorflow/contrib/distributions/python/ops/bijector.py
Normal file
@ -0,0 +1,350 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""An API for reversible (bijective) transformations of random variables."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
||||
|
||||
class _Bijector(object):
|
||||
"""An interface for transforming random variable(s).
|
||||
|
||||
A bijector is characterized by three operations:
|
||||
|
||||
1) Forward Evaluation
|
||||
Useful for turning one random outcome into another random outcome from a
|
||||
different distribution.
|
||||
|
||||
2) Inverse Evaluation
|
||||
Useful for "reversing" a transformation to compute one probability in terms
|
||||
of another.
|
||||
|
||||
3) (log o det o Jacobian o inverse)(x)
|
||||
"The log of the determinant of the matrix of all first-order partial
|
||||
derivatives of the inverse function."
|
||||
Useful for inverting a transformation to compute one probability in terms
|
||||
of another. Geometrically, the det(Jacobian) is the volume of the
|
||||
transformation and is used to scale the probability.
|
||||
|
||||
By convention, transformations of random variables are named in terms of the
|
||||
forward transformation. The forward transformation creates samples, the
|
||||
inverse is useful for computing probabilities.
|
||||
|
||||
Example transformations:
|
||||
"Exponential"
|
||||
|
||||
```
|
||||
Y = g(X) = exp(X)
|
||||
X ~ Normal(0, 1) # Univariate.
|
||||
```
|
||||
|
||||
Implies:
|
||||
|
||||
```
|
||||
g^{-1}(Y) = log(Y)
|
||||
|Jacobian(g^{-1})(y)| = 1 / y
|
||||
Y ~ LogNormal(0, 1), i.e.,
|
||||
prob(Y=y) = |Jacobian(g^{-1})(y)| * prob(X=g^{-1}(y))
|
||||
= (1 / y) Normal(log(y); 0, 1)
|
||||
```
|
||||
|
||||
"ShiftAndScale"
|
||||
|
||||
```
|
||||
Y = g(X) = sqrtSigma * X + mu
|
||||
X ~ MultivariateNormal(0, I_d)
|
||||
```
|
||||
|
||||
Implies:
|
||||
|
||||
```
|
||||
g^{-1}(Y) = inv(sqrtSigma) * (Y - mu)
|
||||
|Jacobian(g^{-1})(y)| = det(inv(sqrtSigma))
|
||||
Y ~ MultivariateNormal(mu, sqrtSigma) , i.e.,
|
||||
prob(Y=y) = |Jacobian(g^{-1})(y)| * prob(X=g^{-1}(y))
|
||||
= det(sqrtSigma)^(-d) *
|
||||
MultivariateNormal(inv(sqrtSigma) * (y - mu); 0, I_d)
|
||||
```
|
||||
|
||||
Example use:
|
||||
Basic properties:
|
||||
|
||||
```python
|
||||
x = ... # A tensor.
|
||||
# Evaluate forward transformation.
|
||||
fwd_x = my_bijector.forward(x)
|
||||
x != my_bijector.forward(fwd_x) # Not equal because g(x) != g(g(x)).
|
||||
x == my_bijector.inverse(fwd_x)
|
||||
```
|
||||
|
||||
Computing a log-likelihood:
|
||||
|
||||
```python
|
||||
def transformed_log_pdf(bijector, log_pdf, x):
|
||||
return (bijector.inverse_log_det_jacobian(x) +
|
||||
log_pdf(bijector.inverse(x)))
|
||||
```
|
||||
|
||||
Transforming a random outcome:
|
||||
|
||||
```python
|
||||
def transformed_sample(bijector, x):
|
||||
return bijector.forward(x)
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
# TODO(b/30476956): Try to remove constructor dependence on shape util.
|
||||
def __init__(self, shaper=None, name=None):
|
||||
"""Constructs Bijector.
|
||||
|
||||
A bijector transforms random variables into new random variables. Managing
|
||||
shape is typically an important piece of this so a Bijector is usually
|
||||
composed of ShapeUtil. The ShapeUtil object handles input shape checks as
|
||||
well as reshaping/transposing for easier linear algebra operations.
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Create the Y = g(X) = X transform which operates on 4-Tensors of vectors.
|
||||
identity = Identity(ShapeUtil(batch_ndims=4, event_ndims=1))
|
||||
|
||||
# Create the Y = g(X) = exp(X) transform which operates on matrices.
|
||||
exp = Exp(ShapeUtil(batch_ndims=0, event_ndims=2))
|
||||
```
|
||||
|
||||
See Bijector subclass doc for more details and examples.
|
||||
|
||||
Args:
|
||||
shaper: object used for managing and manipulating shape, typically an
|
||||
instance of ShapeUtil.
|
||||
name: The name to give Ops created by the initializer.
|
||||
"""
|
||||
self._shaper = shaper
|
||||
self._name = name or type(self).__name__
|
||||
|
||||
@property
|
||||
def shaper(self):
|
||||
"""Returns shape object used to manage shape constraints."""
|
||||
return self._shaper
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
"""Returns the string name of this bijector."""
|
||||
return self._name
|
||||
|
||||
def forward(self, x, name='forward'):
|
||||
"""Returns the forward bijector evaluation, i.e., X = g(Y).
|
||||
|
||||
Args:
|
||||
x: `Tensor`. The input to the "forward" evaluation.
|
||||
name: The name to give this op.
|
||||
|
||||
Returns:
|
||||
`Tensor`.
|
||||
"""
|
||||
with ops.name_scope(self.name):
|
||||
with ops.op_scope([x], name):
|
||||
x = ops.convert_to_tensor(x)
|
||||
return self._forward(x)
|
||||
|
||||
def inverse(self, x, name='inverse'):
|
||||
"""Returns the inverse bijector evaluation, i.e., X = g^{-1}(Y).
|
||||
|
||||
Args:
|
||||
x: `Tensor`. The input to the "inverse" evaluation.
|
||||
name: The name to give this op.
|
||||
|
||||
Returns:
|
||||
`Tensor`.
|
||||
"""
|
||||
with ops.name_scope(self.name):
|
||||
with ops.op_scope([x], name):
|
||||
x = ops.convert_to_tensor(x)
|
||||
try:
|
||||
return self._inverse(x)
|
||||
except NotImplementedError:
|
||||
return self._inverse_and_inverse_log_det_jacobian(x)[0]
|
||||
|
||||
def inverse_log_det_jacobian(self, x, name='inverse_log_det_jacobian'):
|
||||
"""Returns the (log o det o Jacobian o inverse)(x).
|
||||
|
||||
Mathematically, returns: log(det(dY/dX g^{-1}))(Y).
|
||||
|
||||
Args:
|
||||
x: `Tensor`. The input to the "inverse" Jacobian evaluation.
|
||||
name: The name to give this op.
|
||||
|
||||
Returns:
|
||||
`Tensor`.
|
||||
"""
|
||||
with ops.name_scope(self.name):
|
||||
with ops.op_scope([x], name):
|
||||
x = ops.convert_to_tensor(x)
|
||||
try:
|
||||
return self._inverse_log_det_jacobian(x)
|
||||
except NotImplementedError:
|
||||
return self._inverse_and_inverse_log_det_jacobian(x)[1]
|
||||
|
||||
def inverse_and_inverse_log_det_jacobian(
|
||||
self, x, name='inverse_and_inverse_log_det_jacobian'):
|
||||
"""Returns both the inverse evaluation and inverse_log_det_jacobian.
|
||||
|
||||
Enables possibly more efficient calculation when both inverse and
|
||||
corresponding Jacobian are needed.
|
||||
|
||||
See `inverse()`, `inverse_log_det_jacobian()` for more details.
|
||||
|
||||
Args:
|
||||
x: `Tensor`. The input to the "inverse" Jacobian evaluation.
|
||||
name: The name to give this op.
|
||||
|
||||
Returns:
|
||||
`Tensor`.
|
||||
"""
|
||||
with ops.name_scope(self.name):
|
||||
with ops.op_scope([x], name):
|
||||
x = ops.convert_to_tensor(x)
|
||||
try:
|
||||
return self._inverse_and_inverse_log_det_jacobian(x)
|
||||
except NotImplementedError:
|
||||
return self._inverse(x), self._inverse_log_det_jacobian(x)
|
||||
|
||||
# Subclass interface.
|
||||
def _forward(self, x):
|
||||
"""Subclass implementation of forward().
|
||||
|
||||
Args:
|
||||
x: `Tensor`. The input to the "forward" evaluation.
|
||||
|
||||
Raises:
|
||||
`NotImplementedError`: if subclass implementation not provided
|
||||
|
||||
Returns:
|
||||
`Tensor`.
|
||||
"""
|
||||
raise NotImplementedError('_forward not implemented')
|
||||
|
||||
def _inverse(self, x):
|
||||
"""Subclass implementation of inverse().
|
||||
|
||||
Args:
|
||||
x: `Tensor`. The input to the "inverse" evaluation.
|
||||
|
||||
Raises:
|
||||
`NotImplementedError`: if subclass implementation not provided
|
||||
|
||||
Returns:
|
||||
`Tensor`.
|
||||
"""
|
||||
raise NotImplementedError('_inverse not implemented')
|
||||
|
||||
def _inverse_log_det_jacobian(self, x):
|
||||
"""Subclass implementation of inverse_log_det_jacobian().
|
||||
|
||||
Args:
|
||||
x: `Tensor`. The input to the "inverse" Jacobian evaluation.
|
||||
|
||||
Raises:
|
||||
`NotImplementedError`: if subclass implementation not provided
|
||||
|
||||
Returns:
|
||||
`Tensor`.
|
||||
"""
|
||||
raise NotImplementedError('_inverse_log_det_jacobian not implemented')
|
||||
|
||||
def _inverse_and_inverse_log_det_jacobian(self, x):
|
||||
"""Subclass implementation of inverse_and_inverse_log_det_jacobian().
|
||||
|
||||
Args:
|
||||
x: `Tensor`. The input to the "inverse" evaluation.
|
||||
|
||||
Returns:
|
||||
List of two `Tensor` items, inverse and inverse_log_det_jacobian.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
'_inverse_and_inverse_log_det_jacobian not implemented')
|
||||
|
||||
|
||||
class _Identity(_Bijector):
|
||||
"""Bijector which computes Y = g(X) = X.
|
||||
|
||||
Example Use:
|
||||
```python
|
||||
# Create the Y=g(X)=X transform which works only on Tensors with 1 batch
|
||||
# ndims and 1 event ndim (i.e., vector of vectors).
|
||||
identity = Identity(ShapeUtil(batch_ndims=1, event_ndims=1))
|
||||
x = [[1., 2],
|
||||
[3, 4]]
|
||||
x == identity.forward(x) == identity.inverse(x)
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
# TODO(b/30476956): Try to remove constructor dependence on shape util.
|
||||
def __init__(self, shaper=None, name='Identity'):
|
||||
super(_Identity, self).__init__(shaper, name)
|
||||
|
||||
def _forward(self, x):
|
||||
return x
|
||||
|
||||
def _inverse(self, x):
|
||||
return x
|
||||
|
||||
def _inverse_log_det_jacobian(self, x):
|
||||
result_shape = self.shaper.get_shape(
|
||||
x, sample=True, batch=True, event=False)
|
||||
return array_ops.zeros(result_shape, dtype=x.dtype)
|
||||
|
||||
|
||||
class _Exp(_Bijector):
|
||||
"""Bijector which computes Y = g(X) = exp(X).
|
||||
|
||||
Example Use:
|
||||
```python
|
||||
# Create the Y=g(X)=exp(X) transform which works only on Tensors with 1
|
||||
# batch ndims and 2 event ndim (i.e., vector of matrices).
|
||||
exp = Exp(ShapeUtil(batch_ndims=1, event_ndims=2))
|
||||
x = [[[1., 2],
|
||||
[3, 4]],
|
||||
[[5, 6],
|
||||
[7, 8]]]
|
||||
exp(x) == exp.forward(x)
|
||||
log(x) == exp.inverse(x)
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
# TODO(b/30476956): Try to remove constructor dependence on shape util.
|
||||
def __init__(self, shaper=None, name='Exp'):
|
||||
super(_Exp, self).__init__(shaper, name)
|
||||
|
||||
def _forward(self, x):
|
||||
return math_ops.exp(x)
|
||||
|
||||
def _inverse(self, x):
|
||||
return math_ops.log(x)
|
||||
|
||||
def _inverse_log_det_jacobian(self, x):
|
||||
d = self.shaper.get_event_dims(x)
|
||||
return -math_ops.reduce_sum(math_ops.log(x), d)
|
||||
|
||||
def _inverse_and_inverse_log_det_jacobian(self, x):
|
||||
y = math_ops.log(x)
|
||||
d = self.shaper.get_event_dims(x)
|
||||
return y, -math_ops.reduce_sum(y, d)
|
396
tensorflow/contrib/distributions/python/ops/shape.py
Normal file
396
tensorflow/contrib/distributions/python/ops/shape.py
Normal file
@ -0,0 +1,396 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""A helper class for inferring Distribution shape."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
|
||||
|
||||
class _ShapeUtil(object):
|
||||
"""Class which helps infer/identify subsets of tensor dimensions.
|
||||
|
||||
Terminology:
|
||||
Recall that a `Tensor` has:
|
||||
shape: sizes of tensor dimensions,
|
||||
ndims: size of shape; number of tensor dimensions,
|
||||
dims: indexes into shape; useful for transpose, reduce.
|
||||
|
||||
Tensors sampled from a `Distribution` can be partitioned by:
|
||||
sample dims: indexes independent, identically distributed (iid) draws,
|
||||
batch dims: indexes non-identical draws,
|
||||
event dims: indexes coordinates of a single draw.
|
||||
|
||||
The sample, batch, and event dimensions constitute the entirety of a
|
||||
`Tensor` shape. The dimensions are always in sample, batch, event order.
|
||||
|
||||
Assumptions:
|
||||
We assume that batch_ndims and event_ndims are statically known for both
|
||||
creating this object and for inputs to its functions.
|
||||
TODO(jvdillon): Relax this assumption and support fully unknown shape.
|
||||
|
||||
We also assume that the `Tensor` rank is static, i.e., `x.get_shape().ndims
|
||||
is not None`.
|
||||
|
||||
Possible use-cases:
|
||||
~ Sample dimensions:
|
||||
Computing summary statistics, i.e., the average is a reduction over sample
|
||||
dimensions.
|
||||
|
||||
~ Batch dimensions:
|
||||
Log-likelihood under model predicted location:
|
||||
```python
|
||||
mu = ... # vector of predictions, one for each covariate.
|
||||
neg_log_likelihood = -tf.reduce_mean(
|
||||
Normal(loc=mu, scale=1).log_pdf(x),
|
||||
reduce_dims=[0])
|
||||
```
|
||||
|
||||
Monte Carlo estimation of a marginal probability:
|
||||
Average over batch dimensions where batch dimensions are associated with
|
||||
random draws of a prior.
|
||||
E.g., suppose we want to find the Monte Carlo estimate of the marginal
|
||||
distribution of a Normal with a random Laplace location:
|
||||
```
|
||||
P(X=x) = integral P(X=x|y) P(Y=y) dy
|
||||
~= 1/n sum_{i=1}^n P(X=x|y_i), y_i ~iid Laplace(0,1)
|
||||
= tf.reduce_mean(Normal(loc=Laplace(0, 1).sample_n(n=1000),
|
||||
scale=tf.ones([1000, 1])).pdf(x),
|
||||
reduce_dims=[0])
|
||||
```
|
||||
|
||||
The `Laplace` distribution generates a tensor of shape [1000, 1]. When fed
|
||||
to a `Normal`, this is interpreted as 1000 different locations, i.e.,
|
||||
1000 non-identical Normals. Therefore a single call to pdf(x) yields 1000
|
||||
probabilities, one for every location. The average over this batch yields
|
||||
the marginal.
|
||||
|
||||
~ Event dimensions:
|
||||
Computing the determinant of the Jacobian of a function of a random
|
||||
variable involves a reduction over event dimensions.
|
||||
|
||||
Examples:
|
||||
Write S, B, E for sample shape, batch shape, and event shape (resp.).
|
||||
|
||||
```python
|
||||
x.get_shape() == S + B + E # For statically known x shape.
|
||||
|
||||
# 100 iid samples from one multivariate Normal with two
|
||||
# degrees of freedom (DF).
|
||||
mu = [0., 0]
|
||||
sigma = [[1., 0],
|
||||
[0, 1]]
|
||||
X = MultivariateNormal(loc=mu, scale=sigma).sample_n(n=100)
|
||||
# S = [100]
|
||||
# B = []
|
||||
# E = [2]
|
||||
|
||||
# 100 iid samples from one Wishart with 2x2 DF.
|
||||
sigma = [[1., 0],
|
||||
[0, 1]]
|
||||
X = Wishart(scale=sigma).sample_n(n=100)
|
||||
# S = [100]
|
||||
# B = []
|
||||
# E = [2, 2]
|
||||
|
||||
# 100 iid samples (with shape [2, 50]) from two, non-identical bivariate
|
||||
# Normal distributions.
|
||||
mu = ... # shape(2, 2)
|
||||
sigma = ... # shape(2, 2, 2)
|
||||
X = MultivariateNormal(loc=mu, scale=sigma).sample(shape=[2, 50])
|
||||
# S = [2, 50]
|
||||
# B = [2]
|
||||
# E = [2]
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, batch_ndims=None, event_ndims=None, name='ShapeUtil'):
|
||||
"""Construct ShapeUtil with known sample, batch, and/or event ndims.
|
||||
|
||||
Typically, batch_ndims and event_ndims are fixed throughout the lifetime of
|
||||
a Distribution.
|
||||
|
||||
Args:
|
||||
batch_ndims: number of dims (rank) of the batch portion of indexes of a
|
||||
`Tensor`. A "batch" is a non-identical distribution, i.e, Normal with
|
||||
different parameters.
|
||||
event_ndims: number of dims (rank) of the event portion of indexes of a
|
||||
`Tensor`. An "event" is what is sampled from a distribution, i.e., a
|
||||
trivariate Normal has an event shape of [3] and a 4 dimensional Wishart
|
||||
has an event shape of [4, 4].
|
||||
name: `String`. The name to give Ops created by this class.
|
||||
|
||||
Raises:
|
||||
ValueError: if batch_ndims or event_ndims are invalid.
|
||||
"""
|
||||
if batch_ndims < 0:
|
||||
raise ValueError('must specify non-negative batch_ndims(%d)', batch_ndims)
|
||||
if batch_ndims > 0 and event_ndims < 1:
|
||||
raise ValueError('must specify positive event_ndims(%d) when '
|
||||
'batch_ndims(%d) is positive', event_ndims, batch_ndims)
|
||||
# TODO(jvdillon): Support batches of scalars.
|
||||
self._name = name
|
||||
self._batch_ndims = batch_ndims
|
||||
self._event_ndims = event_ndims
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
"""Name given to ops created by this class."""
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def batch_ndims(self):
|
||||
"""Returns number of dimensions corresponding to non-identical draws."""
|
||||
return self._batch_ndims
|
||||
|
||||
@property
|
||||
def event_ndims(self):
|
||||
"""Returns number of dimensions needed to index a sample's coordinates."""
|
||||
return self._event_ndims
|
||||
|
||||
def get_ndims(self, x, name='get_ndims'):
|
||||
"""Get tensor ndims (rank).
|
||||
|
||||
Args:
|
||||
x: `Tensor`.
|
||||
name: `String`. The name to give this op.
|
||||
|
||||
Raises:
|
||||
ValueError: if ndims is not statically known.
|
||||
|
||||
Returns:
|
||||
`Scalar` number of dimensions associated with a `Tensor`.
|
||||
"""
|
||||
if x is None:
|
||||
raise ValueError('Input was None which does not have known ndims.')
|
||||
with ops.name_scope(self.name):
|
||||
with ops.op_scope([x], name):
|
||||
ndims = ops.convert_to_tensor(x).get_shape().ndims
|
||||
if ndims is None:
|
||||
raise ValueError('ShapeUtil assumes static number of '
|
||||
'dimensions(%d)', ndims)
|
||||
return ndims
|
||||
|
||||
def get_sample_ndims(self, x):
|
||||
"""Returns number of dimensions corresponding to iid draws.
|
||||
|
||||
Args:
|
||||
x: `Tensor`.
|
||||
|
||||
Raises:
|
||||
ValueError: if batch_ndims or event_ndims are not statically known.
|
||||
ValueError: if static sample_ndims does not match inferred
|
||||
|
||||
Returns:
|
||||
Scalar number of dimensions associated with a sample.
|
||||
"""
|
||||
ndims = self.get_ndims(x)
|
||||
sample_ndims = ndims - self.batch_ndims - self.event_ndims
|
||||
if sample_ndims < 0:
|
||||
raise ValueError('expected batch_ndims(%d) + event_ndims(%d) < ndims(%d)',
|
||||
self.batch_ndims, self.event_ndims, ndims)
|
||||
return sample_ndims
|
||||
|
||||
def get_dims(self, x, sample=True, batch=True, event=True):
|
||||
"""Returns subset of tensor's dimension indexes (indexes into shape).
|
||||
|
||||
Args:
|
||||
x: `Tensor`.
|
||||
sample: `Boolean`. Include sample dimensions or not.
|
||||
batch: `Boolean`. Include batch dimensions or not.
|
||||
event: `Boolean`. Include event dimensions or not.
|
||||
|
||||
Raises:
|
||||
ValueError: if `x.get_shape().ndims` is `None`
|
||||
|
||||
Returns:
|
||||
List enumerating requested dimensions.
|
||||
"""
|
||||
ndims = self.get_ndims(x)
|
||||
|
||||
if sample and batch and event:
|
||||
return list(range(ndims))
|
||||
|
||||
sample_start = 0
|
||||
batch_start = self.get_sample_ndims(x)
|
||||
event_start = batch_start + self.batch_ndims
|
||||
|
||||
sample_shape = list(range(sample_start, batch_start)) if sample else []
|
||||
batch_shape = list(range(batch_start, event_start)) if batch else []
|
||||
event_shape = list(range(event_start, ndims)) if event else []
|
||||
|
||||
return sample_shape + batch_shape + event_shape
|
||||
|
||||
def get_shape(self, x, sample=True, batch=True, event=True, name='get_shape'):
|
||||
"""Returns subset of tensor's shape (size of dimensions).
|
||||
|
||||
Args:
|
||||
x: `Tensor`.
|
||||
sample: `Boolean`. Include sample shape or not.
|
||||
batch: `Boolean`. Include batch shape or not.
|
||||
event: `Boolean`. Include event shape or not.
|
||||
name: `String`. The name to give this op.
|
||||
|
||||
Raises:
|
||||
ValueError: if `x.get_shape().ndims` is `None`
|
||||
|
||||
Returns:
|
||||
List describing event shape if known statically, `Tensor` otherwise.
|
||||
"""
|
||||
if not sample and not batch and not event:
|
||||
return []
|
||||
with ops.name_scope(self._name):
|
||||
with ops.op_scope([x], name):
|
||||
x = ops.convert_to_tensor(x)
|
||||
shape = (x.get_shape().as_list()
|
||||
if x.get_shape().is_fully_defined()
|
||||
else array_ops.shape(x))
|
||||
|
||||
if sample and batch and event:
|
||||
return shape
|
||||
|
||||
sample_start = 0
|
||||
batch_start = self.get_sample_ndims(x)
|
||||
event_start = batch_start + self.batch_ndims
|
||||
|
||||
sample_shape = shape[sample_start:batch_start] if sample else []
|
||||
batch_shape = shape[batch_start:event_start] if batch else []
|
||||
event_shape = shape[event_start:] if event else []
|
||||
|
||||
if not batch and not event:
|
||||
return sample_shape
|
||||
if not sample and not event:
|
||||
return batch_shape
|
||||
if not sample and not batch:
|
||||
return event_shape
|
||||
|
||||
if x.get_shape().is_fully_defined():
|
||||
return sample_shape + batch_shape + event_shape
|
||||
else:
|
||||
return array_ops.concat(0, [sample_shape, batch_shape, event_shape])
|
||||
|
||||
def get_sample_dims(self, x):
|
||||
"""Returns dimension indexes corresponding to sample.
|
||||
|
||||
Convenience function; identical to:
|
||||
|
||||
```python
|
||||
get_dims(x, sample=True, batch=False, event=False)
|
||||
```
|
||||
|
||||
Args:
|
||||
x: `Tensor`.
|
||||
|
||||
Raises:
|
||||
ValueError: if `x.get_shape().ndims` is `None`
|
||||
|
||||
Returns:
|
||||
List enumerating sample dimensions.
|
||||
"""
|
||||
return self.get_dims(x, sample=True, batch=False, event=False)
|
||||
|
||||
def get_batch_dims(self, x):
|
||||
"""Returns dimension indexes corresponding to batch.
|
||||
|
||||
Convenience function; identical to:
|
||||
|
||||
```python
|
||||
get_dims(x, sample=False, batch=True, event=False)
|
||||
```
|
||||
|
||||
Args:
|
||||
x: `Tensor`.
|
||||
|
||||
Raises:
|
||||
ValueError: if `x.get_shape().ndims` is `None`
|
||||
|
||||
Returns:
|
||||
List enumerating batch dimensions.
|
||||
"""
|
||||
return self.get_dims(x, sample=False, batch=True, event=False)
|
||||
|
||||
def get_event_dims(self, x):
|
||||
"""Returns dimension indexes corresponding to event.
|
||||
|
||||
Convenience function; identical to:
|
||||
|
||||
```python
|
||||
get_dims(x, sample=False, batch=False, event=True)
|
||||
```
|
||||
|
||||
Args:
|
||||
x: `Tensor`.
|
||||
|
||||
Raises:
|
||||
ValueError: if `x.get_shape().ndims` is `None`
|
||||
|
||||
Returns:
|
||||
List enumerating event dimensions.
|
||||
"""
|
||||
return self.get_dims(x, sample=False, batch=False, event=True)
|
||||
|
||||
def get_sample_shape(self, x):
|
||||
"""Returns shape corresponding to sample.
|
||||
|
||||
Convenience function; identical to:
|
||||
|
||||
```python
|
||||
get_shape(x, sample=True, batch=False, event=False)
|
||||
```
|
||||
|
||||
Args:
|
||||
x: `Tensor`.
|
||||
|
||||
Returns:
|
||||
List describing sample shape if known statically, `Tensor` otherwise.
|
||||
"""
|
||||
return self.get_shape(x, sample=True, batch=False, event=False)
|
||||
|
||||
def get_batch_shape(self, x):
|
||||
"""Returns shape corresponding to batch.
|
||||
|
||||
Convenience function; identical to:
|
||||
|
||||
```python
|
||||
get_shape(x, sample=False, batch=True, event=False)
|
||||
```
|
||||
|
||||
Args:
|
||||
x: `Tensor`.
|
||||
|
||||
Returns:
|
||||
List describing batch shape if known statically, `Tensor` otherwise.
|
||||
"""
|
||||
return self.get_shape(x, sample=False, batch=True, event=False)
|
||||
|
||||
def get_event_shape(self, x):
|
||||
"""Returns shape corresponding to event.
|
||||
|
||||
Convenience function; identical to:
|
||||
|
||||
```python
|
||||
get_shape(x, sample=False, batch=False, event=True)
|
||||
```
|
||||
|
||||
Args:
|
||||
x: `Tensor`.
|
||||
|
||||
Returns:
|
||||
List describing event shape if known statically, `Tensor` otherwise.
|
||||
"""
|
||||
return self.get_shape(x, sample=False, batch=False, event=True)
|
@ -41,6 +41,7 @@ py_test(
|
||||
size = "small",
|
||||
srcs = ["python/framework/checkpoint_utils_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["manual"], # http://b/30468735
|
||||
deps = ["//tensorflow:tensorflow_py"],
|
||||
)
|
||||
|
||||
|
@ -1036,6 +1036,7 @@ def separable_convolution2d(
|
||||
depthwise_weights = variables.model_variable(
|
||||
'depthwise_weights',
|
||||
shape=depthwise_shape,
|
||||
dtype=dtype,
|
||||
initializer=weights_initializer,
|
||||
regularizer=weights_regularizer,
|
||||
trainable=trainable,
|
||||
@ -1048,6 +1049,7 @@ def separable_convolution2d(
|
||||
pointwise_weights = variables.model_variable(
|
||||
'pointwise_weights',
|
||||
shape=pointwise_shape,
|
||||
dtype=dtype,
|
||||
initializer=weights_initializer,
|
||||
regularizer=weights_regularizer,
|
||||
trainable=trainable,
|
||||
|
@ -1604,10 +1604,28 @@ class RepeatTests(tf.test.TestCase):
|
||||
|
||||
class SeparableConv2dTest(tf.test.TestCase):
|
||||
|
||||
def testCreateConv(self):
|
||||
def testCreateConvInt32(self):
|
||||
height, width = 3, 3
|
||||
with self.test_session():
|
||||
images = tf.random_uniform((5, height, width, 3), seed=1)
|
||||
images = tf.random_uniform(
|
||||
(5, height, width, 3), seed=1, dtype=tf.int32, maxval=12345)
|
||||
with self.assertRaisesRegexp(TypeError, 'non-floating point type'):
|
||||
tf.contrib.layers.separable_conv2d(images, 32, [3, 3], 2)
|
||||
|
||||
def testCreateConvFloat32(self):
|
||||
height, width = 3, 3
|
||||
with self.test_session():
|
||||
images = tf.random_uniform(
|
||||
(5, height, width, 3), seed=1, dtype=tf.float32)
|
||||
output = tf.contrib.layers.separable_conv2d(images, 32, [3, 3], 2)
|
||||
self.assertEquals(output.op.name, 'SeparableConv2d/Relu')
|
||||
self.assertListEqual(output.get_shape().as_list(), [5, height, width, 32])
|
||||
|
||||
def testCreateConvFloat64(self):
|
||||
height, width = 3, 3
|
||||
with self.test_session():
|
||||
images = tf.random_uniform(
|
||||
(5, height, width, 3), seed=1, dtype=tf.float64)
|
||||
output = tf.contrib.layers.separable_conv2d(images, 32, [3, 3], 2)
|
||||
self.assertEquals(output.op.name, 'SeparableConv2d/Relu')
|
||||
self.assertListEqual(output.get_shape().as_list(), [5, height, width, 32])
|
||||
|
@ -123,8 +123,8 @@ class Scaffold(object):
|
||||
global_step_tensor = contrib_variables.get_or_create_global_step()
|
||||
self.global_step_tensor = global_step_tensor
|
||||
if init_op is None:
|
||||
init_op = Scaffold._get_or_default(
|
||||
ops.GraphKeys.INIT_OP, variables.initialize_all_variables)
|
||||
init_op = Scaffold._get_or_default('init_op', ops.GraphKeys.INIT_OP,
|
||||
variables.initialize_all_variables)
|
||||
self.init_op = init_op
|
||||
self.init_feed_dict = init_feed_dict
|
||||
# NOTE(touts): modifying the init function to be passed the scaffold is a
|
||||
@ -135,19 +135,23 @@ class Scaffold(object):
|
||||
self.init_fn = None
|
||||
if ready_op is None:
|
||||
ready_op = Scaffold._get_or_default(
|
||||
ops.GraphKeys.READY_OP, variables.report_uninitialized_variables)
|
||||
'ready_op', ops.GraphKeys.READY_OP,
|
||||
variables.report_uninitialized_variables)
|
||||
self.ready_op = ready_op
|
||||
if local_init_op is None:
|
||||
local_init_op = Scaffold._get_or_default(
|
||||
ops.GraphKeys.LOCAL_INIT_OP, Scaffold._default_local_init_op)
|
||||
local_init_op = Scaffold._get_or_default('local_init_op',
|
||||
ops.GraphKeys.LOCAL_INIT_OP,
|
||||
Scaffold._default_local_init_op)
|
||||
self.local_init_op = local_init_op
|
||||
if summary_op is None:
|
||||
summary_op = Scaffold._get_or_default(
|
||||
ops.GraphKeys.SUMMARY_OP, logging_ops.merge_all_summaries)
|
||||
summary_op = Scaffold._get_or_default('summary_op',
|
||||
ops.GraphKeys.SUMMARY_OP,
|
||||
logging_ops.merge_all_summaries)
|
||||
self.summary_op = summary_op
|
||||
# pylint: disable=g-long-lambda
|
||||
if saver is None:
|
||||
saver = Scaffold._get_or_default(
|
||||
'saver',
|
||||
ops.GraphKeys.SAVERS,
|
||||
lambda: training_saver.Saver(sharded=True,
|
||||
max_to_keep=keep_checkpoint_max))
|
||||
@ -157,9 +161,16 @@ class Scaffold(object):
|
||||
ops.get_default_graph().finalize()
|
||||
|
||||
@staticmethod
|
||||
def _get_or_default(collection_key, default_constructor):
|
||||
def _get_or_default(arg_name, collection_key, default_constructor):
|
||||
"""Get from cache or create a default operation."""
|
||||
elements = ops.get_collection(collection_key)
|
||||
if elements:
|
||||
if len(elements) > 1:
|
||||
raise RuntimeError('More than one item in the collection "%s". '
|
||||
'Please indicate which one to use by passing it to '
|
||||
'the tf.Scaffold constructor as: '
|
||||
'tf.Scaffold(%s=item to use)', collection_key,
|
||||
arg_name)
|
||||
return elements[0]
|
||||
op = default_constructor()
|
||||
if op is not None:
|
||||
|
@ -57,6 +57,14 @@ class ScaffoldTest(tf.test.TestCase):
|
||||
self.assertEqual(scaffold1.local_init_op, scaffold2.local_init_op)
|
||||
self.assertEqual(scaffold1.saver, scaffold2.saver)
|
||||
|
||||
def test_raise_error_if_more_than_one_cached_item(self):
|
||||
with tf.Graph().as_default():
|
||||
tf.Variable([1])
|
||||
tf.add_to_collection(tf.GraphKeys.SAVERS, tf.train.Saver())
|
||||
tf.add_to_collection(tf.GraphKeys.SAVERS, tf.train.Saver())
|
||||
with self.assertRaisesRegexp(RuntimeError, 'More than one item'):
|
||||
supervised_session.Scaffold()
|
||||
|
||||
def test_uses_passed_values(self):
|
||||
with tf.Graph().as_default():
|
||||
scaffold = supervised_session.Scaffold(global_step_tensor=1,
|
||||
|
@ -393,10 +393,8 @@ $(wildcard tensorflow/core/graph/dot.*) \
|
||||
$(wildcard tensorflow/core/lib/gif/*) \
|
||||
$(wildcard tensorflow/core/lib/jpeg/*) \
|
||||
$(wildcard tensorflow/core/lib/png/*) \
|
||||
$(wildcard tensorflow/core/util/checkpoint_reader.*) \
|
||||
$(wildcard tensorflow/core/util/events_writer.*) \
|
||||
$(wildcard tensorflow/core/util/reporter.*) \
|
||||
$(wildcard tensorflow/core/util/tf_status_helper.*) \
|
||||
$(wildcard tensorflow/core/platform/default/stream_executor.*) \
|
||||
$(wildcard tensorflow/core/platform/default/test_benchmark.*) \
|
||||
$(wildcard tensorflow/core/platform/cuda.h) \
|
||||
|
@ -13,11 +13,14 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
using shape_inference::InferenceContext;
|
||||
using shape_inference::Shape;
|
||||
|
||||
REGISTER_OP("QuantizeV2")
|
||||
.Input("input: float")
|
||||
@ -28,6 +31,15 @@ REGISTER_OP("QuantizeV2")
|
||||
.Output("output_max: float")
|
||||
.Attr("T: quantizedtype")
|
||||
.Attr("mode: {'MIN_COMBINED', 'MIN_FIRST'} = 'MIN_COMBINED'")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
|
||||
const Shape* unused;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
|
||||
c->set_output(1, c->Scalar());
|
||||
c->set_output(2, c->Scalar());
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
Quantize the 'input' tensor of type float to 'output' tensor of type 'T'.
|
||||
|
||||
@ -96,6 +108,13 @@ REGISTER_OP("Dequantize")
|
||||
.Output("output: float")
|
||||
.Attr("T: quantizedtype")
|
||||
.Attr("mode: {'MIN_COMBINED', 'MIN_FIRST'} = 'MIN_COMBINED'")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
|
||||
const Shape* unused;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
Dequantize the 'input' tensor into a float Tensor.
|
||||
|
||||
|
@ -156,6 +156,15 @@ REGISTER_OP("QuantizedMaxPool")
|
||||
.Attr("ksize: list(int)")
|
||||
.Attr("strides: list(int)")
|
||||
.Attr(GetPaddingAttrString())
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
TF_RETURN_IF_ERROR(shape_inference::MaxPoolShape(c));
|
||||
const Shape* unused;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
|
||||
c->set_output(1, c->Scalar());
|
||||
c->set_output(2, c->Scalar());
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
Produces the max pool of the input tensor for quantized types.
|
||||
|
||||
|
@ -198,7 +198,7 @@ of the corresponding TF-Slim code:
|
||||
|
||||
```python
|
||||
input = ...
|
||||
net = slim.conv2d(input, [3, 3], 128, scope='conv1_1')
|
||||
net = slim.conv2d(input, 128, [3, 3], scope='conv1_1')
|
||||
```
|
||||
|
||||
TF-Slim provides standard implementations for numerous components for building
|
||||
@ -431,7 +431,7 @@ between the predicted and true values.
|
||||
|
||||
Certain models, such as multi-task
|
||||
learning models, require the use of multiple loss functions simultaneously. In
|
||||
other words, the loss function ultimatey being minimized is the sum of various
|
||||
other words, the loss function ultimately being minimized is the sum of various
|
||||
other loss functions. For example, consider a model that predicts both
|
||||
the type of scene in an image as well as the depth from the
|
||||
camera of each pixel. This model's loss function would be the sum of the
|
||||
|
@ -156,6 +156,7 @@ cc_library(
|
||||
"lib/io/table_options.h",
|
||||
"lib/jpeg/jpeg_mem.h",
|
||||
"lib/monitoring/counter.h",
|
||||
"lib/monitoring/metric_def.h",
|
||||
"lib/random/distribution_sampler.h",
|
||||
"lib/random/philox_random.h",
|
||||
"lib/random/simple_philox.h", # TODO(josh11b): make internal
|
||||
@ -423,9 +424,6 @@ tf_cuda_library(
|
||||
"graph/validate.h",
|
||||
"public/session.h",
|
||||
"public/session_options.h",
|
||||
"public/tensor_c_api.h",
|
||||
"util/checkpoint_reader.h",
|
||||
"util/tf_status_helper.h",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
@ -605,10 +603,8 @@ filegroup(
|
||||
"lib/jpeg/**/*",
|
||||
"lib/png/**/*",
|
||||
"lib/gif/**/*",
|
||||
"util/checkpoint_reader.*",
|
||||
"util/events_writer.*",
|
||||
"util/reporter.*",
|
||||
"util/tf_status_helper.*",
|
||||
"platform/default/stream_executor.*",
|
||||
"platform/default/test_benchmark.*",
|
||||
"platform/cuda.h",
|
||||
@ -917,8 +913,6 @@ tf_cuda_library(
|
||||
"**/*test*",
|
||||
"**/*main.cc",
|
||||
"example/example_parser_configuration.*",
|
||||
"util/tf_status_helper.*",
|
||||
"util/checkpoint_reader.*",
|
||||
"util/reporter.h",
|
||||
"util/reporter.cc",
|
||||
"framework/fake_input.*",
|
||||
@ -1063,10 +1057,7 @@ tf_cuda_library(
|
||||
"graph/**/*.cc",
|
||||
"public/session.h",
|
||||
"public/session_options.h",
|
||||
"public/tensor_c_api.h",
|
||||
"public/version.h",
|
||||
"util/tf_status_helper.*",
|
||||
"util/checkpoint_reader.*",
|
||||
],
|
||||
exclude = [
|
||||
"**/*test*",
|
||||
@ -1497,36 +1488,6 @@ tf_cc_tests(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_tests(
|
||||
size = "small",
|
||||
linkopts = select({
|
||||
"//tensorflow:darwin": ["-headerpad_max_install_names"],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
linkstatic = tf_kernel_tests_linkstatic(),
|
||||
tests = [
|
||||
"client/tensor_c_api_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":core",
|
||||
":core_cpu",
|
||||
":core_cpu_internal",
|
||||
":direct_session_internal",
|
||||
":framework",
|
||||
":framework_internal",
|
||||
":lib",
|
||||
":lib_internal",
|
||||
":proto_text",
|
||||
":protos_all_cc",
|
||||
":test",
|
||||
":test_main",
|
||||
":testlib",
|
||||
"//tensorflow/core/kernels:array",
|
||||
"//tensorflow/core/kernels:math",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
# GPU-related tests
|
||||
tf_cc_tests_gpu(
|
||||
size = "small",
|
||||
@ -2000,6 +1961,7 @@ filegroup(
|
||||
"example/testdata/parse_example_graph_def.pbtxt",
|
||||
],
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Google-internal targets go here (must be at the end).
|
||||
|
||||
|
@ -182,8 +182,8 @@ Status GrpcServer::Init() {
|
||||
|
||||
// Finish setting up worker environment.
|
||||
worker_env_.graph_mgr = new GraphMgr(&worker_env_);
|
||||
worker_env_.rendezvous_mgr = new RpcRendezvousMgr(&worker_env_);
|
||||
worker_env_.compute_pool = ComputePool(sess_opts);
|
||||
worker_env_.rendezvous_mgr = new RpcRendezvousMgr(&worker_env_);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -37,8 +37,9 @@ namespace {
|
||||
|
||||
class RpcRemoteRendezvous : public BaseRemoteRendezvous {
|
||||
public:
|
||||
RpcRemoteRendezvous(const WorkerEnv* env, int64 step_id)
|
||||
: BaseRemoteRendezvous(env, step_id, false) {}
|
||||
RpcRemoteRendezvous(const WorkerEnv* env, WorkerCacheInterface* cache,
|
||||
int64 step_id)
|
||||
: BaseRemoteRendezvous(env, step_id, false), cache_(cache) {}
|
||||
|
||||
protected:
|
||||
void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
|
||||
@ -48,6 +49,7 @@ class RpcRemoteRendezvous : public BaseRemoteRendezvous {
|
||||
private:
|
||||
~RpcRemoteRendezvous() override {}
|
||||
|
||||
WorkerCacheInterface* cache_; // Not owned.
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(RpcRemoteRendezvous);
|
||||
};
|
||||
|
||||
@ -55,13 +57,12 @@ class RpcRemoteRendezvous : public BaseRemoteRendezvous {
|
||||
class RpcRecvTensorCall : public BaseRecvTensorCall {
|
||||
public:
|
||||
RpcRecvTensorCall()
|
||||
: wi_(nullptr), wc_(nullptr), allocator_(nullptr), dst_device_(nullptr) {}
|
||||
: wi_(nullptr), allocator_(nullptr), dst_device_(nullptr) {}
|
||||
|
||||
void Init(WorkerCacheInterface* wc, WorkerInterface* wi, int64 step_id,
|
||||
StringPiece key, Allocator* allocator, Device* dst_device,
|
||||
void Init(WorkerInterface* wi, int64 step_id, StringPiece key,
|
||||
Allocator* allocator, Device* dst_device,
|
||||
const Rendezvous::Args& recv_args, Rendezvous::DoneCallback done) {
|
||||
wi_ = wi;
|
||||
wc_ = wc;
|
||||
allocator_ = allocator;
|
||||
dst_device_ = dst_device;
|
||||
recv_args_ = recv_args;
|
||||
@ -73,7 +74,6 @@ class RpcRecvTensorCall : public BaseRecvTensorCall {
|
||||
void Reset() {
|
||||
delete wi_;
|
||||
wi_ = nullptr;
|
||||
wc_ = nullptr;
|
||||
allocator_ = nullptr;
|
||||
dst_device_ = nullptr;
|
||||
// We don't clear opts_ and assume that Init will set up the state for
|
||||
@ -123,6 +123,8 @@ class RpcRecvTensorCall : public BaseRecvTensorCall {
|
||||
const Rendezvous::DoneCallback& done() const { return done_; }
|
||||
|
||||
private:
|
||||
friend class RpcRemoteRendezvous;
|
||||
|
||||
// Start the main RecvTensor call, checking for an async abort.
|
||||
void StartRTCall(std::function<void()> recv_done) {
|
||||
wi_->RecvTensorAsync(&opts_, &req_, &resp_,
|
||||
@ -137,8 +139,9 @@ class RpcRecvTensorCall : public BaseRecvTensorCall {
|
||||
});
|
||||
}
|
||||
|
||||
WorkerInterface* wi_; // Owned.
|
||||
WorkerCacheInterface* wc_; // Not owned.
|
||||
string src_worker_;
|
||||
string src_rel_device_;
|
||||
WorkerInterface* wi_;
|
||||
Allocator* allocator_;
|
||||
Device* dst_device_;
|
||||
CallOptions opts_;
|
||||
@ -153,7 +156,6 @@ class RpcRecvTensorCall : public BaseRecvTensorCall {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(RpcRecvTensorCall);
|
||||
};
|
||||
|
||||
namespace {
|
||||
class RpcRecvTensorFreeList {
|
||||
public:
|
||||
RpcRecvTensorFreeList() {}
|
||||
@ -195,32 +197,99 @@ class RpcRecvTensorFreeList {
|
||||
};
|
||||
|
||||
static RpcRecvTensorFreeList call_freelist_;
|
||||
}
|
||||
|
||||
// A private cache that wraps env->worker_cache and allows reuse of
|
||||
// WorkerInterface objects.
|
||||
class WorkerFreeListCache : public WorkerCacheInterface {
|
||||
public:
|
||||
explicit WorkerFreeListCache(WorkerCacheInterface* w) : wrapped_(w) {}
|
||||
|
||||
~WorkerFreeListCache() {
|
||||
for (auto p : workers_) {
|
||||
delete p.second.worker;
|
||||
}
|
||||
}
|
||||
|
||||
void ListWorkers(std::vector<string>* workers) override {
|
||||
wrapped_->ListWorkers(workers);
|
||||
}
|
||||
|
||||
WorkerInterface* CreateWorker(const string& target) override {
|
||||
mutex_lock l(mu_);
|
||||
auto p = workers_.find(target);
|
||||
if (p != workers_.end()) {
|
||||
return p->second.worker;
|
||||
}
|
||||
WorkerState state;
|
||||
state.worker = wrapped_->CreateWorker(target);
|
||||
if (state.worker != nullptr) {
|
||||
workers_.insert(make_pair(target, state));
|
||||
}
|
||||
return state.worker;
|
||||
}
|
||||
|
||||
void ReleaseWorker(const string& target, WorkerInterface* worker) override {
|
||||
// TODO(jeff,sanjay): Should decrement ref-count when we implement eviction.
|
||||
}
|
||||
|
||||
bool GetDeviceBusNonBlocking(const string& device,
|
||||
BusAdjacency* ba) override {
|
||||
return wrapped_->GetDeviceBusNonBlocking(device, ba);
|
||||
}
|
||||
|
||||
void GetDeviceBusAsync(const string& device, BusAdjacency* ba,
|
||||
StatusCallback done) override {
|
||||
wrapped_->GetDeviceBusAsync(device, ba, done);
|
||||
}
|
||||
|
||||
void SetLogging(bool active) override { wrapped_->SetLogging(active); }
|
||||
|
||||
void ClearLogs() override { wrapped_->ClearLogs(); }
|
||||
|
||||
bool RetrieveLogs(int64 step_id, StepStats* ss) override {
|
||||
return wrapped_->RetrieveLogs(step_id, ss);
|
||||
}
|
||||
|
||||
private:
|
||||
WorkerCacheInterface* wrapped_;
|
||||
|
||||
// Information kept per created WorkerInterface.
|
||||
struct WorkerState {
|
||||
WorkerInterface* worker;
|
||||
// TODO(jeff,sanjay): Add reference count if we support eviction.
|
||||
};
|
||||
|
||||
// TODO(jeff,sanjay): Eviction when the map becomes too big.
|
||||
mutex mu_;
|
||||
std::unordered_map<string, WorkerState> workers_ GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
void RpcRemoteRendezvous::RecvFromRemoteAsync(
|
||||
const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& recv_args,
|
||||
DoneCallback done) {
|
||||
Status s;
|
||||
|
||||
// key.src_device identifies a remote device.
|
||||
string src_worker;
|
||||
string src_rel_device;
|
||||
if (!DeviceNameUtils::SplitDeviceName(parsed.src_device, &src_worker,
|
||||
&src_rel_device)) {
|
||||
s = errors::Internal(parsed.src_device,
|
||||
" is invalid remote source device.");
|
||||
}
|
||||
// TODO(jeff): Consider checking for a valid worker_cache during the
|
||||
// constructor of RpcRemoteRendezvous, rather than here, to simplify
|
||||
// the twisty logic below.
|
||||
WorkerCacheInterface* worker_cache = env_->worker_cache;
|
||||
if (s.ok() && worker_cache == nullptr) {
|
||||
if (env_->worker_cache == nullptr) {
|
||||
s = errors::Internal("No remote worker cache available.");
|
||||
done(s, Args(), recv_args, Tensor{}, false);
|
||||
return;
|
||||
}
|
||||
WorkerInterface* rwi =
|
||||
(worker_cache ? worker_cache->CreateWorker(src_worker) : nullptr);
|
||||
|
||||
// Prepare a RecvTensor call that can handle being aborted.
|
||||
RpcRecvTensorCall* call = call_freelist_.New();
|
||||
|
||||
// key.src_device identifies a remote device.
|
||||
if (!DeviceNameUtils::SplitDeviceName(parsed.src_device, &call->src_worker_,
|
||||
&call->src_rel_device_)) {
|
||||
s = errors::Internal(parsed.src_device,
|
||||
" is invalid remote source device.");
|
||||
}
|
||||
WorkerInterface* rwi = cache_->CreateWorker(call->src_worker_);
|
||||
if (s.ok() && rwi == nullptr) {
|
||||
s = errors::Internal("No worker known as ", src_worker);
|
||||
s = errors::Internal("No worker known as ", call->src_worker_);
|
||||
}
|
||||
|
||||
Device* dst_device;
|
||||
@ -228,21 +297,20 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync(
|
||||
s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_device);
|
||||
}
|
||||
if (!s.ok()) {
|
||||
call_freelist_.Release(call);
|
||||
done(s, Args(), recv_args, Tensor{}, false);
|
||||
return;
|
||||
}
|
||||
Allocator* allocator = dst_device->GetAllocator(recv_args.alloc_attrs);
|
||||
|
||||
// Prepare a RecvTensor call that can handle being aborted.
|
||||
RpcRecvTensorCall* call = call_freelist_.New();
|
||||
|
||||
call->Init(worker_cache, rwi, step_id_, parsed.FullKey(), allocator,
|
||||
dst_device, recv_args, std::move(done));
|
||||
call->Init(rwi, step_id_, parsed.FullKey(), allocator, dst_device, recv_args,
|
||||
std::move(done));
|
||||
|
||||
// Record "call" in active_ so that it can be aborted cleanly.
|
||||
RegisterCall(call);
|
||||
|
||||
// Start "call".
|
||||
Ref();
|
||||
call->Start([this, call]() {
|
||||
// Removes "call" from active_. Prevent StartAbort().
|
||||
DeregisterCall(call);
|
||||
@ -255,15 +323,22 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync(
|
||||
call->tensor_proto(), call->recv_args().alloc_attrs, &val);
|
||||
}
|
||||
call->done()(s, Args(), call->recv_args(), val, call->is_dead());
|
||||
cache_->ReleaseWorker(call->src_worker_, call->wi_);
|
||||
call->wi_ = nullptr;
|
||||
call_freelist_.Release(call);
|
||||
Unref();
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
RpcRendezvousMgr::RpcRendezvousMgr(const WorkerEnv* env)
|
||||
: BaseRendezvousMgr(env),
|
||||
cache_(new WorkerFreeListCache(env->worker_cache)) {}
|
||||
|
||||
BaseRemoteRendezvous* RpcRendezvousMgr::Create(int64 step_id,
|
||||
const WorkerEnv* worker_env) {
|
||||
return new RpcRemoteRendezvous(worker_env, step_id);
|
||||
return new RpcRemoteRendezvous(worker_env, cache_.get(), step_id);
|
||||
}
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_RPC_RENDEZVOUS_MGR_H_
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_cache.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_env.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
|
||||
@ -42,13 +43,16 @@ namespace tensorflow {
|
||||
// RendezvousMgr must have keys generated by Rendezvous::CreateKey.
|
||||
class RpcRendezvousMgr : public BaseRendezvousMgr {
|
||||
public:
|
||||
explicit RpcRendezvousMgr(const WorkerEnv* env) : BaseRendezvousMgr(env) {}
|
||||
explicit RpcRendezvousMgr(const WorkerEnv* env);
|
||||
|
||||
protected:
|
||||
BaseRemoteRendezvous* Create(int64 step_id,
|
||||
const WorkerEnv* worker_env) override;
|
||||
|
||||
private:
|
||||
// Private cache_ that allows us to reuse WorkerInterface objects.
|
||||
std::unique_ptr<WorkerCacheInterface> cache_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(RpcRendezvousMgr);
|
||||
};
|
||||
|
||||
|
@ -46,10 +46,28 @@ Rendezvous::ParsedKey MakeKey(const string& s) {
|
||||
return key;
|
||||
}
|
||||
|
||||
namespace {
|
||||
// Fake cache implementation for WorkerEnv.
|
||||
class DummyWorkerCache : public WorkerCacheInterface {
|
||||
void ListWorkers(std::vector<string>* workers) override {}
|
||||
WorkerInterface* CreateWorker(const string& target) override {
|
||||
return nullptr;
|
||||
}
|
||||
bool GetDeviceBusNonBlocking(const string& device,
|
||||
BusAdjacency* ba) override {
|
||||
return false;
|
||||
}
|
||||
void GetDeviceBusAsync(const string& device, BusAdjacency* ba,
|
||||
StatusCallback done) override {}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
TEST(RpcRendezvousMgrTest, LocalSendRecv) {
|
||||
DummyWorkerCache cache;
|
||||
WorkerEnv env;
|
||||
env.env = Env::Default();
|
||||
env.worker_name = "/job:mnist/replica:1/task:2";
|
||||
env.worker_cache = &cache;
|
||||
RpcRendezvousMgr rmgr(&env);
|
||||
const int64 step_id = 123;
|
||||
const Rendezvous::ParsedKey key = MakeKey(Rendezvous::CreateKey(
|
||||
@ -71,9 +89,11 @@ TEST(RpcRendezvousMgrTest, LocalSendRecv) {
|
||||
}
|
||||
|
||||
TEST(RpcRendezvousMgrTest, LocalAbort) {
|
||||
DummyWorkerCache cache;
|
||||
WorkerEnv env;
|
||||
env.env = Env::Default();
|
||||
env.worker_name = "/job:mnist/replica:1/task:2";
|
||||
env.worker_cache = &cache;
|
||||
RpcRendezvousMgr rmgr(&env);
|
||||
const Rendezvous::ParsedKey key = MakeKey(Rendezvous::CreateKey(
|
||||
"/job:mnist/replica:1/task:2/cpu:0", 7890,
|
||||
@ -107,9 +127,11 @@ TEST(RpcRendezvousMgrTest, LocalAbort) {
|
||||
}
|
||||
|
||||
TEST(RpcRendezvousMgrTest, CleanupAll) {
|
||||
DummyWorkerCache cache;
|
||||
WorkerEnv env;
|
||||
env.env = Env::Default();
|
||||
env.worker_name = "/job:mnist/replica:1/task:2";
|
||||
env.worker_cache = &cache;
|
||||
RpcRendezvousMgr rmgr(&env);
|
||||
const Rendezvous::ParsedKey key = MakeKey(Rendezvous::CreateKey(
|
||||
"/job:mnist/replica:1/task:2/cpu:0", 7890,
|
||||
@ -140,9 +162,11 @@ class DummyDeviceContext : public DeviceContext {
|
||||
TEST(RpcRendezvousMgrTest, TransferDummyDeviceContext) {
|
||||
DummyDeviceContext* dc = new DummyDeviceContext(123);
|
||||
|
||||
DummyWorkerCache cache;
|
||||
WorkerEnv env;
|
||||
env.env = Env::Default();
|
||||
env.worker_name = "/job:mnist/replica:1/task:2";
|
||||
env.worker_cache = &cache;
|
||||
RpcRendezvousMgr rmgr(&env);
|
||||
const int64 step_id = 123;
|
||||
const Rendezvous::ParsedKey key = MakeKey(Rendezvous::CreateKey(
|
||||
|
@ -19,7 +19,7 @@ limitations under the License.
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/worker_interface.h" // for CallOptions
|
||||
#include "tensorflow/core/distributed_runtime/worker_interface.h"
|
||||
#include "tensorflow/core/framework/device_attributes.pb.h" // for BusAdjacency
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
@ -28,7 +28,6 @@ typedef std::function<void(const Status&)> StatusCallback;
|
||||
|
||||
class ChannelCache;
|
||||
class StepStats;
|
||||
class WorkerInterface;
|
||||
|
||||
class WorkerCacheInterface {
|
||||
public:
|
||||
@ -46,6 +45,17 @@ class WorkerCacheInterface {
|
||||
// ownership, not a cache lookup.
|
||||
virtual WorkerInterface* CreateWorker(const string& target) = 0;
|
||||
|
||||
// Release a worker previously returned by this->CreateWorker(target).
|
||||
//
|
||||
// TODO(jeff,sanjay): Consider moving target into WorkerInterface.
|
||||
// TODO(jeff,sanjay): Consider disallowing direct deletion of WorkerInterface.
|
||||
// TODO(jeff,sanjay): Unify all worker-cache impls and factor out a
|
||||
// per-rpc-subsystem WorkerInterface creator.
|
||||
virtual void ReleaseWorker(const string& target, WorkerInterface* worker) {
|
||||
// Subclasses may override to reuse worker objects.
|
||||
delete worker;
|
||||
}
|
||||
|
||||
// Set *ba with the BusAdjacency of the specified remote device
|
||||
// within its local environment. Returns true if the device bus
|
||||
// affinity was set, using only locally cached data. Returns false
|
||||
|
@ -414,6 +414,99 @@ Status AvgPoolShape(shape_inference::InferenceContext* c) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MaxPoolShape(shape_inference::InferenceContext* c) {
|
||||
const Shape* input_shape;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
|
||||
|
||||
string data_format;
|
||||
Status s = c->GetAttr("data_format", &data_format);
|
||||
|
||||
std::vector<int32> strides;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
|
||||
if (strides.size() != 4) {
|
||||
return errors::InvalidArgument(
|
||||
"AvgPool requires the stride attribute to contain 4 values, but "
|
||||
"got: ",
|
||||
strides.size());
|
||||
}
|
||||
|
||||
std::vector<int32> kernel_sizes;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes));
|
||||
if (kernel_sizes.size() != 4) {
|
||||
return errors::InvalidArgument(
|
||||
"AvgPool requires the ksize attribute to contain 4 values, but got: ",
|
||||
kernel_sizes.size());
|
||||
}
|
||||
|
||||
int32 stride_rows, stride_cols, stride_depth;
|
||||
int32 kernel_rows, kernel_cols, kernel_depth;
|
||||
|
||||
if (s.ok() && data_format == "NCHW") {
|
||||
// Convert input shape to default NHWC for inference
|
||||
input_shape =
|
||||
c->MakeShape({{c->Dim(input_shape, 0), c->Dim(input_shape, 2),
|
||||
c->Dim(input_shape, 3), c->Dim(input_shape, 1)}});
|
||||
stride_depth = strides[1];
|
||||
stride_rows = strides[2];
|
||||
stride_cols = strides[3];
|
||||
kernel_depth = kernel_sizes[1];
|
||||
kernel_rows = kernel_sizes[2];
|
||||
kernel_cols = kernel_sizes[3];
|
||||
} else {
|
||||
stride_rows = strides[1];
|
||||
stride_cols = strides[2];
|
||||
stride_depth = strides[3];
|
||||
kernel_rows = kernel_sizes[1];
|
||||
kernel_cols = kernel_sizes[2];
|
||||
kernel_depth = kernel_sizes[3];
|
||||
}
|
||||
|
||||
const Dimension* batch_size_dim = c->Dim(input_shape, 0);
|
||||
const Dimension* in_rows_dim = c->Dim(input_shape, 1);
|
||||
const Dimension* in_cols_dim = c->Dim(input_shape, 2);
|
||||
const Dimension* in_depth_dim = c->Dim(input_shape, 3);
|
||||
|
||||
// At the moment we need to know the values of several fields.
|
||||
TF_RETURN_IF_ERROR(CheckKnownDim(c, in_rows_dim, "in_rows"));
|
||||
TF_RETURN_IF_ERROR(CheckKnownDim(c, in_cols_dim, "in_cols"));
|
||||
TF_RETURN_IF_ERROR(CheckKnownDim(c, in_depth_dim, "in_depth"));
|
||||
|
||||
Padding padding;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
|
||||
|
||||
// TODO(mrry,shlens): Raise an error if the stride would cause
|
||||
// information in the input to be ignored. This will require a change
|
||||
// in the kernel implementation.
|
||||
auto in_rows = c->Value(in_rows_dim);
|
||||
auto in_cols = c->Value(in_cols_dim);
|
||||
auto in_depth = c->Value(in_depth_dim);
|
||||
|
||||
int64 output_rows, output_cols, output_depth;
|
||||
int64 padding_before, padding_after;
|
||||
TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
|
||||
in_rows, kernel_rows, stride_rows, padding, &output_rows, &padding_before,
|
||||
&padding_after));
|
||||
TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
|
||||
in_cols, kernel_cols, stride_cols, padding, &output_cols, &padding_before,
|
||||
&padding_after));
|
||||
TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
|
||||
in_depth, kernel_depth, stride_depth, padding, &output_depth,
|
||||
&padding_before, &padding_after));
|
||||
|
||||
const Shape* output_shape =
|
||||
c->MakeShape({batch_size_dim, output_rows, output_cols, output_depth});
|
||||
|
||||
if (data_format == "NCHW") {
|
||||
// Convert output shape back to expected NCHW data format.
|
||||
output_shape =
|
||||
c->MakeShape({c->Dim(output_shape, 0), c->Dim(output_shape, 3),
|
||||
c->Dim(output_shape, 1), c->Dim(output_shape, 2)});
|
||||
}
|
||||
|
||||
c->set_output(0, output_shape);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status UnknownShape(shape_inference::InferenceContext* c) {
|
||||
for (int i = 0; i < c->num_outputs(); ++i) {
|
||||
c->set_output(i, c->UnknownShape());
|
||||
|
@ -163,6 +163,9 @@ Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c);
|
||||
// Shape function for AvgPool-like operations.
|
||||
Status AvgPoolShape(shape_inference::InferenceContext* c);
|
||||
|
||||
// Shape function for MaxPool-like operations.
|
||||
Status MaxPoolShape(shape_inference::InferenceContext* c);
|
||||
|
||||
// Shape function for use with ops whose output shapes are unknown.
|
||||
Status UnknownShape(shape_inference::InferenceContext* c);
|
||||
|
||||
|
@ -485,6 +485,33 @@ TEST(CommonShapeFnsTest, AvgPool2DShapeTest) {
|
||||
INFER_ERROR("must be rank 4", op, "[4,4]");
|
||||
}
|
||||
|
||||
TEST(CommonShapeFnsTest, MaxPool2DShapeTest) {
|
||||
ShapeInferenceTestOp op("MaxPool");
|
||||
auto set_op = [&op](const std::vector<int32>& strides,
|
||||
const std::vector<int32>& ksizes, const string& padding,
|
||||
const string& data_format) {
|
||||
TF_CHECK_OK(NodeDefBuilder("test", "MaxPool")
|
||||
.Input("input", 0, DT_FLOAT)
|
||||
.Attr("strides", strides)
|
||||
.Attr("ksize", ksizes)
|
||||
.Attr("padding", padding)
|
||||
.Attr("data_format", data_format)
|
||||
.Finalize(&op.node_def));
|
||||
};
|
||||
|
||||
// Most of the functionality is tested by conv-like shapes,
|
||||
// so we check the very-specific maxpooling features here,
|
||||
// namely depthwise kernel and striding.
|
||||
|
||||
// all 1 strides, depth 2 filter
|
||||
set_op({1, 1, 1, 1}, {1, 1, 1, 2}, "VALID", "NHWC");
|
||||
INFER_OK(op, "[1,2,2,2]", "[d0_0,2,2,1]");
|
||||
|
||||
// depth 3 stride, 1x1x1 filter, NCHW
|
||||
set_op({1, 3, 1, 1}, {1, 1, 1, 1}, "VALID", "NCHW");
|
||||
INFER_OK(op, "[1,7,5,5]", "[d0_0,3,5,5]");
|
||||
}
|
||||
|
||||
TEST(CommonShapeFnsTest, UnknownShapeTest) {
|
||||
{
|
||||
// Single output
|
||||
|
@ -492,6 +492,10 @@ Status InferenceContext::MakeDimForScalarInput(int idx, const Dimension** out) {
|
||||
*out = UnknownDim();
|
||||
return Status::OK();
|
||||
}
|
||||
const int rank = t->dims();
|
||||
if (rank != 0) {
|
||||
return errors::InvalidArgument("Input must be scalar but has rank ", rank);
|
||||
}
|
||||
|
||||
int64 val;
|
||||
if (t->dtype() == DT_INT32) {
|
||||
|
@ -58,6 +58,8 @@ class BCastGradArgsOp : public OpKernel {
|
||||
Output(ctx, 1, bcast.grad_y_reduce_idx());
|
||||
}
|
||||
|
||||
bool IsExpensive() override { return false; }
|
||||
|
||||
private:
|
||||
void Output(OpKernelContext* ctx, int idx, const BCast::Vec& v) {
|
||||
const int64 len = v.size();
|
||||
|
@ -85,26 +85,28 @@ class OneHotOp : public OpKernel {
|
||||
Tensor* output;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &output));
|
||||
|
||||
// prefix_dim_size == # of elements before the axis
|
||||
// depth_v == # of elements per axis
|
||||
// suffix_dim_size == # of elements after the axis
|
||||
int64 prefix_dim_size = 1;
|
||||
for (int i = 0; i < axis; ++i) {
|
||||
prefix_dim_size *= indices_shape.dim_size(i);
|
||||
if (output_shape.num_elements() > 0) {
|
||||
// prefix_dim_size == # of elements before the axis
|
||||
// depth_v == # of elements per axis
|
||||
// suffix_dim_size == # of elements after the axis
|
||||
int64 prefix_dim_size = 1;
|
||||
for (int i = 0; i < axis; ++i) {
|
||||
prefix_dim_size *= indices_shape.dim_size(i);
|
||||
}
|
||||
TI suffix_dim_size = indices_shape.num_elements() / prefix_dim_size;
|
||||
|
||||
// Split indices into matrix of size prefix_dim_size x suffix_dim_size
|
||||
auto indices_t =
|
||||
indices.shaped<TI, 2>({prefix_dim_size, suffix_dim_size});
|
||||
// Split output into 3-Tensor of size:
|
||||
// prefix_dim_size x depth x suffix_dim_size.
|
||||
auto output_t =
|
||||
output->shaped<T, 3>({prefix_dim_size, depth_v, suffix_dim_size});
|
||||
|
||||
functor::OneHot<Device, T, TI>::Compute(ctx->eigen_device<Device>(),
|
||||
indices_t, on_value_t,
|
||||
off_value_t, &output_t);
|
||||
}
|
||||
TI suffix_dim_size =
|
||||
indices_shape.num_elements() / prefix_dim_size;
|
||||
|
||||
// Split indices into matrix of size prefix_dim_size x suffix_dim_size
|
||||
auto indices_t =
|
||||
indices.shaped<TI, 2>({prefix_dim_size, suffix_dim_size});
|
||||
// Split output into 3-Tensor of size:
|
||||
// prefix_dim_size x depth x suffix_dim_size.
|
||||
auto output_t =
|
||||
output->shaped<T, 3>({prefix_dim_size, depth_v, suffix_dim_size});
|
||||
|
||||
functor::OneHot<Device, T, TI>::Compute(ctx->eigen_device<Device>(), indices_t,
|
||||
on_value_t, off_value_t, &output_t);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -113,12 +115,12 @@ class OneHotOp : public OpKernel {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(OneHotOp);
|
||||
};
|
||||
|
||||
#define REGISTER_ONE_HOT_INDEX(type, index_type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("OneHot") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<index_type>("TI") \
|
||||
.TypeConstraint<type>("T") \
|
||||
.HostMemory("depth"), \
|
||||
#define REGISTER_ONE_HOT_INDEX(type, index_type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("OneHot") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<index_type>("TI") \
|
||||
.TypeConstraint<type>("T") \
|
||||
.HostMemory("depth"), \
|
||||
OneHotOp<CPUDevice, type, index_type>);
|
||||
|
||||
#define REGISTER_ONE_HOT(type) \
|
||||
@ -132,13 +134,13 @@ TF_CALL_ALL_TYPES(REGISTER_ONE_HOT);
|
||||
|
||||
// Forward declarations of the functor specializations for GPU.
|
||||
namespace functor {
|
||||
#define DECLARE_GPU_SPEC_INDEX(T, TI) \
|
||||
template <> \
|
||||
void OneHot<GPUDevice, T, TI>::Compute( \
|
||||
const GPUDevice& d, const typename TTypes<TI>::ConstMatrix& indices, \
|
||||
const typename TTypes<T>::ConstScalar& on_value, \
|
||||
const typename TTypes<T>::ConstScalar& off_value, \
|
||||
typename TTypes<T, 3>::Tensor* output); \
|
||||
#define DECLARE_GPU_SPEC_INDEX(T, TI) \
|
||||
template <> \
|
||||
void OneHot<GPUDevice, T, TI>::Compute( \
|
||||
const GPUDevice& d, const typename TTypes<TI>::ConstMatrix& indices, \
|
||||
const typename TTypes<T>::ConstScalar& on_value, \
|
||||
const typename TTypes<T>::ConstScalar& off_value, \
|
||||
typename TTypes<T, 3>::Tensor* output); \
|
||||
extern template struct OneHot<GPUDevice, T, TI>;
|
||||
|
||||
#define DECLARE_GPU_SPEC(T) \
|
||||
@ -154,12 +156,12 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
|
||||
} // namespace functor
|
||||
|
||||
// Registration of the GPU implementations.
|
||||
#define REGISTER_ONE_HOT_GPU_INDEX(type, index_type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("OneHot") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<index_type>("TI") \
|
||||
.TypeConstraint<type>("T") \
|
||||
.HostMemory("depth"), \
|
||||
#define REGISTER_ONE_HOT_GPU_INDEX(type, index_type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("OneHot") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<index_type>("TI") \
|
||||
.TypeConstraint<type>("T") \
|
||||
.HostMemory("depth"), \
|
||||
OneHotOp<GPUDevice, type, index_type>);
|
||||
|
||||
#define REGISTER_ONE_HOT_GPU(type) \
|
||||
|
@ -111,6 +111,17 @@ class MaxPoolingOp : public OpKernel {
|
||||
0, params.forward_output_shape(), &output));
|
||||
|
||||
if (params.depth_window > 1) {
|
||||
// Validate spec against the current implementation. A
|
||||
// relaxation of these requirements would be ideal.
|
||||
OP_REQUIRES(context, params.depth % params.depth_window == 0,
|
||||
errors::Unimplemented(
|
||||
"Depthwise max pooling requires "
|
||||
"the depth window to evenly divide the input depth."));
|
||||
OP_REQUIRES(
|
||||
context, params.depth_window == params.depth_stride,
|
||||
errors::Unimplemented("Depthwise max pooling requires "
|
||||
"the depth window to equal the depth stride."));
|
||||
|
||||
DepthwiseMaxPool(context, output, tensor_in, params);
|
||||
} else {
|
||||
SpatialMaxPool(context, output, tensor_in, params, padding_);
|
||||
|
@ -118,16 +118,23 @@ class LinSpaceOp : public OpKernel {
|
||||
}
|
||||
};
|
||||
|
||||
#define REGISTER_CPU_KERNEL(T) \
|
||||
#define REGISTER_KERNEL(DEV, T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("LinSpace") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.Device(DEV) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.HostMemory("start") \
|
||||
.HostMemory("stop") \
|
||||
.HostMemory("num") \
|
||||
.HostMemory("output"), \
|
||||
LinSpaceOp<T>);
|
||||
#define REGISTER_CPU_KERNEL(T) REGISTER_KERNEL(DEVICE_CPU, T)
|
||||
TF_CALL_float(REGISTER_CPU_KERNEL);
|
||||
TF_CALL_double(REGISTER_CPU_KERNEL);
|
||||
|
||||
// NOTE(touts): We register the op on GPU but it still runs on CPU
|
||||
// because its inputs and outputs are tagged as HostMemory.
|
||||
#define REGISTER_GPU_KERNEL(T) REGISTER_KERNEL(DEVICE_GPU, T)
|
||||
TF_CALL_float(REGISTER_GPU_KERNEL);
|
||||
TF_CALL_double(REGISTER_GPU_KERNEL);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -1,31 +0,0 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/lib/monitoring/counter.h"
|
||||
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace monitoring {
|
||||
|
||||
void CounterCell::IncrementBy(const int64 step) {
|
||||
DCHECK_LE(0, step) << "Must not decrement cumulative metrics.";
|
||||
value_ += step;
|
||||
}
|
||||
|
||||
int64 CounterCell::value() const { return value_; }
|
||||
|
||||
} // namespace monitoring
|
||||
} // namespace tensorflow
|
@ -20,6 +20,8 @@ limitations under the License.
|
||||
#include <atomic>
|
||||
#include <map>
|
||||
|
||||
#include "tensorflow/core/lib/monitoring/metric_def.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
@ -71,9 +73,12 @@ class CounterCell {
|
||||
template <int NumLabels>
|
||||
class Counter {
|
||||
public:
|
||||
Counter() {}
|
||||
~Counter() {}
|
||||
|
||||
explicit Counter(
|
||||
const MetricDef<MetricKind::CUMULATIVE, int64, NumLabels>& metric_def)
|
||||
: metric_def_(metric_def) {}
|
||||
|
||||
// Retrieves the cell for the specified labels, creating it on demand if
|
||||
// not already present.
|
||||
template <typename... Labels>
|
||||
@ -82,6 +87,10 @@ class Counter {
|
||||
private:
|
||||
mutable mutex mu_;
|
||||
|
||||
// The metric definition. This will be used to identify the metric when we
|
||||
// register it for exporting.
|
||||
const MetricDef<MetricKind::CUMULATIVE, int64, NumLabels> metric_def_;
|
||||
|
||||
using LabelArray = std::array<string, NumLabels>;
|
||||
std::map<LabelArray, CounterCell> cells_ GUARDED_BY(mu_);
|
||||
|
||||
@ -92,6 +101,13 @@ class Counter {
|
||||
// Implementation details follow. API readers may skip.
|
||||
////
|
||||
|
||||
inline void CounterCell::IncrementBy(const int64 step) {
|
||||
DCHECK_LE(0, step) << "Must not decrement cumulative metrics.";
|
||||
value_ += step;
|
||||
}
|
||||
|
||||
inline int64 CounterCell::value() const { return value_; }
|
||||
|
||||
template <int NumLabels>
|
||||
template <typename... Labels>
|
||||
CounterCell* Counter<NumLabels>::GetCell(const Labels&... labels)
|
||||
|
@ -19,26 +19,28 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
namespace monitoring {
|
||||
namespace {
|
||||
|
||||
class LabeledCounterTest : public ::testing::Test {
|
||||
protected:
|
||||
LabeledCounterTest() {}
|
||||
|
||||
Counter<1> counter_;
|
||||
Counter<1> counter_with_labels_{{"/tensorflow/test/counter_with_labels_",
|
||||
"Counter with one label.", "One label"}};
|
||||
};
|
||||
|
||||
TEST_F(LabeledCounterTest, InitializedWithZero) {
|
||||
EXPECT_EQ(0, counter_.GetCell("Empty")->value());
|
||||
EXPECT_EQ(0, counter_with_labels_.GetCell("Empty")->value());
|
||||
}
|
||||
|
||||
TEST_F(LabeledCounterTest, GetCell) {
|
||||
auto* cell = counter_.GetCell("GetCellOp");
|
||||
auto* cell = counter_with_labels_.GetCell("GetCellOp");
|
||||
EXPECT_EQ(0, cell->value());
|
||||
|
||||
cell->IncrementBy(42);
|
||||
EXPECT_EQ(42, cell->value());
|
||||
|
||||
auto* same_cell = counter_.GetCell("GetCellOp");
|
||||
auto* same_cell = counter_with_labels_.GetCell("GetCellOp");
|
||||
EXPECT_EQ(42, same_cell->value());
|
||||
|
||||
same_cell->IncrementBy(58);
|
||||
@ -49,29 +51,31 @@ TEST_F(LabeledCounterTest, GetCell) {
|
||||
using LabeledCounterDeathTest = LabeledCounterTest;
|
||||
|
||||
TEST_F(LabeledCounterDeathTest, DiesOnDecrement) {
|
||||
EXPECT_DEBUG_DEATH({ counter_.GetCell("DyingOp")->IncrementBy(-1); },
|
||||
"decrement");
|
||||
EXPECT_DEBUG_DEATH(
|
||||
{ counter_with_labels_.GetCell("DyingOp")->IncrementBy(-1); },
|
||||
"decrement");
|
||||
}
|
||||
|
||||
class UnlabeledCounterTest : public ::testing::Test {
|
||||
protected:
|
||||
UnlabeledCounterTest() {}
|
||||
|
||||
Counter<0> counter_;
|
||||
Counter<0> counter_without_labels_{
|
||||
{"/tensorflow/test/counter0", "Counter without any labels."}};
|
||||
};
|
||||
|
||||
TEST_F(UnlabeledCounterTest, InitializedWithZero) {
|
||||
EXPECT_EQ(0, counter_.GetCell()->value());
|
||||
EXPECT_EQ(0, counter_without_labels_.GetCell()->value());
|
||||
}
|
||||
|
||||
TEST_F(UnlabeledCounterTest, GetCell) {
|
||||
auto* cell = counter_.GetCell();
|
||||
auto* cell = counter_without_labels_.GetCell();
|
||||
EXPECT_EQ(0, cell->value());
|
||||
|
||||
cell->IncrementBy(42);
|
||||
EXPECT_EQ(42, cell->value());
|
||||
|
||||
auto* same_cell = counter_.GetCell();
|
||||
auto* same_cell = counter_without_labels_.GetCell();
|
||||
EXPECT_EQ(42, same_cell->value());
|
||||
|
||||
same_cell->IncrementBy(58);
|
||||
@ -82,8 +86,10 @@ TEST_F(UnlabeledCounterTest, GetCell) {
|
||||
using UnlabeledCounterDeathTest = UnlabeledCounterTest;
|
||||
|
||||
TEST_F(UnlabeledCounterDeathTest, DiesOnDecrement) {
|
||||
EXPECT_DEBUG_DEATH({ counter_.GetCell()->IncrementBy(-1); }, "decrement");
|
||||
EXPECT_DEBUG_DEATH({ counter_without_labels_.GetCell()->IncrementBy(-1); },
|
||||
"decrement");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace monitoring
|
||||
} // namespace tensorflow
|
||||
|
128
tensorflow/core/lib/monitoring/metric_def.h
Normal file
128
tensorflow/core/lib/monitoring/metric_def.h
Normal file
@ -0,0 +1,128 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_METRIC_DEF_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_METRIC_DEF_H_
|
||||
|
||||
#include <array>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace monitoring {
|
||||
|
||||
// Everything in the internal namespace is implementation details. Do not depend
|
||||
// on this.
|
||||
namespace internal {
|
||||
|
||||
// Ensures that the string is a compile-time string literal.
|
||||
class StringLiteral {
|
||||
public:
|
||||
// We allow implicit conversions here on purpose.
|
||||
template <int N>
|
||||
StringLiteral(const char (&data)[N]) : literal_(data, N) {}
|
||||
|
||||
// This ctor will be called for non-literals, causing compile-time failure.
|
||||
template <typename NotStringLiteral>
|
||||
StringLiteral(const NotStringLiteral& not_string_literal) = delete;
|
||||
|
||||
// Implicit conversion to StringPiece.
|
||||
operator StringPiece() const { return literal_; }
|
||||
|
||||
private:
|
||||
const StringPiece literal_;
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
|
||||
// The different metric kinds available.
|
||||
//
|
||||
// Gauge indicates that the metric's values are instantaneous measurements of a
|
||||
// (typically) continuously varying quantity. Examples: a process's current heap
|
||||
// size, a queue's current length.
|
||||
//
|
||||
// Cumulative indicates that the metric's values represent non-negative changes
|
||||
// over specified time periods. Example: the number of rpc calls to a service.
|
||||
enum MetricKind { GAUGE, CUMULATIVE };
|
||||
|
||||
// Abstract base class for a metric definition.
|
||||
//
|
||||
// Unlike MetricDef, this class is non-templatized and allows storing and
|
||||
// accessing metric definitions without the full type information.
|
||||
//
|
||||
// Everything except the value type of a metric is stored here. Please read
|
||||
// MetricDef class comments for more details.
|
||||
class AbstractMetricDef {
|
||||
public:
|
||||
MetricKind kind() const { return kind_; }
|
||||
|
||||
StringPiece name() const { return name_; }
|
||||
|
||||
StringPiece description() const { return description_; }
|
||||
|
||||
const std::vector<StringPiece> label_descriptions() const {
|
||||
return label_descriptions_;
|
||||
}
|
||||
|
||||
private:
|
||||
template <MetricKind kind, typename Value, int NumLabels>
|
||||
friend class MetricDef;
|
||||
|
||||
AbstractMetricDef(
|
||||
const MetricKind kind, const internal::StringLiteral name,
|
||||
const internal::StringLiteral description,
|
||||
const std::vector<internal::StringLiteral>& label_descriptions)
|
||||
: kind_(kind),
|
||||
name_(name),
|
||||
description_(description),
|
||||
label_descriptions_(
|
||||
{label_descriptions.begin(), label_descriptions.end()}) {}
|
||||
|
||||
const MetricKind kind_;
|
||||
const StringPiece name_;
|
||||
const StringPiece description_;
|
||||
const std::vector<StringPiece> label_descriptions_;
|
||||
};
|
||||
|
||||
// Metric definition.
|
||||
//
|
||||
// A metric is defined by its kind, value-type, name, description and the
|
||||
// description of its labels.
|
||||
//
|
||||
// NOTE: We allow only string literals for the name, description and label
|
||||
// descriptions because these should be fixed at compile-time and shouldn't be
|
||||
// dynamic.
|
||||
template <MetricKind metric_kind, typename Value, int NumLabels>
|
||||
class MetricDef : public AbstractMetricDef {
|
||||
public:
|
||||
using value_type = Value;
|
||||
|
||||
template <typename... LabelDesc>
|
||||
MetricDef(const internal::StringLiteral name,
|
||||
const internal::StringLiteral description,
|
||||
const LabelDesc&... label_descriptions)
|
||||
: AbstractMetricDef(metric_kind, name, description,
|
||||
{label_descriptions...}) {
|
||||
static_assert(sizeof...(LabelDesc) == NumLabels,
|
||||
"Mismatch between Counter<NumLabels> and number of label "
|
||||
"descriptions.");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace monitoring
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CORE_LIB_MONITORING_METRIC_DEF_H_
|
@ -344,6 +344,34 @@ REGISTER_OP("Split")
|
||||
.Output("output: num_split * T")
|
||||
.Attr("num_split: int >= 1")
|
||||
.Attr("T: type")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
const Dimension* split_dimension;
|
||||
TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(0, &split_dimension));
|
||||
int num_split = c->num_outputs();
|
||||
const Shape* input = c->input(1);
|
||||
const Shape* out;
|
||||
if (!c->ValueKnown(split_dimension)) {
|
||||
if (c->RankKnown(input)) {
|
||||
std::vector<const Dimension*> dims;
|
||||
dims.resize(c->Rank(input));
|
||||
for (int i = 0; i < dims.size(); ++i) dims[i] = c->UnknownDim();
|
||||
out = c->MakeShape(dims);
|
||||
} else {
|
||||
out = c->UnknownShape();
|
||||
}
|
||||
} else {
|
||||
int64 split_dim = c->Value(split_dimension);
|
||||
TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, split_dim + 1, &input));
|
||||
const Dimension* split_dim_size;
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||
c->Divide(c->Dim(input, split_dim), num_split, &split_dim_size),
|
||||
"Number of ways to split should evenly divide the split dimension");
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->ReplaceDim(input, split_dim, split_dim_size, &out));
|
||||
}
|
||||
for (int i = 0; i < num_split; ++i) c->set_output(i, out);
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
Splits a tensor into `num_split` tensors along one dimension.
|
||||
|
||||
@ -1443,7 +1471,7 @@ Status ShapeShapeFn(InferenceContext* c) {
|
||||
for (int i = 0; i < c->num_inputs(); ++i) {
|
||||
const Dimension* dim;
|
||||
if (c->RankKnown(c->input(i))) {
|
||||
dim = c->MakeDim(c->Rank(c->input(0)));
|
||||
dim = c->MakeDim(c->Rank(c->input(i)));
|
||||
} else {
|
||||
dim = c->UnknownDim();
|
||||
}
|
||||
@ -1479,6 +1507,7 @@ REGISTER_OP("ShapeN")
|
||||
.Output("output: N * int32")
|
||||
.Attr("N: int")
|
||||
.Attr("T: type")
|
||||
.SetShapeFn(ShapeShapeFn)
|
||||
.Doc(R"doc(
|
||||
Returns shape of tensors.
|
||||
|
||||
@ -1493,6 +1522,42 @@ REGISTER_OP("ReverseSequence")
|
||||
.Attr("seq_dim: int")
|
||||
.Attr("batch_dim: int = 0")
|
||||
.Attr("T: type")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
const Shape* input = c->input(0);
|
||||
const Shape* seq_lens_shape;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &seq_lens_shape));
|
||||
|
||||
int64 seq_dim;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("seq_dim", &seq_dim));
|
||||
int64 batch_dim;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("batch_dim", &batch_dim));
|
||||
|
||||
if (!c->RankKnown(input)) {
|
||||
return shape_inference::UnknownShape(c);
|
||||
}
|
||||
|
||||
// Validate batch_dim and seq_dim against input.
|
||||
const int32 input_rank = c->Rank(input);
|
||||
if (batch_dim >= input_rank) {
|
||||
return errors::InvalidArgument("batch_dim must be < input rank: ",
|
||||
batch_dim, " vs. ", input_rank);
|
||||
}
|
||||
if (seq_dim >= input_rank) {
|
||||
return errors::InvalidArgument("seq_dim must be < input rank: ",
|
||||
seq_dim, " vs. ", input_rank);
|
||||
}
|
||||
|
||||
const Dimension* batch_dim_dim = c->Dim(input, batch_dim);
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->Merge(batch_dim_dim, c->Dim(seq_lens_shape, 0), &batch_dim_dim));
|
||||
|
||||
// Replace batch_dim of input with batch_size
|
||||
const Shape* output_shape;
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->ReplaceDim(input, batch_dim, batch_dim_dim, &output_shape));
|
||||
c->set_output(0, output_shape);
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
Reverses variable length slices.
|
||||
|
||||
@ -1564,6 +1629,7 @@ REGISTER_OP("Rank")
|
||||
.Input("input: T")
|
||||
.Output("output: int32")
|
||||
.Attr("T: type")
|
||||
.SetShapeFn(shape_inference::ScalarShape)
|
||||
.Doc(R"doc(
|
||||
Returns the rank of a tensor.
|
||||
|
||||
@ -1587,6 +1653,7 @@ REGISTER_OP("Size")
|
||||
.Input("input: T")
|
||||
.Output("output: int32")
|
||||
.Attr("T: type")
|
||||
.SetShapeFn(shape_inference::ScalarShape)
|
||||
.Doc(R"doc(
|
||||
Returns the size of a tensor.
|
||||
|
||||
@ -1669,7 +1736,7 @@ begin_mask: a bitmask where a bit i being 1 means to ignore the begin
|
||||
begin[i] will be replaced with `[0, n-1) if `stride[i] > 0` or
|
||||
`[-1, n-1]` if `stride[i] < 0`
|
||||
end_mask: analogous to `begin_mask`
|
||||
ellipsis_mask: a bitmask where bit `i` being 1 means the `i`th
|
||||
ellipsis_mask: a bitmask where bit `i` being 1 means the `i`th
|
||||
position is actually an ellipsis. One bit at most can be 1.
|
||||
new_axis_mask: a bitmask where bit `i` being 1 means the `i`th
|
||||
position creates a dimension in the tensor of length 1. Thus
|
||||
@ -1678,7 +1745,7 @@ new_axis_mask: a bitmask where bit `i` being 1 means the `i`th
|
||||
shrink_axis_mask: a bitmask where bit `i` implies that the `i`th
|
||||
position should shrink the dimensionality. begin and end
|
||||
must imply a slice of size 1 in the dimension. For example in
|
||||
python one might do `foo[:,3,:]` which would result in
|
||||
python one might do `foo[:,3,:]` which would result in
|
||||
`shrink_axis_mask` being 2.
|
||||
)doc");
|
||||
|
||||
@ -1705,7 +1772,7 @@ as `shape`). The gradient will be zero in any element that the slice
|
||||
does not select.
|
||||
|
||||
Arguments are the same as StridedSliceGrad with the exception that
|
||||
`dy` is the input gradient to be propagated and `shape` is the
|
||||
`dy` is the input gradient to be propagated and `shape` is the
|
||||
shape of `StridedSlice`'s `input`.
|
||||
)doc");
|
||||
|
||||
@ -1744,7 +1811,14 @@ each repeated tile of `input` into `output`.
|
||||
)doc");
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
REGISTER_OP("Where").Input("input: bool").Output("index: int64").Doc(R"doc(
|
||||
REGISTER_OP("Where")
|
||||
.Input("input: bool")
|
||||
.Output("index: int64")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
c->set_output(0, c->Matrix(c->UnknownDim(), c->Rank(c->input(0))));
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
Returns locations of true values in a boolean tensor.
|
||||
|
||||
This operation returns the coordinates of true elements in `input`. The
|
||||
|
@ -237,6 +237,20 @@ TEST(ArrayOpsTest, Shape_ShapeFn) {
|
||||
INFER_OK(op, "[?,2,3,4,5]", "[5]");
|
||||
}
|
||||
|
||||
TEST(ArrayOpsTest, ShapeN_ShapeFn) {
|
||||
ShapeInferenceTestOp op("ShapeN");
|
||||
int n = 3;
|
||||
std::vector<NodeDefBuilder::NodeOut> src_list;
|
||||
for (int i = 0; i < n; ++i) src_list.emplace_back("a", 0, DT_FLOAT);
|
||||
TF_CHECK_OK(NodeDefBuilder("test", "ShapeN")
|
||||
.Input(src_list)
|
||||
.Attr("N", n)
|
||||
.Finalize(&op.node_def));
|
||||
INFER_OK(op, "?;?;?", "[?];[?];[?]");
|
||||
INFER_OK(op, "[?];[?];[?]", "[1];[1];[1]");
|
||||
INFER_OK(op, "[?,2,3,4,5];?;[1,?,3]", "[5];[?];[3]");
|
||||
}
|
||||
|
||||
TEST(ArrayOpsTest, Unique_ShapeFn) {
|
||||
ShapeInferenceTestOp op("Unique");
|
||||
INFER_OK(op, "?", "[?];in0");
|
||||
@ -696,4 +710,61 @@ TEST(ArrayOpsTest, Squeeze_ShapeFn) {
|
||||
INFER_ERROR("not in [-3,3)", op, "[1,2,3]");
|
||||
}
|
||||
|
||||
TEST(ArrayOpsTest, ReverseSequence_ShapeFn) {
|
||||
ShapeInferenceTestOp op("ReverseSequence");
|
||||
auto rebuild_node_def = [&op](const int32 seq_dim, const int32 batch_dim) {
|
||||
TF_CHECK_OK(NodeDefBuilder("test", "ReverseSequence")
|
||||
.Input("input", 0, DT_FLOAT)
|
||||
.Input("seq_lengths", 1, DT_INT64)
|
||||
.Attr("seq_dim", seq_dim)
|
||||
.Attr("batch_dim", batch_dim)
|
||||
.Finalize(&op.node_def));
|
||||
};
|
||||
|
||||
rebuild_node_def(1, 2);
|
||||
// No valid shape provided, so output is unknown.
|
||||
INFER_OK(op, "?;[10]", "?");
|
||||
|
||||
// Bad rank for seq_lengths
|
||||
INFER_ERROR("Shape must be rank 1 but is rank 2", op, "?;[10,10]");
|
||||
|
||||
// Validate seq_dim and batch_dim
|
||||
rebuild_node_def(1, 4);
|
||||
INFER_ERROR("batch_dim must be < input rank", op, "[1,2,3];[3]");
|
||||
rebuild_node_def(4, 1);
|
||||
INFER_ERROR("seq_dim must be < input rank", op, "[1,2,3];[3]");
|
||||
|
||||
rebuild_node_def(1, 2);
|
||||
INFER_OK(op, "[1,2,3];[3]", "[d0_0,d0_1,d0_2]");
|
||||
// Resolves uncertainty on batch dimension by merging.
|
||||
INFER_OK(op, "[1,2,?];[3]", "[d0_0,d0_1,d1_0]");
|
||||
INFER_OK(op, "[1,2,3];[?]", "[d0_0,d0_1,d0_2]");
|
||||
}
|
||||
|
||||
TEST(ArrayOpsTest, Split_ShapeFn) {
|
||||
ShapeInferenceTestOp op("Split");
|
||||
op.input_tensors.resize(2);
|
||||
|
||||
// No value for split_dim and no input.
|
||||
TF_CHECK_OK(NodeDefBuilder("test", "Split")
|
||||
.Input("split_dim", 0, DT_INT32)
|
||||
.Input("value", 1, DT_FLOAT)
|
||||
.Attr("num_split", 2)
|
||||
.Finalize(&op.node_def));
|
||||
INFER_OK(op, "?;?", "?;?");
|
||||
// If the rank is known, we know the rank of each output.
|
||||
INFER_OK(op, "?;[?,?]", "[?,?];[?,?]");
|
||||
|
||||
// split_dim is known.
|
||||
Tensor split_dim = test::AsTensor<int32>({1, 2});
|
||||
op.input_tensors[0] = &split_dim;
|
||||
INFER_ERROR("Input must be scalar but has rank 1", op, "[?];[?,?]");
|
||||
split_dim = test::AsScalar<int32>(1);
|
||||
INFER_OK(op, "?;?", "?;?");
|
||||
INFER_OK(op, "?;[?,?]", "[d1_0,?];[d1_0,?]");
|
||||
INFER_OK(op, "?;[1,4]", "[d1_0,2];[d1_0,2]");
|
||||
INFER_OK(op, "?;[1,?]", "[d1_0,?];[d1_0,?]");
|
||||
INFER_ERROR("Dimension size must be divisible by 2 but is 5", op, "?;[1,5]");
|
||||
}
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
@ -858,6 +858,7 @@ REGISTER_OP("MaxPool")
|
||||
.Attr(GetConvnetDataFormatAttrString())
|
||||
.Input("input: T")
|
||||
.Output("output: T")
|
||||
.SetShapeFn(shape_inference::MaxPoolShape)
|
||||
.Doc(R"doc(
|
||||
Performs max pooling on the input.
|
||||
|
||||
@ -914,6 +915,11 @@ REGISTER_OP("MaxPoolWithArgmax")
|
||||
.Output("output: T")
|
||||
.Output("argmax: Targmax")
|
||||
.Attr("T: {float, half} = DT_FLOAT")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
TF_RETURN_IF_ERROR(shape_inference::MaxPoolShape(c));
|
||||
c->set_output(1, c->output(0));
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
Performs max pooling on the input and outputs both max values and indices.
|
||||
|
||||
|
@ -21,7 +21,7 @@ Then:
|
||||
```python
|
||||
import tensorflow as tf
|
||||
|
||||
with tf.Session("local"):
|
||||
with tf.Session():
|
||||
input1 = tf.constant(1.0, shape=[1, 1], name="input1")
|
||||
input2 = tf.constant(2.0, shape=[1, 1], name="input2")
|
||||
output = tf.matmul(input1, input2)
|
||||
|
@ -373,8 +373,7 @@ the session constructor.
|
||||
|
||||
|
||||
* <b>`target`</b>: (Optional.) The execution engine to connect to.
|
||||
Defaults to using an in-process engine. At present, no value
|
||||
other than the empty string is supported.
|
||||
Defaults to using an in-process engine.
|
||||
* <b>`graph`</b>: (Optional.) The `Graph` to be launched (described above).
|
||||
* <b>`config`</b>: (Optional) `ConfigProto` proto used to configure the session.
|
||||
|
||||
|
@ -53,8 +53,7 @@ the session constructor.
|
||||
|
||||
|
||||
* <b>`target`</b>: (Optional.) The execution engine to connect to.
|
||||
Defaults to using an in-process engine. At present, no value
|
||||
other than the empty string is supported.
|
||||
Defaults to using an in-process engine.
|
||||
* <b>`graph`</b>: (Optional.) The `Graph` to be launched (described above).
|
||||
* <b>`config`</b>: (Optional) `ConfigProto` proto used to configure the session.
|
||||
|
||||
|
@ -147,7 +147,7 @@ graphs and running steps; we also have an experimental API for
|
||||
|
||||
We would like to support more client languages, as determined by community
|
||||
interest. TensorFlow has a
|
||||
[C-based client API](https://www.tensorflow.org/code/tensorflow/core/public/tensor_c_api.h)
|
||||
[C-based client API](https://www.tensorflow.org/code/tensorflow/c/c_api.h)
|
||||
that makes it easy to build a client in many different languages. We invite
|
||||
contributions of new language bindings.
|
||||
|
||||
|
@ -2,6 +2,10 @@
|
||||
|
||||
## Basic Neural Networks
|
||||
|
||||
The first few Tensorflow tutorials guide you through training and testing a
|
||||
simple neural network to classify handwritten digits from the MNIST database of
|
||||
digit images.
|
||||
|
||||
### MNIST For ML Beginners
|
||||
|
||||
If you're new to machine learning, we recommend starting here. You'll learn
|
||||
@ -27,13 +31,6 @@ example.
|
||||
|
||||
[View Tutorial](../tutorials/mnist/tf/index.md)
|
||||
|
||||
### MNIST Data Download
|
||||
|
||||
Details about downloading the MNIST handwritten digits data set. Exciting
|
||||
stuff.
|
||||
|
||||
[View Tutorial](../tutorials/mnist/download/index.md)
|
||||
|
||||
|
||||
## Easy ML with tf.contrib.learn
|
||||
|
||||
|
@ -2,7 +2,6 @@
|
||||
mnist/beginners/index.md
|
||||
mnist/pros/index.md
|
||||
mnist/tf/index.md
|
||||
mnist/download/index.md
|
||||
### Easy ML with tf.contrib.learn
|
||||
tflearn/index.md
|
||||
linear/overview.md
|
||||
|
@ -1,11 +1,11 @@
|
||||
# MNIST For ML Beginners
|
||||
|
||||
*This tutorial is intended for readers who are new to both machine learning and
|
||||
TensorFlow. If you already
|
||||
know what MNIST is, and what softmax (multinomial logistic) regression is,
|
||||
you might prefer this [faster paced tutorial](../pros/index.md).
|
||||
Be sure to [install TensorFlow](../../../get_started/os_setup.md) before
|
||||
starting either tutorial.*
|
||||
TensorFlow. If you already know what MNIST is, and what softmax (multinomial
|
||||
logistic) regression is, you might prefer this
|
||||
[faster paced tutorial](../pros/index.md). Be sure to
|
||||
[install TensorFlow](../../../get_started/os_setup.md) before starting either
|
||||
tutorial.*
|
||||
|
||||
When one learns how to program, there's a tradition that the first thing you do
|
||||
is print "Hello World." Just like programming has Hello World, machine learning
|
||||
@ -33,21 +33,45 @@ important to understand the ideas behind it: both how TensorFlow works and the
|
||||
core machine learning concepts. Because of this, we are going to very carefully
|
||||
work through the code.
|
||||
|
||||
## About this tutorial
|
||||
|
||||
This tutorial is an explanation, line by line, of what is happening in the
|
||||
[mnist_softmax.py](https://www.tensorflow.org/code/tensorflow/examples/tutorials/mnist/mnist_softmax.py) code.
|
||||
|
||||
You can use this tutorial in a few different ways, including:
|
||||
|
||||
- Copy and paste each code snippet, line by line, into a Python environment as
|
||||
you read through the explanations of each line.
|
||||
|
||||
- Run the entire `mnist_softmax.py` Python file either before or after reading
|
||||
through the explanations, and use this tutorial to understand the lines of
|
||||
code that aren't clear to you.
|
||||
|
||||
What we will accomplish in this tutorial:
|
||||
|
||||
- Learn about the MNIST data and softmax regressions
|
||||
|
||||
- Create a function that is a model for recognizing digits, based on looking at
|
||||
every pixel in the image
|
||||
|
||||
- Use Tensorflow to train the model to recognize digits by having it "look" at
|
||||
thousands of examples (and run our first Tensorflow session to do so)
|
||||
|
||||
- Check the model's accuracy with our test data
|
||||
|
||||
## The MNIST Data
|
||||
|
||||
The MNIST data is hosted on
|
||||
[Yann LeCun's website](http://yann.lecun.com/exdb/mnist/). For your
|
||||
convenience, we've included some python code to download and install the data
|
||||
automatically. You can either download
|
||||
[the code](https://www.tensorflow.org/code/tensorflow/examples/tutorials/mnist/input_data.py)
|
||||
and import it as below, or simply copy and paste it in.
|
||||
[Yann LeCun's website](http://yann.lecun.com/exdb/mnist/). If you are copying and
|
||||
pasting in the code from this tutorial, start here with these two lines of code
|
||||
which will download and read in the data automatically:
|
||||
|
||||
```python
|
||||
from tensorflow.examples.tutorials.mnist import input_data
|
||||
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
|
||||
```
|
||||
|
||||
The downloaded data is split into three parts, 55,000 data points of training
|
||||
The MNIST data is split into three parts: 55,000 data points of training
|
||||
data (`mnist.train`), 10,000 points of test data (`mnist.test`), and 5,000
|
||||
points of validation data (`mnist.validation`). This split is very important:
|
||||
it's essential in machine learning that we have separate data which we don't
|
||||
@ -55,10 +79,10 @@ learn from so that we can make sure that what we've learned actually
|
||||
generalizes!
|
||||
|
||||
As mentioned earlier, every MNIST data point has two parts: an image of a
|
||||
handwritten digit and a corresponding label. We will call the images "xs" and
|
||||
the labels "ys". Both the training set and test set contain xs and ys, for
|
||||
example the training images are `mnist.train.images` and the train labels are
|
||||
`mnist.train.labels`.
|
||||
handwritten digit and a corresponding label. We'll call the images "x"
|
||||
and the labels "y". Both the training set and test set contain images and their
|
||||
corresponding labels; for example the training images are `mnist.train.images`
|
||||
and the training labels are `mnist.train.labels`.
|
||||
|
||||
Each image is 28 pixels by 28 pixels. We can interpret this as a big array of
|
||||
numbers:
|
||||
@ -77,26 +101,26 @@ From this perspective, the MNIST images are just a bunch of points in a
|
||||
Flattening the data throws away information about the 2D structure of the image.
|
||||
Isn't that bad? Well, the best computer vision methods do exploit this
|
||||
structure, and we will in later tutorials. But the simple method we will be
|
||||
using here, a softmax regression, won't.
|
||||
using here, a softmax regression (defined below), won't.
|
||||
|
||||
The result is that `mnist.train.images` is a tensor (an n-dimensional array)
|
||||
with a shape of `[55000, 784]`. The first dimension indexes the images and the
|
||||
second dimension indexes the pixels in each image. Each entry in the tensor is
|
||||
the pixel intensity between 0 and 1, for a particular pixel in a particular
|
||||
image.
|
||||
with a shape of `[55000, 784]`. The first dimension is an index into the list
|
||||
of images and the second dimension is the index for each pixel in each image.
|
||||
East entry in the tensor is a pixel intensity between 0 and 1, for a particular
|
||||
pixel in a particular image.
|
||||
|
||||
<div style="width:40%; margin:auto; margin-bottom:10px; margin-top:20px;">
|
||||
<img style="width:100%" src="../../../images/mnist-train-xs.png">
|
||||
</div>
|
||||
|
||||
The corresponding labels in MNIST are numbers between 0 and 9, describing
|
||||
which digit a given image is of.
|
||||
For the purposes of this tutorial, we're going to want our labels
|
||||
as "one-hot vectors". A one-hot vector is a vector which is 0 in most
|
||||
dimensions, and 1 in a single dimension. In this case, the \\(n\\)th digit will
|
||||
be represented as a vector which is 1 in the \\(n\\)th dimensions. For example,
|
||||
3 would be \\([0,0,0,1,0,0,0,0,0,0]\\).
|
||||
Consequently, `mnist.train.labels` is a
|
||||
Each image in MNIST has a corresponding label, a number between 0 and 9
|
||||
representing the digit drawn in the image.
|
||||
|
||||
For the purposes of this tutorial, we're going to want our labels as "one-hot
|
||||
vectors". A one-hot vector is a vector which is 0 in most dimensions, and 1 in a
|
||||
single dimension. In this case, the \\(n\\)th digit will be represented as a
|
||||
vector which is 1 in the \\(n\\)th dimensions. For example, 3 would be
|
||||
\\([0,0,0,1,0,0,0,0,0,0]\\). Consequently, `mnist.train.labels` is a
|
||||
`[55000, 10]` array of floats.
|
||||
|
||||
<div style="width:40%; margin:auto; margin-bottom:10px; margin-top:20px;">
|
||||
@ -107,24 +131,26 @@ We're now ready to actually make our model!
|
||||
|
||||
## Softmax Regressions
|
||||
|
||||
We know that every image in MNIST is a digit, whether it's a zero or a nine. We
|
||||
want to be able to look at an image and give probabilities for it being each
|
||||
We know that every image in MNIST is of a handwritten digit between zero and
|
||||
nine. So there are only ten possible things that a given image can be. We want
|
||||
to be able to look at an image and give the probabilities for it being each
|
||||
digit. For example, our model might look at a picture of a nine and be 80% sure
|
||||
it's a nine, but give a 5% chance to it being an eight (because of the top loop)
|
||||
and a bit of probability to all the others because it isn't sure.
|
||||
and a bit of probability to all the others because isn't 100% sure.
|
||||
|
||||
This is a classic case where a softmax regression is a natural, simple model.
|
||||
If you want to assign probabilities to an object being one of several different
|
||||
things, softmax is the thing to do. Even later on, when we train more
|
||||
sophisticated models, the final step will be a layer of softmax.
|
||||
things, softmax is the thing to do, because softmax gives us a list of values
|
||||
between 0 and 1 that add up to 1. Even later on, when we train more sophisticated
|
||||
models, the final step will be a layer of softmax.
|
||||
|
||||
A softmax regression has two steps: first we add up the evidence of our input
|
||||
being in certain classes, and then we convert that evidence into probabilities.
|
||||
|
||||
To tally up the evidence that a given image is in a particular class, we do a
|
||||
weighted sum of the pixel intensities. The weight is negative if that pixel
|
||||
having a high intensity is evidence against the image being in that class,
|
||||
and positive if it is evidence in favor.
|
||||
having a high intensity is evidence against the image being in that class, and
|
||||
positive if it is evidence in favor.
|
||||
|
||||
The following diagram shows the weights one model learned for each of these
|
||||
classes. Red represents negative weights, while blue represents positive
|
||||
@ -160,18 +186,16 @@ If you expand that equation out, you get:
|
||||
|
||||
$$\text{softmax}(x)_i = \frac{\exp(x_i)}{\sum_j \exp(x_j)}$$
|
||||
|
||||
But it's often more helpful to think of softmax the first way:
|
||||
exponentiating its inputs and then normalizing them.
|
||||
The exponentiation means that one more unit of evidence increases the weight
|
||||
given to any hypothesis multiplicatively.
|
||||
And conversely, having one less unit of evidence means that a
|
||||
hypothesis gets a fraction of its earlier weight. No hypothesis ever has zero
|
||||
or negative weight. Softmax then normalizes these weights, so that they add up
|
||||
to one, forming a valid probability distribution. (To get more intuition about
|
||||
the softmax function, check out the
|
||||
[section](http://neuralnetworksanddeeplearning.com/chap3.html#softmax)
|
||||
on it in Michael Nielsen's book, complete with an interactive visualization.)
|
||||
|
||||
But it's often more helpful to think of softmax the first way: exponentiating
|
||||
its inputs and then normalizing them. The exponentiation means that one more
|
||||
unit of evidence increases the weight given to any hypothesis multiplicatively.
|
||||
And conversely, having one less unit of evidence means that a hypothesis gets a
|
||||
fraction of its earlier weight. No hypothesis ever has zero or negative
|
||||
weight. Softmax then normalizes these weights, so that they add up to one,
|
||||
forming a valid probability distribution. (To get more intuition about the
|
||||
softmax function, check out the
|
||||
[section](http://neuralnetworksanddeeplearning.com/chap3.html#softmax) on it in
|
||||
Michael Nielsen's book, complete with an interactive visualization.)
|
||||
|
||||
You can picture our softmax regression as looking something like the following,
|
||||
although with a lot more \\(x\\)s. For each output, we compute a weighted sum of
|
||||
@ -199,26 +223,26 @@ More compactly, we can just write:
|
||||
|
||||
$$y = \text{softmax}(Wx + b)$$
|
||||
|
||||
Now let's turn that into something that Tensorflow can use.
|
||||
|
||||
## Implementing the Regression
|
||||
|
||||
|
||||
To do efficient numerical computing in Python, we typically use libraries like
|
||||
NumPy that do expensive operations such as matrix multiplication outside Python,
|
||||
using highly efficient code implemented in another language.
|
||||
Unfortunately, there can still be a lot of overhead from switching back to
|
||||
Python every operation. This overhead is especially bad if you want to run
|
||||
computations on GPUs or in a distributed manner, where there can be a high cost
|
||||
to transferring data.
|
||||
[NumPy](http://www.numpy.org/) that do expensive operations such as matrix
|
||||
multiplication outside Python, using highly efficient code implemented in
|
||||
another language. Unfortunately, there can still be a lot of overhead from
|
||||
switching back to Python every operation. This overhead is especially bad if you
|
||||
want to run computations on GPUs or in a distributed manner, where there can be
|
||||
a high cost to transferring data.
|
||||
|
||||
TensorFlow also does its heavy lifting outside python,
|
||||
but it takes things a step further to avoid this overhead.
|
||||
Instead of running a single expensive operation independently
|
||||
from Python, TensorFlow lets us describe a graph of interacting operations that
|
||||
run entirely outside Python. (Approaches like this can be seen in a few
|
||||
machine learning libraries.)
|
||||
TensorFlow also does its heavy lifting outside Python, but it takes things a
|
||||
step further to avoid this overhead. Instead of running a single expensive
|
||||
operation independently from Python, TensorFlow lets us describe a graph of
|
||||
interacting operations that run entirely outside Python. (Approaches like this
|
||||
can be seen in a few machine learning libraries.)
|
||||
|
||||
To use TensorFlow, we need to import it.
|
||||
To use TensorFlow, first we need to import it.
|
||||
|
||||
```python
|
||||
import tensorflow as tf
|
||||
@ -239,11 +263,10 @@ this as a 2-D tensor of floating-point numbers, with a shape `[None, 784]`.
|
||||
|
||||
We also need the weights and biases for our model. We could imagine treating
|
||||
these like additional inputs, but TensorFlow has an even better way to handle
|
||||
it: `Variable`.
|
||||
A `Variable` is a modifiable tensor that lives in TensorFlow's graph of
|
||||
interacting
|
||||
operations. It can be used and even modified by the computation. For machine
|
||||
learning applications, one generally has the model parameters be `Variable`s.
|
||||
it: `Variable`. A `Variable` is a modifiable tensor that lives in TensorFlow's
|
||||
graph of interacting operations. It can be used and even modified by the
|
||||
computation. For machine learning applications, one generally has the model
|
||||
parameters be `Variable`s.
|
||||
|
||||
```python
|
||||
W = tf.Variable(tf.zeros([784, 10]))
|
||||
@ -260,17 +283,16 @@ Notice that `W` has a shape of [784, 10] because we want to multiply the
|
||||
evidence for the difference classes. `b` has a shape of [10] so we can add it
|
||||
to the output.
|
||||
|
||||
We can now implement our model. It only takes one line!
|
||||
We can now implement our model. It only takes one line to define it!
|
||||
|
||||
```python
|
||||
y = tf.nn.softmax(tf.matmul(x, W) + b)
|
||||
```
|
||||
|
||||
First, we multiply `x` by `W` with the expression `tf.matmul(x, W)`. This is
|
||||
flipped from when we multiplied them in our equation, where we had \\(Wx\\), as a
|
||||
small trick
|
||||
to deal with `x` being a 2D tensor with multiple inputs. We then add `b`, and
|
||||
finally apply `tf.nn.softmax`.
|
||||
flipped from when we multiplied them in our equation, where we had \\(Wx\\), as
|
||||
a small trick to deal with `x` being a 2D tensor with multiple inputs. We then
|
||||
add `b`, and finally apply `tf.nn.softmax`.
|
||||
|
||||
That's it. It only took us one line to define our model, after a couple short
|
||||
lines of setup. That isn't because TensorFlow is designed to make a softmax
|
||||
@ -282,79 +304,80 @@ your computer's CPU, GPUs, and even phones!
|
||||
|
||||
## Training
|
||||
|
||||
In order to train our model, we need to define what it means for the model to
|
||||
be good. Well, actually, in machine learning we typically define what it means
|
||||
for a model to be bad, called the cost or loss, and then try to minimize how bad
|
||||
it is. But the two are equivalent.
|
||||
In order to train our model, we need to define what it means for the model to be
|
||||
good. Well, actually, in machine learning we typically define what it means for
|
||||
a model to be bad. We call this the cost, or the loss, and it represents how far
|
||||
off our model is from our desired outcome. We try to minimize that error, and
|
||||
the smaller the error margin, the better our model is.
|
||||
|
||||
One very common, very nice cost function is "cross-entropy." Surprisingly,
|
||||
cross-entropy arises from thinking about information compressing codes in
|
||||
information theory but it winds up being an important idea in lots of areas,
|
||||
from gambling to machine learning. It's defined:
|
||||
One very common, very nice function to determine the loss of a model is called
|
||||
"cross-entropy." Cross-entropy arises from thinking about information
|
||||
compressing codes in information theory but it winds up being an important idea
|
||||
in lots of areas, from gambling to machine learning. It's defined as:
|
||||
|
||||
$$H_{y'}(y) = -\sum_i y'_i \log(y_i)$$
|
||||
|
||||
Where \\(y\\) is our predicted probability distribution, and \\(y'\\) is the true
|
||||
distribution (the one-hot vector we'll input). In some rough sense, the
|
||||
distribution (the one-hot vector with the digit labels). In some rough sense, the
|
||||
cross-entropy is measuring how inefficient our predictions are for describing
|
||||
the truth. Going into more detail about cross-entropy is beyond the scope of
|
||||
this tutorial, but it's well worth
|
||||
[understanding](http://colah.github.io/posts/2015-09-Visual-Information/).
|
||||
|
||||
To implement cross-entropy we need to first add a new placeholder to input
|
||||
the correct answers:
|
||||
To implement cross-entropy we need to first add a new placeholder to input the
|
||||
correct answers:
|
||||
|
||||
```python
|
||||
y_ = tf.placeholder(tf.float32, [None, 10])
|
||||
```
|
||||
|
||||
Then we can implement the cross-entropy, \\(-\sum y'\log(y)\\):
|
||||
Then we can implement the cross-entropy function, \\(-\sum y'\log(y)\\):
|
||||
|
||||
```python
|
||||
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
|
||||
```
|
||||
|
||||
First, `tf.log` computes the logarithm of each element of `y`. Next, we multiply
|
||||
each element of `y_` with the corresponding element of `tf.log(y)`. Then
|
||||
`tf.reduce_sum` adds the elements in the second dimension of y, due to the
|
||||
`reduction_indices=[1]` parameter. Finally, `tf.reduce_mean` computes the mean
|
||||
each element of `y_` with the corresponding element of `tf.log(y)`. Then
|
||||
`tf.reduce_sum` adds the elements in the second dimension of y, due to the
|
||||
`reduction_indices=[1]` parameter. Finally, `tf.reduce_mean` computes the mean
|
||||
over all the examples in the batch.
|
||||
|
||||
Now that we know what we want our model to do, it's very easy to have TensorFlow
|
||||
train it to do so.
|
||||
Because TensorFlow knows the entire graph of your computations, it
|
||||
can automatically use the [backpropagation
|
||||
algorithm](http://colah.github.io/posts/2015-08-Backprop/)
|
||||
to efficiently determine how your variables affect the cost you ask it to
|
||||
train it to do so. Because TensorFlow knows the entire graph of your
|
||||
computations, it can automatically use the
|
||||
[backpropagation algorithm](http://colah.github.io/posts/2015-08-Backprop/) to
|
||||
efficiently determine how your variables affect the loss you ask it to
|
||||
minimize. Then it can apply your choice of optimization algorithm to modify the
|
||||
variables and reduce the cost.
|
||||
variables and reduce the loss.
|
||||
|
||||
```python
|
||||
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
|
||||
```
|
||||
|
||||
In this case, we ask TensorFlow to minimize `cross_entropy` using the gradient
|
||||
descent algorithm with a learning rate of 0.5. Gradient descent is a simple
|
||||
procedure, where TensorFlow simply shifts each variable a little bit in the
|
||||
direction that reduces the cost. But TensorFlow also provides
|
||||
In this case, we ask TensorFlow to minimize `cross_entropy` using the
|
||||
[gradient descent algorithm](https://en.wikipedia.org/wiki/Gradient_descent)
|
||||
with a learning rate of 0.5. Gradient descent is a simple procedure, where
|
||||
TensorFlow simply shifts each variable a little bit in the direction that
|
||||
reduces the cost. But TensorFlow also provides
|
||||
[many other optimization algorithms]
|
||||
(../../../api_docs/python/train.md#optimizers): using one is as simple as
|
||||
tweaking one line.
|
||||
|
||||
What TensorFlow actually does here, behind the scenes, is it adds new operations
|
||||
to your graph which
|
||||
implement backpropagation and gradient descent. Then it gives you back a
|
||||
single operation which, when run, will do a step of gradient descent training,
|
||||
slightly tweaking your variables to reduce the cost.
|
||||
What TensorFlow actually does here, behind the scenes, is to add new operations
|
||||
to your graph which implement backpropagation and gradient descent. Then it
|
||||
gives you back a single operation which, when run, does a step of gradient
|
||||
descent training, slightly tweaking your variables to reduce the loss.
|
||||
|
||||
Now we have our model set up to train. One last thing before we launch it,
|
||||
we have to add an operation to initialize the variables we created:
|
||||
Now we have our model set up to train. One last thing before we launch it, we
|
||||
have to create an operation to initialize the variables we created. Note that
|
||||
this defines the operation but does not run it yet:
|
||||
|
||||
```python
|
||||
init = tf.initialize_all_variables()
|
||||
```
|
||||
|
||||
We can now launch the model in a `Session`, and run the operation that
|
||||
We can now launch the model in a `Session`, and now we run the operation that
|
||||
initializes the variables:
|
||||
|
||||
```python
|
||||
@ -374,10 +397,10 @@ Each step of the loop, we get a "batch" of one hundred random data points from
|
||||
our training set. We run `train_step` feeding in the batches data to replace
|
||||
the `placeholder`s.
|
||||
|
||||
Using small batches of random data is called stochastic training -- in
|
||||
this case, stochastic gradient descent. Ideally, we'd like to use all our data
|
||||
for every step of training because that would give us a better sense of what
|
||||
we should be doing, but that's expensive. So, instead, we use a different subset
|
||||
Using small batches of random data is called stochastic training -- in this
|
||||
case, stochastic gradient descent. Ideally, we'd like to use all our data for
|
||||
every step of training because that would give us a better sense of what we
|
||||
should be doing, but that's expensive. So, instead, we use a different subset
|
||||
every time. Doing this is cheap and has much of the same benefit.
|
||||
|
||||
|
||||
@ -414,12 +437,12 @@ print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}
|
||||
This should be about 92%.
|
||||
|
||||
Is that good? Well, not really. In fact, it's pretty bad. This is because we're
|
||||
using a very simple model. With some small changes, we can get to
|
||||
97%. The best models can get to over 99.7% accuracy! (For more information, have
|
||||
a look at this
|
||||
using a very simple model. With some small changes, we can get to 97%. The best
|
||||
models can get to over 99.7% accuracy! (For more information, have a look at
|
||||
this
|
||||
[list of results](http://rodrigob.github.io/are_we_there_yet/build/classification_datasets_results.html).)
|
||||
|
||||
What matters is that we learned from this model. Still, if you're feeling a bit
|
||||
down about these results, check out [the next tutorial](../../../tutorials/mnist/pros/index.md) where we
|
||||
do a lot better, and learn how to build more sophisticated models using
|
||||
TensorFlow!
|
||||
down about these results, check out
|
||||
[the next tutorial](../../../tutorials/mnist/pros/index.md) where we do a lot
|
||||
better, and learn how to build more sophisticated models using TensorFlow!
|
||||
|
@ -1,85 +0,0 @@
|
||||
# MNIST Data Download
|
||||
|
||||
Code: [tensorflow/examples/tutorials/mnist/](https://www.tensorflow.org/code/tensorflow/examples/tutorials/mnist/)
|
||||
|
||||
The goal of this tutorial is to show how to download the dataset files required
|
||||
for handwritten digit classification using the (classic) MNIST data set.
|
||||
|
||||
## Tutorial Files
|
||||
|
||||
This tutorial references the following files:
|
||||
|
||||
File | Purpose
|
||||
--- | ---
|
||||
[`input_data.py`](https://www.tensorflow.org/code/tensorflow/examples/tutorials/mnist/input_data.py) | The code to download the MNIST dataset for training and evaluation.
|
||||
|
||||
## Prepare the Data
|
||||
|
||||
MNIST is a classic problem in machine learning. The problem is to look at
|
||||
greyscale 28x28 pixel images of handwritten digits and determine which digit
|
||||
the image represents, for all the digits from zero to nine.
|
||||
|
||||

|
||||
|
||||
For more information, refer to [Yann LeCun's MNIST page](http://yann.lecun.com/exdb/mnist/)
|
||||
or [Chris Olah's visualizations of MNIST](http://colah.github.io/posts/2014-10-Visualizing-MNIST/).
|
||||
|
||||
### Download
|
||||
|
||||
[Yann LeCun's MNIST page](http://yann.lecun.com/exdb/mnist/)
|
||||
also hosts the training and test data for download.
|
||||
|
||||
File | Purpose
|
||||
--- | ---
|
||||
[`train-images-idx3-ubyte.gz`](http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz) | training set images - 55000 training images, 5000 validation images
|
||||
[`train-labels-idx1-ubyte.gz`](http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz) | training set labels matching the images
|
||||
[`t10k-images-idx3-ubyte.gz`](http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz) | test set images - 10000 images
|
||||
[`t10k-labels-idx1-ubyte.gz`](http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz) | test set labels matching the images
|
||||
|
||||
In the `input_data.py` file, the `maybe_download()` function will ensure these
|
||||
files are downloaded into a local data folder for training.
|
||||
|
||||
The folder name is specified in a flag variable at the top of the
|
||||
`fully_connected_feed.py` file and may be changed to fit your needs.
|
||||
|
||||
### Unpack and Reshape
|
||||
|
||||
The files themselves are not in any standard image format and are manually
|
||||
unpacked (following the instructions available at the website) by the
|
||||
`extract_images()` and `extract_labels()` functions in `input_data.py`.
|
||||
|
||||
The image data is extracted into a 2d tensor of: `[image index, pixel index]`
|
||||
where each entry is the intensity value of a specific pixel in a specific
|
||||
image, rescaled from `[0, 255]` to `[0, 1]`. The "image index" corresponds
|
||||
to an image in the dataset, counting up from zero to the size of the dataset.
|
||||
And the "pixel index" corresponds to a specific pixel in that image, ranging
|
||||
from zero to the number of pixels in the image.
|
||||
|
||||
The 60000 examples in the `train-*` files are then split into 55000 examples
|
||||
for training and 5000 examples for validation. For all of the 28x28
|
||||
pixel greyscale images in the datasets the image size is 784 and so the output
|
||||
tensor for the training set images is of shape `[55000, 784]`.
|
||||
|
||||
The label data is extracted into a 1d tensor of: `[image index]`
|
||||
with the class identifier for each example as the value. For the training set
|
||||
labels, this would then be of shape `[55000]`.
|
||||
|
||||
### DataSet Object
|
||||
|
||||
The underlying code will download, unpack, and reshape images and labels for
|
||||
the following datasets:
|
||||
|
||||
Dataset | Purpose
|
||||
--- | ---
|
||||
`data_sets.train` | 55000 images and labels, for primary training.
|
||||
`data_sets.validation` | 5000 images and labels, for iterative validation of training accuracy.
|
||||
`data_sets.test` | 10000 images and labels, for final testing of trained accuracy.
|
||||
|
||||
The `read_data_sets()` function will return a dictionary with a `DataSet`
|
||||
instance for each of these three sets of data. The `DataSet.next_batch()`
|
||||
method can be used to fetch a tuple consisting of `batch_size` lists of images
|
||||
and labels to be fed into the running TensorFlow session.
|
||||
|
||||
```python
|
||||
images_feed, labels_feed = data_set.next_batch(FLAGS.batch_size)
|
||||
```
|
@ -1,10 +1,9 @@
|
||||
# Deep MNIST for Experts
|
||||
# Deep MNIST for Experts
|
||||
|
||||
TensorFlow is a powerful library for doing large-scale numerical computation.
|
||||
One of the tasks at which it excels is implementing and training deep neural
|
||||
networks.
|
||||
In this tutorial we will learn the basic building blocks of a TensorFlow model
|
||||
while constructing a deep convolutional MNIST classifier.
|
||||
networks. In this tutorial we will learn the basic building blocks of a
|
||||
TensorFlow model while constructing a deep convolutional MNIST classifier.
|
||||
|
||||
*This introduction assumes familiarity with neural networks and the MNIST
|
||||
dataset. If you don't have
|
||||
@ -12,6 +11,30 @@ a background with them, check out the
|
||||
[introduction for beginners](../beginners/index.md). Be sure to
|
||||
[install TensorFlow](../../../get_started/os_setup.md) before starting.*
|
||||
|
||||
|
||||
## About this tutorial
|
||||
|
||||
The first part of this tutorial explains what is happening in the
|
||||
[mnist_softmax.py](https://www.tensorflow.org/code/tensorflow/examples/tutorials/mnist/mnist_softmax.py)
|
||||
code, which is a basic implementation of a Tensorflow model. The second part
|
||||
shows some ways to improve the accuracy.
|
||||
|
||||
You can copy and paste each code snippet from this tutorial into a Python
|
||||
environment, or you can choose to just read through the code.
|
||||
|
||||
What we will accomplish in this tutorial:
|
||||
|
||||
- Create a softmax regression function that is a model for recognizing MNIST
|
||||
digits, based on looking at every pixel in the image
|
||||
|
||||
- Use Tensorflow to train the model to recognize digits by having it "look" at
|
||||
thousands of examples (and run our first Tensorflow session to do so)
|
||||
|
||||
- Check the model's accuracy with our test data
|
||||
|
||||
- Build, train, and test a multilayer convolutional neural network to improve
|
||||
the results
|
||||
|
||||
## Setup
|
||||
|
||||
Before we create our model, we will first load the MNIST dataset, and start a
|
||||
@ -19,10 +42,8 @@ TensorFlow session.
|
||||
|
||||
### Load MNIST Data
|
||||
|
||||
For your convenience, we've included
|
||||
[a script](https://www.tensorflow.org/code/tensorflow/examples/tutorials/mnist/input_data.py)
|
||||
which will help you download and import the MNIST dataset. Run the following commands to create a
|
||||
directory `'MNIST_data'` in the current folder, the data files will be stored inside that directory.
|
||||
If you are copying and pasting in the code from this tutorial, start here with
|
||||
these two lines of code which will download and read in the data automatically:
|
||||
|
||||
```python
|
||||
from tensorflow.examples.tutorials.mnist import input_data
|
||||
@ -30,9 +51,8 @@ mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
|
||||
```
|
||||
|
||||
Here `mnist` is a lightweight class which stores the training, validation, and
|
||||
testing sets as NumPy arrays.
|
||||
It also provides a function for iterating through data minibatches, which we
|
||||
will use below.
|
||||
testing sets as NumPy arrays. It also provides a function for iterating through
|
||||
data minibatches, which we will use below.
|
||||
|
||||
### Start TensorFlow InteractiveSession
|
||||
|
||||
@ -40,17 +60,15 @@ TensorFlow relies on a highly efficient C++ backend to do its computation. The
|
||||
connection to this backend is called a session. The common usage for TensorFlow
|
||||
programs is to first create a graph and then launch it in a session.
|
||||
|
||||
Here we instead use the convenient `InteractiveSession` class, which
|
||||
makes TensorFlow more flexible about how you
|
||||
structure your code.
|
||||
It allows you to interleave operations which build a
|
||||
Here we instead use the convenient `InteractiveSession` class, which makes
|
||||
TensorFlow more flexible about how you structure your code. It allows you to
|
||||
interleave operations which build a
|
||||
[computation graph](../../../get_started/basic_usage.md#the-computation-graph)
|
||||
with ones that run the graph.
|
||||
This is particularly convenient when working in interactive contexts like
|
||||
IPython.
|
||||
If you are not using an `InteractiveSession`, then you should build
|
||||
the entire computation graph before starting a session and [launching the
|
||||
graph](../../../get_started/basic_usage.md#launching-the-graph-in-a-session).
|
||||
with ones that run the graph. This is particularly convenient when working in
|
||||
interactive contexts like IPython. If you are not using an
|
||||
`InteractiveSession`, then you should build the entire computation graph before
|
||||
starting a session and
|
||||
[launching the graph](../../../get_started/basic_usage.md#launching-the-graph-in-a-session).
|
||||
|
||||
```python
|
||||
import tensorflow as tf
|
||||
@ -60,19 +78,18 @@ sess = tf.InteractiveSession()
|
||||
#### Computation Graph
|
||||
|
||||
To do efficient numerical computing in Python, we typically use libraries like
|
||||
NumPy that do expensive operations such as matrix multiplication outside Python,
|
||||
using highly efficient code implemented in another language.
|
||||
Unfortunately, there can still be a lot of overhead from switching back to
|
||||
Python every operation. This overhead is especially bad if you want to run
|
||||
computations on GPUs or in a distributed manner, where there can be a high cost
|
||||
to transferring data.
|
||||
[NumPy](http://www.numpy.org/) that do expensive operations such as matrix
|
||||
multiplication outside Python, using highly efficient code implemented in
|
||||
another language. Unfortunately, there can still be a lot of overhead from
|
||||
switching back to Python every operation. This overhead is especially bad if you
|
||||
want to run computations on GPUs or in a distributed manner, where there can be
|
||||
a high cost to transferring data.
|
||||
|
||||
TensorFlow also does its heavy lifting outside Python,
|
||||
but it takes things a step further to avoid this overhead.
|
||||
Instead of running a single expensive operation independently
|
||||
from Python, TensorFlow lets us describe a graph of interacting operations that
|
||||
run entirely outside Python.
|
||||
This approach is similar to that used in Theano or Torch.
|
||||
TensorFlow also does its heavy lifting outside Python, but it takes things a
|
||||
step further to avoid this overhead. Instead of running a single expensive
|
||||
operation independently from Python, TensorFlow lets us describe a graph of
|
||||
interacting operations that run entirely outside Python. This approach is
|
||||
similar to that used in Theano or Torch.
|
||||
|
||||
The role of the Python code is therefore to build this external computation
|
||||
graph, and to dictate which parts of the computation graph should be run. See
|
||||
@ -102,59 +119,58 @@ Here `x` and `y_` aren't specific values. Rather, they are each a `placeholder`
|
||||
-- a value that we'll input when we ask TensorFlow to run a computation.
|
||||
|
||||
The input images `x` will consist of a 2d tensor of floating point numbers.
|
||||
Here we assign it a `shape` of `[None, 784]`, where `784` is the dimensionality of
|
||||
a single flattened MNIST image, and `None` indicates that the first dimension,
|
||||
corresponding to the batch size, can be of any size.
|
||||
The target output classes `y_` will also consist of a 2d tensor,
|
||||
where each row is a one-hot 10-dimensional vector indicating
|
||||
which digit class the corresponding MNIST image belongs to.
|
||||
Here we assign it a `shape` of `[None, 784]`, where `784` is the dimensionality
|
||||
of a single flattened 28 by 28 pixel MNIST image, and `None` indicates that the
|
||||
first dimension, corresponding to the batch size, can be of any size. The
|
||||
target output classes `y_` will also consist of a 2d tensor, where each row is a
|
||||
one-hot 10-dimensional vector indicating which digit class (zero through nine)
|
||||
the corresponding MNIST image belongs to.
|
||||
|
||||
The `shape` argument to `placeholder` is optional, but it allows TensorFlow
|
||||
to automatically catch bugs stemming from inconsistent tensor shapes.
|
||||
|
||||
### Variables
|
||||
|
||||
We now define the weights `W` and biases `b` for our model. We could imagine treating
|
||||
these like additional inputs, but TensorFlow has an even better way to handle
|
||||
them: `Variable`.
|
||||
A `Variable` is a value that lives in TensorFlow's computation graph.
|
||||
It can be used and even modified by the computation. In machine
|
||||
learning applications, one generally has the model parameters be `Variable`s.
|
||||
We now define the weights `W` and biases `b` for our model. We could imagine
|
||||
treating these like additional inputs, but TensorFlow has an even better way to
|
||||
handle them: `Variable`. A `Variable` is a value that lives in TensorFlow's
|
||||
computation graph. It can be used and even modified by the computation. In
|
||||
machine learning applications, one generally has the model parameters be
|
||||
`Variable`s.
|
||||
|
||||
```python
|
||||
W = tf.Variable(tf.zeros([784,10]))
|
||||
b = tf.Variable(tf.zeros([10]))
|
||||
```
|
||||
|
||||
We pass the initial value for each parameter in the call to `tf.Variable`.
|
||||
In this case, we initialize both `W` and `b` as tensors full of
|
||||
zeros. `W` is a 784x10 matrix (because we have 784 input features
|
||||
and 10 outputs) and `b` is a 10-dimensional vector (because we have 10 classes).
|
||||
We pass the initial value for each parameter in the call to `tf.Variable`. In
|
||||
this case, we initialize both `W` and `b` as tensors full of zeros. `W` is a
|
||||
784x10 matrix (because we have 784 input features and 10 outputs) and `b` is a
|
||||
10-dimensional vector (because we have 10 classes).
|
||||
|
||||
Before `Variable`s can be used within a session, they must be initialized using
|
||||
that session.
|
||||
This step takes the initial values (in this case tensors full of zeros) that
|
||||
have already been specified, and assigns them to each `Variable`. This can be
|
||||
done for all `Variables` at once.
|
||||
that session. This step takes the initial values (in this case tensors full of
|
||||
zeros) that have already been specified, and assigns them to each
|
||||
`Variable`. This can be done for all `Variables` at once:
|
||||
|
||||
```python
|
||||
sess.run(tf.initialize_all_variables())
|
||||
```
|
||||
|
||||
### Predicted Class and Cost Function
|
||||
### Predicted Class and Loss Function
|
||||
|
||||
We can now implement our regression model. It only takes one line!
|
||||
We multiply the vectorized input images `x` by the weight matrix `W`, add
|
||||
the bias `b`, and compute the softmax probabilities that are assigned to each
|
||||
class.
|
||||
We can now implement our regression model. It only takes one line! We multiply
|
||||
the vectorized input images `x` by the weight matrix `W`, add the bias `b`, and
|
||||
compute the softmax probabilities that are assigned to each class.
|
||||
|
||||
```python
|
||||
y = tf.nn.softmax(tf.matmul(x,W) + b)
|
||||
```
|
||||
|
||||
The cost function to be minimized during training can be specified just as
|
||||
easily. Our cost function will be the cross-entropy between the target and the
|
||||
model's prediction.
|
||||
We can specify a loss function just as easily. Loss indicates how bad the
|
||||
model's prediction was on a single example; we try to minimize that while
|
||||
training across all the examples. Here, our loss function is the cross-entropy
|
||||
between the target and the model's prediction:
|
||||
|
||||
```python
|
||||
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
|
||||
@ -165,16 +181,14 @@ the average over these sums.
|
||||
|
||||
## Train the Model
|
||||
|
||||
Now that we have defined our model and training cost function, it is
|
||||
straightforward to train using TensorFlow.
|
||||
Because TensorFlow knows the entire computation graph, it
|
||||
can use automatic differentiation to find the gradients of the cost with
|
||||
respect to each of the variables.
|
||||
TensorFlow has a variety of
|
||||
[builtin optimization algorithms]
|
||||
(../../../api_docs/python/train.md#optimizers).
|
||||
For this example, we will use steepest gradient descent, with a step length of
|
||||
0.5, to descend the cross entropy.
|
||||
Now that we have defined our model and training loss function, it is
|
||||
straightforward to train using TensorFlow. Because TensorFlow knows the entire
|
||||
computation graph, it can use automatic differentiation to find the gradients of
|
||||
the loss with respect to each of the variables. TensorFlow has a variety of
|
||||
[built-in optimization algorithms]
|
||||
(../../../api_docs/python/train.md#optimizers). For this example, we will use
|
||||
steepest gradient descent, with a step length of 0.5, to descend the cross
|
||||
entropy.
|
||||
|
||||
```python
|
||||
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
|
||||
@ -184,9 +198,9 @@ What TensorFlow actually did in that single line was to add new operations to
|
||||
the computation graph. These operations included ones to compute gradients,
|
||||
compute parameter update steps, and apply update steps to the parameters.
|
||||
|
||||
The returned operation `train_step`, when run, will apply the gradient
|
||||
descent updates to the parameters. Training the model can therefore be
|
||||
accomplished by repeatedly running `train_step`.
|
||||
The returned operation `train_step`, when run, will apply the gradient descent
|
||||
updates to the parameters. Training the model can therefore be accomplished by
|
||||
repeatedly running `train_step`.
|
||||
|
||||
```python
|
||||
for i in range(1000):
|
||||
@ -194,22 +208,21 @@ for i in range(1000):
|
||||
train_step.run(feed_dict={x: batch[0], y_: batch[1]})
|
||||
```
|
||||
|
||||
Each training iteration we load 100 training examples. We then run the
|
||||
We load 100 training examples in each training iteration. We then run the
|
||||
`train_step` operation, using `feed_dict` to replace the `placeholder` tensors
|
||||
`x` and `y_` with the training examples.
|
||||
Note that you can replace any tensor in your computation graph using `feed_dict`
|
||||
-- it's not restricted to just `placeholder`s.
|
||||
`x` and `y_` with the training examples. Note that you can replace any tensor
|
||||
in your computation graph using `feed_dict` -- it's not restricted to just
|
||||
`placeholder`s.
|
||||
|
||||
### Evaluate the Model
|
||||
|
||||
How well did our model do?
|
||||
|
||||
First we'll figure out where we predicted the correct label. `tf.argmax`
|
||||
is an extremely useful function which gives you the index of the highest entry
|
||||
in a tensor along some axis. For example, `tf.argmax(y,1)` is the label our
|
||||
model thinks is most likely for each input, while `tf.argmax(y_,1)` is the
|
||||
true label. We can use `tf.equal` to check if our prediction matches the
|
||||
truth.
|
||||
First we'll figure out where we predicted the correct label. `tf.argmax` is an
|
||||
extremely useful function which gives you the index of the highest entry in a
|
||||
tensor along some axis. For example, `tf.argmax(y,1)` is the label our model
|
||||
thinks is most likely for each input, while `tf.argmax(y_,1)` is the true
|
||||
label. We can use `tf.equal` to check if our prediction matches the truth.
|
||||
|
||||
```python
|
||||
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
|
||||
@ -241,10 +254,11 @@ to around 99.2% accuracy -- not state of the art, but respectable.
|
||||
|
||||
To create this model, we're going to need to create a lot of weights and biases.
|
||||
One should generally initialize weights with a small amount of noise for
|
||||
symmetry breaking, and to prevent 0 gradients. Since we're using ReLU neurons,
|
||||
it is also good practice to initialize them with a slightly positive initial
|
||||
bias to avoid "dead neurons". Instead of doing this repeatedly while we build
|
||||
the model, let's create two handy functions to do it for us.
|
||||
symmetry breaking, and to prevent 0 gradients. Since we're using
|
||||
[ReLU](https://en.wikipedia.org/wiki/Rectifier_(neural_networks)) neurons, it is
|
||||
also good practice to initialize them with a slightly positive initial bias to
|
||||
avoid "dead neurons". Instead of doing this repeatedly while we build the model,
|
||||
let's create two handy functions to do it for us.
|
||||
|
||||
```python
|
||||
def weight_variable(shape):
|
||||
@ -362,13 +376,21 @@ y_conv=tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)
|
||||
|
||||
### Train and Evaluate the Model
|
||||
|
||||
How well does this model do?
|
||||
To train and evaluate it we will use code that is nearly identical to that for
|
||||
the simple one layer SoftMax network above.
|
||||
The differences are that: we will replace the steepest gradient descent
|
||||
optimizer with the more sophisticated ADAM optimizer; we will include the
|
||||
additional parameter `keep_prob` in `feed_dict` to control the dropout rate;
|
||||
and we will add logging to every 100th iteration in the training process.
|
||||
How well does this model do? To train and evaluate it we will use code that is
|
||||
nearly identical to that for the simple one layer SoftMax network above.
|
||||
|
||||
The differences are that:
|
||||
|
||||
- We will replace the steepest gradient descent optimizer with the more
|
||||
sophisticated ADAM optimizer.
|
||||
|
||||
- We will include the additional parameter `keep_prob` in `feed_dict` to control
|
||||
the dropout rate.
|
||||
|
||||
- We will add logging to every 100th iteration in the training process.
|
||||
|
||||
Feel free to go ahead and run this code, but it does 20,000 training iterations
|
||||
and may take a while (possibly up to half an hour), depending on your processor.
|
||||
|
||||
```python
|
||||
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y_conv), reduction_indices=[1]))
|
||||
|
@ -58,9 +58,6 @@ Dataset | Purpose
|
||||
`data_sets.validation` | 5000 images and labels, for iterative validation of training accuracy.
|
||||
`data_sets.test` | 10000 images and labels, for final testing of trained accuracy.
|
||||
|
||||
For more information about the data, please read the [Download](../../../tutorials/mnist/download/index.md)
|
||||
tutorial.
|
||||
|
||||
### Inputs and Placeholders
|
||||
|
||||
The `placeholder_inputs()` function creates two [`tf.placeholder`](../../../api_docs/python/io_ops.md#placeholder)
|
||||
|
@ -1065,6 +1065,8 @@ tf_cuda_library(
|
||||
":construction_fails_op",
|
||||
":numpy_lib",
|
||||
":test_ops_kernels",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/core",
|
||||
"//tensorflow/core:all_kernels",
|
||||
"//tensorflow/core:direct_session",
|
||||
@ -1104,6 +1106,9 @@ tf_py_wrap_cc(
|
||||
":tf_session_helper",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_session",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:checkpoint_reader",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/distributed_runtime:server_lib",
|
||||
"//util/python:python_headers",
|
||||
@ -1265,7 +1270,7 @@ cuda_py_tests(
|
||||
"training/server_lib_test.py",
|
||||
"training/session_manager_test.py",
|
||||
"training/supervisor_test.py",
|
||||
"training/saver_test.py",
|
||||
"training/saver_large_variable_test.py",
|
||||
],
|
||||
),
|
||||
additional_deps = [
|
||||
@ -1273,14 +1278,15 @@ cuda_py_tests(
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "saver_test",
|
||||
py_test(
|
||||
name = "saver_large_variable_test",
|
||||
size = "small",
|
||||
srcs = ["training/saver_test.py"],
|
||||
additional_deps = [
|
||||
":training",
|
||||
],
|
||||
srcs = ["training/saver_large_variable_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = ["notsan"], # http://b/30379628
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
@ -1311,6 +1317,7 @@ py_tests(
|
||||
["training/input_test.py"],
|
||||
),
|
||||
additional_deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
":training",
|
||||
],
|
||||
)
|
||||
|
@ -1217,8 +1217,7 @@ class InteractiveSession(BaseSession):
|
||||
|
||||
Args:
|
||||
target: (Optional.) The execution engine to connect to.
|
||||
Defaults to using an in-process engine. At present, no value
|
||||
other than the empty string is supported.
|
||||
Defaults to using an in-process engine.
|
||||
graph: (Optional.) The `Graph` to be launched (described above).
|
||||
config: (Optional) `ConfigProto` proto used to configure the session.
|
||||
"""
|
||||
|
@ -213,7 +213,7 @@ tensorflow::ImportNumpy();
|
||||
reinterpret_cast<const char*>($1.data), $1.length);
|
||||
}
|
||||
|
||||
// Include the functions from tensor_c_api.h, except TF_Run.
|
||||
// Include the functions from c_api.h, except TF_Run.
|
||||
%ignoreall
|
||||
%unignore TF_Code;
|
||||
%unignore TF_Status;
|
||||
@ -238,7 +238,7 @@ tensorflow::ImportNumpy();
|
||||
%unignore TF_NewLibrary;
|
||||
%unignore TF_LoadLibrary;
|
||||
%unignore TF_GetOpList;
|
||||
%include "tensorflow/core/public/tensor_c_api.h"
|
||||
%include "tensorflow/c/c_api.h"
|
||||
%ignoreall
|
||||
|
||||
%insert("python") %{
|
||||
|
@ -17,13 +17,13 @@ limitations under the License.
|
||||
|
||||
#include <cstring>
|
||||
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/log_memory.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/graph/equal_graph_def.h"
|
||||
#include "tensorflow/core/lib/core/coding.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/tf_status_helper.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
|
@ -19,11 +19,11 @@ limitations under the License.
|
||||
// Must be included first
|
||||
#include "tensorflow/python/lib/core/numpy.h"
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/public/tensor_c_api.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
|
@ -721,6 +721,15 @@ class ControlFlowTest(tf.test.TestCase):
|
||||
c = tf.while_loop(lambda x: x < 10, lambda x: x + 1, [c])
|
||||
self.assertEqual(10, sess.run(c, {b: True}))
|
||||
|
||||
def testWhileWithControl_4(self):
|
||||
with self.test_session() as sess:
|
||||
b = tf.placeholder(tf.bool)
|
||||
c = tf.constant(1)
|
||||
x0 = tf.constant(0)
|
||||
with tf.control_dependencies([b]):
|
||||
r = tf.while_loop(lambda x: x < 10, lambda x: x + tf.identity(c), [x0])
|
||||
self.assertEqual(10, sess.run(r, {b: True}))
|
||||
|
||||
def testCondWhile_1(self):
|
||||
with self.test_session():
|
||||
n = tf.convert_to_tensor(0, name="n")
|
||||
|
@ -276,42 +276,55 @@ class RangeTest(tf.test.TestCase):
|
||||
# TODO(vrv): move to sequence_ops_test?
|
||||
class LinSpaceTest(tf.test.TestCase):
|
||||
|
||||
def _gpu_modes(self):
|
||||
if tf.test.is_gpu_available():
|
||||
return [False, True]
|
||||
else:
|
||||
return [False]
|
||||
|
||||
def _LinSpace(self, start, stop, num):
|
||||
with self.test_session():
|
||||
tf_ans = tf.linspace(start, stop, num, name="linspace")
|
||||
self.assertEqual([num], tf_ans.get_shape())
|
||||
return tf_ans.eval()
|
||||
# NOTE(touts): Needs to pass a graph to get a new session each time.
|
||||
with tf.Graph().as_default() as graph:
|
||||
with self.test_session(graph=graph, force_gpu=self.force_gpu):
|
||||
tf_ans = tf.linspace(start, stop, num, name="linspace")
|
||||
self.assertEqual([num], tf_ans.get_shape())
|
||||
return tf_ans.eval()
|
||||
|
||||
def testPositive(self):
|
||||
self.assertArrayNear(self._LinSpace(1., 5., 1), np.array([1.]), 1e-5)
|
||||
self.assertArrayNear(self._LinSpace(1., 5., 2), np.array([1., 5.]), 1e-5)
|
||||
self.assertArrayNear(self._LinSpace(1., 5., 3),
|
||||
np.array([1., 3., 5.]), 1e-5)
|
||||
self.assertArrayNear(self._LinSpace(1., 5., 4),
|
||||
np.array([1., 7. / 3., 11. / 3., 5.]), 1e-5)
|
||||
for self.force_gpu in self._gpu_modes():
|
||||
self.assertArrayNear(self._LinSpace(1., 5., 1), np.array([1.]), 1e-5)
|
||||
self.assertArrayNear(self._LinSpace(1., 5., 2), np.array([1., 5.]), 1e-5)
|
||||
self.assertArrayNear(self._LinSpace(1., 5., 3),
|
||||
np.array([1., 3., 5.]), 1e-5)
|
||||
self.assertArrayNear(self._LinSpace(1., 5., 4),
|
||||
np.array([1., 7. / 3., 11. / 3., 5.]), 1e-5)
|
||||
|
||||
def testNegative(self):
|
||||
self.assertArrayNear(self._LinSpace(-1., -5., 1), np.array([-1.]), 1e-5)
|
||||
self.assertArrayNear(self._LinSpace(-1., -5., 2),
|
||||
np.array([-1., -5.]), 1e-5)
|
||||
self.assertArrayNear(self._LinSpace(-1., -5., 3),
|
||||
np.array([-1., -3., -5.]), 1e-5)
|
||||
self.assertArrayNear(self._LinSpace(-1., -5., 4),
|
||||
np.array([-1., -7. / 3., -11. / 3., -5.]), 1e-5)
|
||||
for self.force_gpu in self._gpu_modes():
|
||||
self.assertArrayNear(self._LinSpace(-1., -5., 1), np.array([-1.]), 1e-5)
|
||||
self.assertArrayNear(self._LinSpace(-1., -5., 2),
|
||||
np.array([-1., -5.]), 1e-5)
|
||||
self.assertArrayNear(self._LinSpace(-1., -5., 3),
|
||||
np.array([-1., -3., -5.]), 1e-5)
|
||||
self.assertArrayNear(self._LinSpace(-1., -5., 4),
|
||||
np.array([-1., -7. / 3., -11. / 3., -5.]), 1e-5)
|
||||
|
||||
def testNegativeToPositive(self):
|
||||
self.assertArrayNear(self._LinSpace(-1., 5., 1), np.array([-1.]), 1e-5)
|
||||
self.assertArrayNear(self._LinSpace(-1., 5., 2), np.array([-1., 5.]), 1e-5)
|
||||
self.assertArrayNear(self._LinSpace(-1., 5., 3),
|
||||
np.array([-1., 2., 5.]), 1e-5)
|
||||
self.assertArrayNear(self._LinSpace(-1., 5., 4),
|
||||
np.array([-1., 1., 3., 5.]), 1e-5)
|
||||
for self.force_gpu in self._gpu_modes():
|
||||
self.assertArrayNear(self._LinSpace(-1., 5., 1), np.array([-1.]), 1e-5)
|
||||
self.assertArrayNear(self._LinSpace(-1., 5., 2), np.array([-1., 5.]),
|
||||
1e-5)
|
||||
self.assertArrayNear(self._LinSpace(-1., 5., 3),
|
||||
np.array([-1., 2., 5.]), 1e-5)
|
||||
self.assertArrayNear(self._LinSpace(-1., 5., 4),
|
||||
np.array([-1., 1., 3., 5.]), 1e-5)
|
||||
|
||||
def testPoint(self):
|
||||
self.assertArrayNear(self._LinSpace(5., 5., 1), np.array([5.]), 1e-5)
|
||||
self.assertArrayNear(self._LinSpace(5., 5., 2), np.array([5.] * 2), 1e-5)
|
||||
self.assertArrayNear(self._LinSpace(5., 5., 3), np.array([5.] * 3), 1e-5)
|
||||
self.assertArrayNear(self._LinSpace(5., 5., 4), np.array([5.] * 4), 1e-5)
|
||||
for self.force_gpu in self._gpu_modes():
|
||||
self.assertArrayNear(self._LinSpace(5., 5., 1), np.array([5.]), 1e-5)
|
||||
self.assertArrayNear(self._LinSpace(5., 5., 2), np.array([5.] * 2), 1e-5)
|
||||
self.assertArrayNear(self._LinSpace(5., 5., 3), np.array([5.] * 3), 1e-5)
|
||||
self.assertArrayNear(self._LinSpace(5., 5., 4), np.array([5.] * 4), 1e-5)
|
||||
|
||||
|
||||
class DeviceTest(tf.test.TestCase):
|
||||
|
@ -233,27 +233,54 @@ class OneHotTest(tf.test.TestCase):
|
||||
dtype=dtype,
|
||||
truth=[truth[0].T, truth[1].T]) # Do not transpose the batch
|
||||
|
||||
def _testEmpty(self, dtype):
|
||||
indices = np.zeros((0, 16), dtype=np.int64)
|
||||
depth = 3
|
||||
on_value = np.asarray(1.0, dtype=dtype)
|
||||
off_value = np.asarray(-1.0, dtype=dtype)
|
||||
truth = np.empty((0, 16, 3), dtype=dtype)
|
||||
|
||||
# axis == -1
|
||||
self._testBothOneHot(
|
||||
indices=indices,
|
||||
depth=depth,
|
||||
on_value=on_value,
|
||||
off_value=off_value,
|
||||
dtype=dtype,
|
||||
truth=truth)
|
||||
|
||||
def testHalfBatch(self):
|
||||
self._testEmpty(np.float16)
|
||||
self._testBatch(np.float16)
|
||||
self._testDefaultValuesBatch(np.float16)
|
||||
self._testValueTypeBatch(np.float16)
|
||||
|
||||
def testFloatBatch(self):
|
||||
self._testEmpty(np.float32)
|
||||
self._testBatch(np.float32)
|
||||
self._testDefaultValuesBatch(np.float32)
|
||||
self._testValueTypeBatch(np.float32)
|
||||
|
||||
def testDoubleBatch(self):
|
||||
self._testEmpty(np.float64)
|
||||
self._testBatch(np.float64)
|
||||
self._testDefaultValuesBatch(np.float64)
|
||||
self._testValueTypeBatch(np.float64)
|
||||
|
||||
def testInt32Batch(self):
|
||||
self._testEmpty(np.int32)
|
||||
self._testBatch(np.int32)
|
||||
self._testDefaultValuesBatch(np.int32)
|
||||
self._testValueTypeBatch(np.int32)
|
||||
|
||||
def testInt64Batch(self):
|
||||
self._testEmpty(np.int64)
|
||||
self._testBatch(np.int64)
|
||||
self._testDefaultValuesBatch(np.int64)
|
||||
self._testValueTypeBatch(np.int64)
|
||||
|
||||
def testComplexBatch(self):
|
||||
self._testEmpty(np.complex64)
|
||||
self._testBatch(np.complex64)
|
||||
# self._testDefaultValuesBatch(np.complex64)
|
||||
self._testValueTypeBatch(np.complex64)
|
||||
|
@ -17,13 +17,13 @@ limitations under the License.
|
||||
%include "tensorflow/python/platform/base.i"
|
||||
|
||||
%{
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/lib/io/match.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/protobuf/meta_graph.pb.h"
|
||||
#include "tensorflow/core/util/tf_status_helper.h"
|
||||
%}
|
||||
|
||||
%{
|
||||
|
@ -118,6 +118,8 @@ def _Identity(data, name=None):
|
||||
else:
|
||||
return array_ops.identity(data, name=name)
|
||||
else:
|
||||
if not isinstance(data, (ops.IndexedSlices, ops.SparseTensor)):
|
||||
raise TypeError("Type %s not supported" % type(data))
|
||||
values = _Identity(data.values, name=name)
|
||||
indices = array_ops.identity(data.indices, name="indices")
|
||||
if isinstance(data, ops.IndexedSlices):
|
||||
@ -125,11 +127,9 @@ def _Identity(data, name=None):
|
||||
if dense_shape is not None:
|
||||
dense_shape = array_ops.identity(dense_shape, name="dense_shape")
|
||||
return ops.IndexedSlices(values, indices, dense_shape)
|
||||
elif isinstance(data, ops.SparseTensor):
|
||||
else:
|
||||
dense_shape = array_ops.identity(data.shape, name="dense_shape")
|
||||
return ops.SparseTensor(indices, values, dense_shape)
|
||||
else:
|
||||
raise TypeError("Type %s not supported" % type(data))
|
||||
|
||||
|
||||
def _NextIteration(data, name=None):
|
||||
@ -140,6 +140,8 @@ def _NextIteration(data, name=None):
|
||||
else:
|
||||
return next_iteration(data, name=name)
|
||||
else:
|
||||
if not isinstance(data, (ops.IndexedSlices, ops.SparseTensor)):
|
||||
raise TypeError("Type %s not supported" % type(data))
|
||||
values = _NextIteration(data.values, name=name)
|
||||
indices = next_iteration(data.indices, name="indices")
|
||||
if isinstance(data, ops.IndexedSlices):
|
||||
@ -147,11 +149,9 @@ def _NextIteration(data, name=None):
|
||||
if dense_shape is not None:
|
||||
dense_shape = next_iteration(dense_shape, name="dense_shape")
|
||||
return ops.IndexedSlices(values, indices, dense_shape)
|
||||
elif isinstance(data, ops.SparseTensor):
|
||||
else:
|
||||
dense_shape = next_iteration(data.shape, name="dense_shape")
|
||||
return ops.SparseTensor(indices, values, dense_shape)
|
||||
else:
|
||||
raise TypeError("Type %s not supported" % type(data))
|
||||
|
||||
|
||||
def _Enter(data, frame_name, is_constant=False, parallel_iterations=10,
|
||||
@ -183,6 +183,8 @@ def _Enter(data, frame_name, is_constant=False, parallel_iterations=10,
|
||||
return enter(data, frame_name, is_constant, parallel_iterations,
|
||||
name=name)
|
||||
else:
|
||||
if not isinstance(data, (ops.IndexedSlices, ops.SparseTensor)):
|
||||
raise TypeError("Type %s not supported" % type(data))
|
||||
values = _Enter(data.values, frame_name, is_constant,
|
||||
parallel_iterations, name=name)
|
||||
indices = enter(data.indices, frame_name, is_constant,
|
||||
@ -193,12 +195,10 @@ def _Enter(data, frame_name, is_constant=False, parallel_iterations=10,
|
||||
dense_shape = enter(dense_shape, frame_name, is_constant,
|
||||
parallel_iterations, name="dense_shape")
|
||||
return ops.IndexedSlices(values, indices, dense_shape)
|
||||
elif isinstance(data, ops.SparseTensor):
|
||||
else:
|
||||
dense_shape = enter(data.shape, frame_name, is_constant,
|
||||
parallel_iterations, name="dense_shape")
|
||||
return ops.SparseTensor(indices, values, dense_shape)
|
||||
else:
|
||||
raise TypeError("Type %s not supported" % type(data))
|
||||
|
||||
|
||||
def exit(data, name=None):
|
||||
@ -220,6 +220,8 @@ def exit(data, name=None):
|
||||
else:
|
||||
return gen_control_flow_ops._exit(data, name)
|
||||
else:
|
||||
if not isinstance(data, (ops.IndexedSlices, ops.SparseTensor)):
|
||||
raise TypeError("Type %s not supported" % type(data))
|
||||
values = exit(data.values, name=name)
|
||||
indices = gen_control_flow_ops._exit(data.indices, name="indices")
|
||||
if isinstance(data, ops.IndexedSlices):
|
||||
@ -227,11 +229,9 @@ def exit(data, name=None):
|
||||
if dense_shape is not None:
|
||||
dense_shape = gen_control_flow_ops._exit(dense_shape, name)
|
||||
return ops.IndexedSlices(values, indices, dense_shape)
|
||||
elif isinstance(data, ops.SparseTensor):
|
||||
else:
|
||||
dense_shape = gen_control_flow_ops._exit(data.shape, name)
|
||||
return ops.SparseTensor(indices, values, dense_shape)
|
||||
else:
|
||||
raise TypeError("Type %s not supported" % type(data))
|
||||
|
||||
|
||||
def switch(data, pred, dtype=None, name=None):
|
||||
@ -1493,8 +1493,15 @@ class WhileContext(ControlFlowContext):
|
||||
self._AddOpInternal(op)
|
||||
|
||||
def _AddOpInternal(self, op):
|
||||
"""Add `op` to the current context."""
|
||||
"""Add `op` to the current context.
|
||||
|
||||
In the case that op has only external data inputs, we remove all of its
|
||||
external control inputs so all its inputs are in the same while loop
|
||||
context. This is valid because op now has an Enter input that has all
|
||||
the right control dependency.
|
||||
"""
|
||||
if not op.inputs:
|
||||
# Remove any external control dependency on this op
|
||||
control_inputs = [x for x in op.control_inputs
|
||||
if x._get_control_flow_context() == self]
|
||||
if len(control_inputs) != len(op.control_inputs):
|
||||
@ -1508,12 +1515,22 @@ class WhileContext(ControlFlowContext):
|
||||
for x in op.outputs:
|
||||
self._values.add(x.name)
|
||||
else:
|
||||
has_internal_data_input = False
|
||||
for index in range(len(op.inputs)):
|
||||
x = op.inputs[index]
|
||||
self.AddValue(x)
|
||||
real_x = self._external_values.get(x.name)
|
||||
if real_x is not None:
|
||||
op._update_input(index, real_x)
|
||||
else:
|
||||
has_internal_data_input = True
|
||||
if not has_internal_data_input:
|
||||
# Remove any external control dependency on this op
|
||||
control_inputs = [x for x in op.control_inputs
|
||||
if x._get_control_flow_context() == self]
|
||||
if len(control_inputs) != len(op.control_inputs):
|
||||
del op.control_inputs[:]
|
||||
op._add_control_inputs(control_inputs)
|
||||
# Add a control dependency to prevent loop invariants from
|
||||
# enabling ops that should not be executed.
|
||||
self._MaybeAddControlDependency(op)
|
||||
@ -1879,6 +1896,8 @@ class WhileContext(ControlFlowContext):
|
||||
if isinstance(e, ops.Tensor):
|
||||
xs = [e]
|
||||
else:
|
||||
if not isinstance(e, (ops.IndexedSlices, ops.SparseTensor)):
|
||||
raise TypeError("Type %s not supported" % type(e))
|
||||
xs = [e.values, e.indices]
|
||||
shape = e.dense_shape if isinstance(e, ops.IndexedSlices) else e.shape
|
||||
if shape is not None:
|
||||
|
49
tensorflow/python/training/saver_large_variable_test.py
Normal file
49
tensorflow/python/training/saver_large_variable_test.py
Normal file
@ -0,0 +1,49 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
"""Tests for tensorflow.python.training.saver.py."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class SaverLargeVariableTest(tf.test.TestCase):
|
||||
|
||||
# NOTE: This is in a separate file from saver_test.py because the
|
||||
# large allocations do not play well with TSAN, and cause flaky
|
||||
# failures.
|
||||
def testLargeVariable(self):
|
||||
save_path = os.path.join(self.get_temp_dir(), "large_variable")
|
||||
with tf.Session("", graph=tf.Graph()) as sess:
|
||||
# Declare a variable that is exactly 2GB. This should fail,
|
||||
# because a serialized checkpoint includes other header
|
||||
# metadata.
|
||||
with tf.device("/cpu:0"):
|
||||
var = tf.Variable(
|
||||
tf.constant(False, shape=[2, 1024, 1024, 1024], dtype=tf.bool))
|
||||
save = tf.train.Saver({var.op.name: var})
|
||||
var.initializer.run()
|
||||
with self.assertRaisesRegexp(
|
||||
tf.errors.InvalidArgumentError,
|
||||
"Tensor slice is too large to serialize"):
|
||||
save.save(sess, save_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
@ -287,22 +287,6 @@ class SaverTest(tf.test.TestCase):
|
||||
expected_save_path = "%s-%d" % (save_path, global_step_int)
|
||||
self.assertEqual(expected_save_path, val)
|
||||
|
||||
def testLargeVariable(self):
|
||||
save_path = os.path.join(self.get_temp_dir(), "large_variable")
|
||||
with tf.Session("", graph=tf.Graph()) as sess:
|
||||
# Declare a variable that is exactly 2GB. This should fail,
|
||||
# because a serialized checkpoint includes other header
|
||||
# metadata.
|
||||
with tf.device("/cpu:0"):
|
||||
var = tf.Variable(
|
||||
tf.constant(False, shape=[2, 1024, 1024, 1024], dtype=tf.bool))
|
||||
save = tf.train.Saver({var.op.name: var})
|
||||
var.initializer.run()
|
||||
with self.assertRaisesRegexp(
|
||||
tf.errors.InvalidArgumentError,
|
||||
"Tensor slice is too large to serialize"):
|
||||
save.save(sess, save_path)
|
||||
|
||||
|
||||
class SaveRestoreShardedTest(tf.test.TestCase):
|
||||
|
||||
|
@ -58,9 +58,9 @@ limitations under the License.
|
||||
}
|
||||
|
||||
%{
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/distributed_runtime/server_lib.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/util/tf_status_helper.h"
|
||||
|
||||
using tensorflow::ServerDef;
|
||||
|
||||
|
@ -17,8 +17,8 @@ limitations under the License.
|
||||
%include "tensorflow/python/platform/base.i"
|
||||
|
||||
%{
|
||||
#include "tensorflow/c/checkpoint_reader.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/util/checkpoint_reader.h"
|
||||
#include "tensorflow/python/lib/core/py_func.h"
|
||||
%}
|
||||
|
||||
@ -126,5 +126,5 @@ def NewCheckpointReader(filepattern):
|
||||
return CheckpointReader(compat.as_bytes(filepattern), status)
|
||||
%}
|
||||
|
||||
%include "tensorflow/core/util/checkpoint_reader.h"
|
||||
%include "tensorflow/c/checkpoint_reader.h"
|
||||
%unignoreall
|
||||
|
@ -101,7 +101,7 @@ OS_SETUP="${TF_SRC_DIR}/g3doc/get_started/os_setup.md"
|
||||
check_existence file "${OS_SETUP}"
|
||||
|
||||
sed -i -r -e "s/(.*pip[0-9]* install .*tensorflow-)([0-9]+\.[0-9]+\.[[:alnum:]]+)(-.*\.whl)/\1${MAJOR}.${MINOR}.${PATCH}\3/g" "${OS_SETUP}"
|
||||
|
||||
sed -i -r -e "s/(.*export TF_BINARY_URL.*tensorflow-)([0-9]+\.[0-9]+\.[[:alnum:]]+)(-.*\.whl)/\1${MAJOR}.${MINOR}.${PATCH}\3/g" "${OS_SETUP}"
|
||||
sed -i -r -e "s/(.*\(e\.g\..*[^0-9])([0-9]+\.[0-9]+\.[[:alnum:]]+)(-gpu.*)/\1${MAJOR}.${MINOR}.${PATCH}\3/g" "${OS_SETUP}"
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user