Merge pull request #2509 from vrv/branch_123262728

Upstream changes for 5/25/2016
This commit is contained in:
Vijay Vasudevan 2016-05-25 17:29:11 -07:00
commit 3867aa29ea
72 changed files with 950 additions and 638 deletions

View File

@ -1,6 +1,6 @@
package(default_visibility = ["//visibility:public"]) package(default_visibility = ["//visibility:public"])
archive_dir = "eigen-eigen-a5e9085a94e8" archive_dir = "eigen-eigen-f3a13643ac1f"
cc_library( cc_library(
name = "eigen", name = "eigen",

View File

@ -7,7 +7,7 @@
include (ExternalProject) include (ExternalProject)
set(eigen_archive_hash "a5e9085a94e8") set(eigen_archive_hash "f3a13643ac1f")
set(eigen_INCLUDE_DIRS set(eigen_INCLUDE_DIRS
${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_BINARY_DIR}
@ -16,7 +16,7 @@ set(eigen_INCLUDE_DIRS
${tensorflow_source_dir}/third_party/eigen3 ${tensorflow_source_dir}/third_party/eigen3
) )
set(eigen_URL https://bitbucket.org/eigen/eigen/get/${eigen_archive_hash}.tar.gz) set(eigen_URL https://bitbucket.org/eigen/eigen/get/${eigen_archive_hash}.tar.gz)
set(eigen_HASH SHA256=967126237829c7c87abb6cd0e13a5a235b0377d51575522c390b9486aed13e71) set(eigen_HASH SHA256=a9266e60366cddb371a23d86b11a297eee86372a89ef4b38a3509012f9cc37ec)
set(eigen_BUILD ${CMAKE_CURRENT_BINARY_DIR}/eigen/src/eigen) set(eigen_BUILD ${CMAKE_CURRENT_BINARY_DIR}/eigen/src/eigen)
set(eigen_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/eigen/install) set(eigen_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/eigen/install)

View File

@ -55,9 +55,9 @@ cuda_py_tests(
) )
cuda_py_tests( cuda_py_tests(
name = "gaussian_test", name = "normal_test",
size = "small", size = "small",
srcs = ["python/kernel_tests/gaussian_test.py"], srcs = ["python/kernel_tests/normal_test.py"],
additional_deps = [ additional_deps = [
":distributions_py", ":distributions_py",
"//tensorflow/python:framework_test_lib", "//tensorflow/python:framework_test_lib",
@ -98,9 +98,9 @@ cuda_py_tests(
) )
cuda_py_tests( cuda_py_tests(
name = "gaussian_conjugate_posteriors_test", name = "normal_conjugate_posteriors_test",
size = "small", size = "small",
srcs = ["python/kernel_tests/gaussian_conjugate_posteriors_test.py"], srcs = ["python/kernel_tests/normal_conjugate_posteriors_test.py"],
additional_deps = [ additional_deps = [
":distributions_py", ":distributions_py",
"//tensorflow/python:platform_test", "//tensorflow/python:platform_test",

View File

@ -30,7 +30,7 @@ initialized with parameters that define the distributions.
@@Chi2 @@Chi2
@@Exponential @@Exponential
@@Gamma @@Gamma
@@Gaussian @@Normal
@@StudentT @@StudentT
@@Uniform @@Uniform
@ -44,10 +44,10 @@ initialized with parameters that define the distributions.
Functions that transform conjugate prior/likelihood pairs to distributions Functions that transform conjugate prior/likelihood pairs to distributions
representing the posterior or posterior predictive. representing the posterior or posterior predictive.
### Gaussian likelihood with conjugate prior. ### Normal likelihood with conjugate prior.
@@gaussian_conjugates_known_sigma_posterior @@normal_conjugates_known_sigma_posterior
@@gaussian_congugates_known_sigma_predictive @@normal_congugates_known_sigma_predictive
""" """
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
@ -60,8 +60,8 @@ from tensorflow.contrib.distributions.python.ops.dirichlet_multinomial import *
from tensorflow.contrib.distributions.python.ops.distribution import * from tensorflow.contrib.distributions.python.ops.distribution import *
from tensorflow.contrib.distributions.python.ops.exponential import * from tensorflow.contrib.distributions.python.ops.exponential import *
from tensorflow.contrib.distributions.python.ops.gamma import * from tensorflow.contrib.distributions.python.ops.gamma import *
from tensorflow.contrib.distributions.python.ops.gaussian import *
from tensorflow.contrib.distributions.python.ops.gaussian_conjugate_posteriors import *
from tensorflow.contrib.distributions.python.ops.mvn import * from tensorflow.contrib.distributions.python.ops.mvn import *
from tensorflow.contrib.distributions.python.ops.normal import *
from tensorflow.contrib.distributions.python.ops.normal_conjugate_posteriors import *
from tensorflow.contrib.distributions.python.ops.student_t import * from tensorflow.contrib.distributions.python.ops.student_t import *
from tensorflow.contrib.distributions.python.ops.uniform import * from tensorflow.contrib.distributions.python.ops.uniform import *

View File

@ -105,10 +105,9 @@ class ExponentialTest(tf.test.TestCase):
exponential = tf.contrib.distributions.Exponential(lam=lam) exponential = tf.contrib.distributions.Exponential(lam=lam)
n_v = 100000 n = 100000
n = tf.constant(n_v)
samples = exponential.sample(n, seed=138) samples = exponential.sample(n, seed=138)
self.assertEqual(samples.get_shape(), (n_v, batch_size, 2)) self.assertEqual(samples.get_shape(), (n, batch_size, 2))
sample_values = samples.eval() sample_values = samples.eval()

View File

@ -25,9 +25,9 @@ import tensorflow as tf
distributions = tf.contrib.distributions distributions = tf.contrib.distributions
class GaussianTest(tf.test.TestCase): class NormalTest(tf.test.TestCase):
def testGaussianConjugateKnownSigmaPosterior(self): def testNormalConjugateKnownSigmaPosterior(self):
with tf.Session(): with tf.Session():
mu0 = tf.constant([3.0]) mu0 = tf.constant([3.0])
sigma0 = tf.constant([math.sqrt(10.0)]) sigma0 = tf.constant([math.sqrt(10.0)])
@ -35,16 +35,16 @@ class GaussianTest(tf.test.TestCase):
x = tf.constant([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]) x = tf.constant([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0])
s = tf.reduce_sum(x) s = tf.reduce_sum(x)
n = tf.size(x) n = tf.size(x)
prior = distributions.Gaussian(mu=mu0, sigma=sigma0) prior = distributions.Normal(mu=mu0, sigma=sigma0)
posterior = distributions.gaussian_conjugates_known_sigma_posterior( posterior = distributions.normal_conjugates_known_sigma_posterior(
prior=prior, sigma=sigma, s=s, n=n) prior=prior, sigma=sigma, s=s, n=n)
# Smoke test # Smoke test
self.assertTrue(isinstance(posterior, distributions.Gaussian)) self.assertTrue(isinstance(posterior, distributions.Normal))
posterior_log_pdf = posterior.log_pdf(x).eval() posterior_log_pdf = posterior.log_pdf(x).eval()
self.assertEqual(posterior_log_pdf.shape, (6,)) self.assertEqual(posterior_log_pdf.shape, (6,))
def testGaussianConjugateKnownSigmaPosteriorND(self): def testNormalConjugateKnownSigmaPosteriorND(self):
with tf.Session(): with tf.Session():
batch_size = 6 batch_size = 6
mu0 = tf.constant([[3.0, -3.0]] * batch_size) mu0 = tf.constant([[3.0, -3.0]] * batch_size)
@ -54,16 +54,16 @@ class GaussianTest(tf.test.TestCase):
tf.constant([[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]], dtype=tf.float32)) tf.constant([[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]], dtype=tf.float32))
s = tf.reduce_sum(x) s = tf.reduce_sum(x)
n = tf.size(x) n = tf.size(x)
prior = distributions.Gaussian(mu=mu0, sigma=sigma0) prior = distributions.Normal(mu=mu0, sigma=sigma0)
posterior = distributions.gaussian_conjugates_known_sigma_posterior( posterior = distributions.normal_conjugates_known_sigma_posterior(
prior=prior, sigma=sigma, s=s, n=n) prior=prior, sigma=sigma, s=s, n=n)
# Smoke test # Smoke test
self.assertTrue(isinstance(posterior, distributions.Gaussian)) self.assertTrue(isinstance(posterior, distributions.Normal))
posterior_log_pdf = posterior.log_pdf(x).eval() posterior_log_pdf = posterior.log_pdf(x).eval()
self.assertEqual(posterior_log_pdf.shape, (6, 2)) self.assertEqual(posterior_log_pdf.shape, (6, 2))
def testGaussianConjugateKnownSigmaNDPosteriorND(self): def testNormalConjugateKnownSigmaNDPosteriorND(self):
with tf.Session(): with tf.Session():
batch_size = 6 batch_size = 6
mu0 = tf.constant([[3.0, -3.0]] * batch_size) mu0 = tf.constant([[3.0, -3.0]] * batch_size)
@ -75,19 +75,19 @@ class GaussianTest(tf.test.TestCase):
s = tf.reduce_sum(x, reduction_indices=[1]) s = tf.reduce_sum(x, reduction_indices=[1])
x = tf.transpose(x) # Reshape to shape (6, 2) x = tf.transpose(x) # Reshape to shape (6, 2)
n = tf.constant([6] * 2) n = tf.constant([6] * 2)
prior = distributions.Gaussian(mu=mu0, sigma=sigma0) prior = distributions.Normal(mu=mu0, sigma=sigma0)
posterior = distributions.gaussian_conjugates_known_sigma_posterior( posterior = distributions.normal_conjugates_known_sigma_posterior(
prior=prior, sigma=sigma, s=s, n=n) prior=prior, sigma=sigma, s=s, n=n)
# Smoke test # Smoke test
self.assertTrue(isinstance(posterior, distributions.Gaussian)) self.assertTrue(isinstance(posterior, distributions.Normal))
# Calculate log_pdf under the 2 models # Calculate log_pdf under the 2 models
posterior_log_pdf = posterior.log_pdf(x) posterior_log_pdf = posterior.log_pdf(x)
self.assertEqual(posterior_log_pdf.get_shape(), (6, 2)) self.assertEqual(posterior_log_pdf.get_shape(), (6, 2))
self.assertEqual(posterior_log_pdf.eval().shape, (6, 2)) self.assertEqual(posterior_log_pdf.eval().shape, (6, 2))
def testGaussianConjugateKnownSigmaPredictive(self): def testNormalConjugateKnownSigmaPredictive(self):
with tf.Session(): with tf.Session():
batch_size = 6 batch_size = 6
mu0 = tf.constant([3.0] * batch_size) mu0 = tf.constant([3.0] * batch_size)
@ -96,12 +96,12 @@ class GaussianTest(tf.test.TestCase):
x = tf.constant([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]) x = tf.constant([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0])
s = tf.reduce_sum(x) s = tf.reduce_sum(x)
n = tf.size(x) n = tf.size(x)
prior = distributions.Gaussian(mu=mu0, sigma=sigma0) prior = distributions.Normal(mu=mu0, sigma=sigma0)
predictive = distributions.gaussian_congugates_known_sigma_predictive( predictive = distributions.normal_congugates_known_sigma_predictive(
prior=prior, sigma=sigma, s=s, n=n) prior=prior, sigma=sigma, s=s, n=n)
# Smoke test # Smoke test
self.assertTrue(isinstance(predictive, distributions.Gaussian)) self.assertTrue(isinstance(predictive, distributions.Normal))
predictive_log_pdf = predictive.log_pdf(x).eval() predictive_log_pdf = predictive.log_pdf(x).eval()
self.assertEqual(predictive_log_pdf.shape, (6,)) self.assertEqual(predictive_log_pdf.shape, (6,))

View File

@ -24,9 +24,9 @@ import numpy as np
import tensorflow as tf import tensorflow as tf
class GaussianTest(tf.test.TestCase): class NormalTest(tf.test.TestCase):
def testGaussianLogPDF(self): def testNormalLogPDF(self):
with tf.Session(): with tf.Session():
batch_size = 6 batch_size = 6
mu = tf.constant([3.0] * batch_size) mu = tf.constant([3.0] * batch_size)
@ -34,18 +34,18 @@ class GaussianTest(tf.test.TestCase):
mu_v = 3.0 mu_v = 3.0
sigma_v = np.sqrt(10.0) sigma_v = np.sqrt(10.0)
x = np.array([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0], dtype=np.float32) x = np.array([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0], dtype=np.float32)
gaussian = tf.contrib.distributions.Gaussian(mu=mu, sigma=sigma) normal = tf.contrib.distributions.Normal(mu=mu, sigma=sigma)
expected_log_pdf = np.log( expected_log_pdf = np.log(
1 / np.sqrt(2 * np.pi) / sigma_v 1 / np.sqrt(2 * np.pi) / sigma_v
* np.exp(-1.0 / (2 * sigma_v**2) * (x - mu_v)**2)) * np.exp(-1.0 / (2 * sigma_v**2) * (x - mu_v)**2))
log_pdf = gaussian.log_pdf(x) log_pdf = normal.log_pdf(x)
self.assertAllClose(expected_log_pdf, log_pdf.eval()) self.assertAllClose(expected_log_pdf, log_pdf.eval())
pdf = gaussian.pdf(x) pdf = normal.pdf(x)
self.assertAllClose(np.exp(expected_log_pdf), pdf.eval()) self.assertAllClose(np.exp(expected_log_pdf), pdf.eval())
def testGaussianLogPDFMultidimensional(self): def testNormalLogPDFMultidimensional(self):
with tf.Session(): with tf.Session():
batch_size = 6 batch_size = 6
mu = tf.constant([[3.0, -3.0]] * batch_size) mu = tf.constant([[3.0, -3.0]] * batch_size)
@ -53,22 +53,22 @@ class GaussianTest(tf.test.TestCase):
mu_v = np.array([3.0, -3.0]) mu_v = np.array([3.0, -3.0])
sigma_v = np.array([np.sqrt(10.0), np.sqrt(15.0)]) sigma_v = np.array([np.sqrt(10.0), np.sqrt(15.0)])
x = np.array([[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]], dtype=np.float32).T x = np.array([[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]], dtype=np.float32).T
gaussian = tf.contrib.distributions.Gaussian(mu=mu, sigma=sigma) normal = tf.contrib.distributions.Normal(mu=mu, sigma=sigma)
expected_log_pdf = np.log( expected_log_pdf = np.log(
1 / np.sqrt(2 * np.pi) / sigma_v 1 / np.sqrt(2 * np.pi) / sigma_v
* np.exp(-1.0 / (2 * sigma_v**2) * (x - mu_v)**2)) * np.exp(-1.0 / (2 * sigma_v**2) * (x - mu_v)**2))
log_pdf = gaussian.log_pdf(x) log_pdf = normal.log_pdf(x)
log_pdf_values = log_pdf.eval() log_pdf_values = log_pdf.eval()
self.assertEqual(log_pdf.get_shape(), (6, 2)) self.assertEqual(log_pdf.get_shape(), (6, 2))
self.assertAllClose(expected_log_pdf, log_pdf_values) self.assertAllClose(expected_log_pdf, log_pdf_values)
pdf = gaussian.pdf(x) pdf = normal.pdf(x)
pdf_values = pdf.eval() pdf_values = pdf.eval()
self.assertEqual(pdf.get_shape(), (6, 2)) self.assertEqual(pdf.get_shape(), (6, 2))
self.assertAllClose(np.exp(expected_log_pdf), pdf_values) self.assertAllClose(np.exp(expected_log_pdf), pdf_values)
def testGaussianCDF(self): def testNormalCDF(self):
with tf.Session(): with tf.Session():
batch_size = 6 batch_size = 6
mu = tf.constant([3.0] * batch_size) mu = tf.constant([3.0] * batch_size)
@ -77,40 +77,40 @@ class GaussianTest(tf.test.TestCase):
sigma_v = np.sqrt(10.0) sigma_v = np.sqrt(10.0)
x = np.array([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0], dtype=np.float32) x = np.array([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0], dtype=np.float32)
gaussian = tf.contrib.distributions.Gaussian(mu=mu, sigma=sigma) normal = tf.contrib.distributions.Normal(mu=mu, sigma=sigma)
erf_fn = np.vectorize(math.erf) erf_fn = np.vectorize(math.erf)
# From Wikipedia # From Wikipedia
expected_cdf = 0.5 * (1.0 + erf_fn((x - mu_v)/(sigma_v*np.sqrt(2)))) expected_cdf = 0.5 * (1.0 + erf_fn((x - mu_v)/(sigma_v*np.sqrt(2))))
cdf = gaussian.cdf(x) cdf = normal.cdf(x)
self.assertAllClose(expected_cdf, cdf.eval()) self.assertAllClose(expected_cdf, cdf.eval())
def testGaussianEntropy(self): def testNormalEntropy(self):
with tf.Session(): with tf.Session():
mu_v = np.array([1.0, 1.0, 1.0]) mu_v = np.array([1.0, 1.0, 1.0])
sigma_v = np.array([[1.0, 2.0, 3.0]]).T sigma_v = np.array([[1.0, 2.0, 3.0]]).T
gaussian = tf.contrib.distributions.Gaussian(mu=mu_v, sigma=sigma_v) normal = tf.contrib.distributions.Normal(mu=mu_v, sigma=sigma_v)
sigma_broadcast = mu_v * sigma_v sigma_broadcast = mu_v * sigma_v
expected_entropy = 0.5 * np.log(2*np.pi*np.exp(1)*sigma_broadcast**2) expected_entropy = 0.5 * np.log(2*np.pi*np.exp(1)*sigma_broadcast**2)
self.assertAllClose(expected_entropy, gaussian.entropy().eval()) self.assertAllClose(expected_entropy, normal.entropy().eval())
def testGaussianSample(self): def testNormalSample(self):
with tf.Session(): with tf.Session():
mu = tf.constant(3.0) mu = tf.constant(3.0)
sigma = tf.constant(math.sqrt(10.0)) sigma = tf.constant(math.sqrt(10.0))
mu_v = 3.0 mu_v = 3.0
sigma_v = np.sqrt(10.0) sigma_v = np.sqrt(10.0)
n = tf.constant(100000) n = tf.constant(100000)
gaussian = tf.contrib.distributions.Gaussian(mu=mu, sigma=sigma) normal = tf.contrib.distributions.Normal(mu=mu, sigma=sigma)
samples = gaussian.sample(n, seed=137) samples = normal.sample(n, seed=137)
sample_values = samples.eval() sample_values = samples.eval()
self.assertEqual(sample_values.shape, (100000,)) self.assertEqual(sample_values.shape, (100000,))
self.assertAllClose(sample_values.mean(), mu_v, atol=1e-2) self.assertAllClose(sample_values.mean(), mu_v, atol=1e-2)
self.assertAllClose(sample_values.std(), sigma_v, atol=1e-1) self.assertAllClose(sample_values.std(), sigma_v, atol=1e-1)
def testGaussianSampleMultiDimensional(self): def testNormalSampleMultiDimensional(self):
with tf.Session(): with tf.Session():
batch_size = 2 batch_size = 2
mu = tf.constant([[3.0, -3.0]] * batch_size) mu = tf.constant([[3.0, -3.0]] * batch_size)
@ -118,8 +118,8 @@ class GaussianTest(tf.test.TestCase):
mu_v = [3.0, -3.0] mu_v = [3.0, -3.0]
sigma_v = [np.sqrt(10.0), np.sqrt(15.0)] sigma_v = [np.sqrt(10.0), np.sqrt(15.0)]
n = tf.constant(100000) n = tf.constant(100000)
gaussian = tf.contrib.distributions.Gaussian(mu=mu, sigma=sigma) normal = tf.contrib.distributions.Normal(mu=mu, sigma=sigma)
samples = gaussian.sample(n, seed=137) samples = normal.sample(n, seed=137)
sample_values = samples.eval() sample_values = samples.eval()
self.assertEqual(samples.get_shape(), (100000, batch_size, 2)) self.assertEqual(samples.get_shape(), (100000, batch_size, 2))
self.assertAllClose(sample_values[:, 0, 0].mean(), mu_v[0], atol=1e-2) self.assertAllClose(sample_values[:, 0, 0].mean(), mu_v[0], atol=1e-2)
@ -129,13 +129,13 @@ class GaussianTest(tf.test.TestCase):
def testNegativeSigmaFails(self): def testNegativeSigmaFails(self):
with tf.Session(): with tf.Session():
gaussian = tf.contrib.distributions.Gaussian( normal = tf.contrib.distributions.Normal(
mu=[1.], mu=[1.],
sigma=[-5.], sigma=[-5.],
name='G') name='G')
with self.assertRaisesOpError( with self.assertRaisesOpError(
r'should contain only positive values'): r'should contain only positive values'):
gaussian.mean.eval() normal.mean.eval()
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()

View File

@ -70,6 +70,7 @@ class Exponential(gamma.Gamma):
""" """
broadcast_shape = self._lam.get_shape() broadcast_shape = self._lam.get_shape()
with ops.op_scope([self.lam, n], name, "ExponentialSample"): with ops.op_scope([self.lam, n], name, "ExponentialSample"):
n = ops.convert_to_tensor(n, name="n")
shape = array_ops.concat( shape = array_ops.concat(
0, [array_ops.pack([n]), array_ops.shape(self._lam)]) 0, [array_ops.pack([n]), array_ops.shape(self._lam)])
sampled = random_ops.random_uniform( sampled = random_ops.random_uniform(

View File

@ -38,8 +38,8 @@ def _assert_all_positive(x):
["Tensor %s should contain only positive values: " % x.name, x]) ["Tensor %s should contain only positive values: " % x.name, x])
class Gaussian(object): class Normal(object):
"""The scalar Gaussian distribution with mean and stddev parameters mu, sigma. """The scalar Normal distribution with mean and stddev parameters mu, sigma.
#### Mathematical details #### Mathematical details
@ -52,15 +52,15 @@ class Gaussian(object):
Examples of initialization of one or a batch of distributions. Examples of initialization of one or a batch of distributions.
```python ```python
# Define a single scalar Gaussian distribution. # Define a single scalar Normal distribution.
dist = tf.contrib.distributions.Gaussian(mu=0, sigma=3) dist = tf.contrib.distributions.Normal(mu=0, sigma=3)
# Evaluate the cdf at 1, returning a scalar. # Evaluate the cdf at 1, returning a scalar.
dist.cdf(1) dist.cdf(1)
# Define a batch of two scalar valued Gaussians. # Define a batch of two scalar valued Normals.
# The first has mean 1 and standard deviation 11, the second 2 and 22. # The first has mean 1 and standard deviation 11, the second 2 and 22.
dist = tf.contrib.distributions.Gaussian(mu=[1, 2.], sigma=[11, 22.]) dist = tf.contrib.distributions.Normal(mu=[1, 2.], sigma=[11, 22.])
# Evaluate the pdf of the first distribution on 0, and the second on 1.5, # Evaluate the pdf of the first distribution on 0, and the second on 1.5,
# returning a length two tensor. # returning a length two tensor.
@ -73,9 +73,9 @@ class Gaussian(object):
Arguments are broadcast when possible. Arguments are broadcast when possible.
```python ```python
# Define a batch of two scalar valued Gaussians. # Define a batch of two scalar valued Normals.
# Both have mean 1, but different standard deviations. # Both have mean 1, but different standard deviations.
dist = tf.contrib.distributions.Gaussian(mu=1, sigma=[11, 22.]) dist = tf.contrib.distributions.Normal(mu=1, sigma=[11, 22.])
# Evaluate the pdf of both distributions on the same point, 3.0, # Evaluate the pdf of both distributions on the same point, 3.0,
# returning a length 2 tensor. # returning a length 2 tensor.
@ -85,7 +85,7 @@ class Gaussian(object):
""" """
def __init__(self, mu, sigma, name=None): def __init__(self, mu, sigma, name=None):
"""Construct Gaussian distributions with mean and stddev `mu` and `sigma`. """Construct Normal distributions with mean and stddev `mu` and `sigma`.
The parameters `mu` and `sigma` must be shaped in a way that supports The parameters `mu` and `sigma` must be shaped in a way that supports
broadcasting (e.g. `mu + sigma` is a valid operation). broadcasting (e.g. `mu + sigma` is a valid operation).
@ -99,7 +99,7 @@ class Gaussian(object):
Raises: Raises:
TypeError: if mu and sigma are different dtypes. TypeError: if mu and sigma are different dtypes.
""" """
with ops.op_scope([mu, sigma], name, "Gaussian"): with ops.op_scope([mu, sigma], name, "Normal"):
mu = ops.convert_to_tensor(mu) mu = ops.convert_to_tensor(mu)
sigma = ops.convert_to_tensor(sigma) sigma = ops.convert_to_tensor(sigma)
with ops.control_dependencies([_assert_all_positive(sigma)]): with ops.control_dependencies([_assert_all_positive(sigma)]):
@ -125,7 +125,7 @@ class Gaussian(object):
return self._mu * array_ops.ones_like(self._sigma) return self._mu * array_ops.ones_like(self._sigma)
def log_pdf(self, x, name=None): def log_pdf(self, x, name=None):
"""Log pdf of observations in `x` under these Gaussian distribution(s). """Log pdf of observations in `x` under these Normal distribution(s).
Args: Args:
x: tensor of dtype `dtype`, must be broadcastable with `mu` and `sigma`. x: tensor of dtype `dtype`, must be broadcastable with `mu` and `sigma`.
@ -134,7 +134,7 @@ class Gaussian(object):
Returns: Returns:
log_pdf: tensor of dtype `dtype`, the log-PDFs of `x`. log_pdf: tensor of dtype `dtype`, the log-PDFs of `x`.
""" """
with ops.op_scope([self._mu, self._sigma, x], name, "GaussianLogPdf"): with ops.op_scope([self._mu, self._sigma, x], name, "NormalLogPdf"):
x = ops.convert_to_tensor(x) x = ops.convert_to_tensor(x)
if x.dtype != self.dtype: if x.dtype != self.dtype:
raise TypeError("Input x dtype does not match dtype: %s vs. %s" raise TypeError("Input x dtype does not match dtype: %s vs. %s"
@ -144,7 +144,7 @@ class Gaussian(object):
-0.5*math_ops.square((x - self._mu) / self._sigma)) -0.5*math_ops.square((x - self._mu) / self._sigma))
def cdf(self, x, name=None): def cdf(self, x, name=None):
"""CDF of observations in `x` under these Gaussian distribution(s). """CDF of observations in `x` under these Normal distribution(s).
Args: Args:
x: tensor of dtype `dtype`, must be broadcastable with `mu` and `sigma`. x: tensor of dtype `dtype`, must be broadcastable with `mu` and `sigma`.
@ -153,7 +153,7 @@ class Gaussian(object):
Returns: Returns:
cdf: tensor of dtype `dtype`, the CDFs of `x`. cdf: tensor of dtype `dtype`, the CDFs of `x`.
""" """
with ops.op_scope([self._mu, self._sigma, x], name, "GaussianCdf"): with ops.op_scope([self._mu, self._sigma, x], name, "NormalCdf"):
x = ops.convert_to_tensor(x) x = ops.convert_to_tensor(x)
if x.dtype != self.dtype: if x.dtype != self.dtype:
raise TypeError("Input x dtype does not match dtype: %s vs. %s" raise TypeError("Input x dtype does not match dtype: %s vs. %s"
@ -162,7 +162,7 @@ class Gaussian(object):
1.0/(math.sqrt(2.0) * self._sigma)*(x - self._mu))) 1.0/(math.sqrt(2.0) * self._sigma)*(x - self._mu)))
def log_cdf(self, x, name=None): def log_cdf(self, x, name=None):
"""Log CDF of observations `x` under these Gaussian distribution(s). """Log CDF of observations `x` under these Normal distribution(s).
Args: Args:
x: tensor of dtype `dtype`, must be broadcastable with `mu` and `sigma`. x: tensor of dtype `dtype`, must be broadcastable with `mu` and `sigma`.
@ -171,11 +171,11 @@ class Gaussian(object):
Returns: Returns:
log_cdf: tensor of dtype `dtype`, the log-CDFs of `x`. log_cdf: tensor of dtype `dtype`, the log-CDFs of `x`.
""" """
with ops.op_scope([self._mu, self._sigma, x], name, "GaussianLogCdf"): with ops.op_scope([self._mu, self._sigma, x], name, "NormalLogCdf"):
return math_ops.log(self.cdf(x)) return math_ops.log(self.cdf(x))
def pdf(self, x, name=None): def pdf(self, x, name=None):
"""The PDF of observations in `x` under these Gaussian distribution(s). """The PDF of observations in `x` under these Normal distribution(s).
Args: Args:
x: tensor of dtype `dtype`, must be broadcastable with `mu` and `sigma`. x: tensor of dtype `dtype`, must be broadcastable with `mu` and `sigma`.
@ -184,11 +184,11 @@ class Gaussian(object):
Returns: Returns:
pdf: tensor of dtype `dtype`, the pdf values of `x`. pdf: tensor of dtype `dtype`, the pdf values of `x`.
""" """
with ops.op_scope([self._mu, self._sigma, x], name, "GaussianPdf"): with ops.op_scope([self._mu, self._sigma, x], name, "NormalPdf"):
return math_ops.exp(self.log_pdf(x)) return math_ops.exp(self.log_pdf(x))
def entropy(self, name=None): def entropy(self, name=None):
"""The entropy of Gaussian distribution(s). """The entropy of Normal distribution(s).
Args: Args:
name: The name to give this op. name: The name to give this op.
@ -196,7 +196,7 @@ class Gaussian(object):
Returns: Returns:
entropy: tensor of dtype `dtype`, the entropy. entropy: tensor of dtype `dtype`, the entropy.
""" """
with ops.op_scope([self._mu, self._sigma], name, "GaussianEntropy"): with ops.op_scope([self._mu, self._sigma], name, "NormalEntropy"):
two_pi_e1 = constant_op.constant( two_pi_e1 = constant_op.constant(
2 * math.pi * math.exp(1), dtype=self.dtype) 2 * math.pi * math.exp(1), dtype=self.dtype)
# Use broadcasting rules to calculate the full broadcast sigma. # Use broadcasting rules to calculate the full broadcast sigma.
@ -204,7 +204,7 @@ class Gaussian(object):
return 0.5 * math_ops.log(two_pi_e1 * math_ops.square(sigma)) return 0.5 * math_ops.log(two_pi_e1 * math_ops.square(sigma))
def sample(self, n, seed=None, name=None): def sample(self, n, seed=None, name=None):
"""Sample `n` observations from the Gaussian Distributions. """Sample `n` observations from the Normal Distributions.
Args: Args:
n: `Scalar`, type int32, the number of observations to sample. n: `Scalar`, type int32, the number of observations to sample.
@ -215,7 +215,7 @@ class Gaussian(object):
samples: `[n, ...]`, a `Tensor` of `n` samples for each samples: `[n, ...]`, a `Tensor` of `n` samples for each
of the distributions determined by broadcasting the hyperparameters. of the distributions determined by broadcasting the hyperparameters.
""" """
with ops.op_scope([self._mu, self._sigma, n], name, "GaussianSample"): with ops.op_scope([self._mu, self._sigma, n], name, "NormalSample"):
broadcast_shape = (self._mu + self._sigma).get_shape() broadcast_shape = (self._mu + self._sigma).get_shape()
n = ops.convert_to_tensor(n) n = ops.convert_to_tensor(n)
shape = array_ops.concat( shape = array_ops.concat(

View File

@ -12,32 +12,32 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""The Gaussian distribution: conjugate posterior closed form calculations.""" """The Normal distribution: conjugate posterior closed form calculations."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.contrib.distributions.python.ops.gaussian import Gaussian # pylint: disable=line-too-long from tensorflow.contrib.distributions.python.ops.normal import Normal # pylint: disable=line-too-long
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
def gaussian_conjugates_known_sigma_posterior(prior, sigma, s, n): def normal_conjugates_known_sigma_posterior(prior, sigma, s, n):
"""Posterior Gaussian distribution with conjugate prior on the mean. """Posterior Normal distribution with conjugate prior on the mean.
This model assumes that `n` observations (with sum `s`) come from a This model assumes that `n` observations (with sum `s`) come from a
Gaussian with unknown mean `mu` (described by the Gaussian `prior`) Normal with unknown mean `mu` (described by the Normal `prior`)
and known variance `sigma^2`. The "known sigma posterior" is and known variance `sigma^2`. The "known sigma posterior" is
the distribution of the unknown `mu`. the distribution of the unknown `mu`.
Accepts a prior Gaussian distribution object, having parameters Accepts a prior Normal distribution object, having parameters
`mu0` and `sigma0`, as well as known `sigma` values of the predictive `mu0` and `sigma0`, as well as known `sigma` values of the predictive
distribution(s) (also assumed Gaussian), distribution(s) (also assumed Normal),
and statistical estimates `s` (the sum(s) of the observations) and and statistical estimates `s` (the sum(s) of the observations) and
`n` (the number(s) of observations). `n` (the number(s) of observations).
Returns a posterior (also Gaussian) distribution object, with parameters Returns a posterior (also Normal) distribution object, with parameters
`(mu', sigma'^2)`, where: `(mu', sigma'^2)`, where:
``` ```
@ -50,7 +50,7 @@ def gaussian_conjugates_known_sigma_posterior(prior, sigma, s, n):
will broadcast in the case of multidimensional sets of parameters. will broadcast in the case of multidimensional sets of parameters.
Args: Args:
prior: `Gaussian` object of type `dtype`: prior: `Normal` object of type `dtype`:
the prior distribution having parameters `(mu0, sigma0)`. the prior distribution having parameters `(mu0, sigma0)`.
sigma: tensor of type `dtype`, taking values `sigma > 0`. sigma: tensor of type `dtype`, taking values `sigma > 0`.
The known stddev parameter(s). The known stddev parameter(s).
@ -58,15 +58,15 @@ def gaussian_conjugates_known_sigma_posterior(prior, sigma, s, n):
n: Tensor of type `int`. The number(s) of observations. n: Tensor of type `int`. The number(s) of observations.
Returns: Returns:
A new Gaussian posterior distribution object for the unknown observation A new Normal posterior distribution object for the unknown observation
mean `mu`. mean `mu`.
Raises: Raises:
TypeError: if dtype of `s` does not match `dtype`, or `prior` is not a TypeError: if dtype of `s` does not match `dtype`, or `prior` is not a
Gaussian object. Normal object.
""" """
if not isinstance(prior, Gaussian): if not isinstance(prior, Normal):
raise TypeError("Expected prior to be an instance of type Gaussian") raise TypeError("Expected prior to be an instance of type Normal")
if s.dtype != prior.dtype: if s.dtype != prior.dtype:
raise TypeError( raise TypeError(
@ -77,27 +77,27 @@ def gaussian_conjugates_known_sigma_posterior(prior, sigma, s, n):
sigma0_2 = math_ops.square(prior.sigma) sigma0_2 = math_ops.square(prior.sigma)
sigma_2 = math_ops.square(sigma) sigma_2 = math_ops.square(sigma)
sigmap_2 = 1.0/(1/sigma0_2 + n/sigma_2) sigmap_2 = 1.0/(1/sigma0_2 + n/sigma_2)
return Gaussian( return Normal(
mu=(prior.mu/sigma0_2 + s/sigma_2) * sigmap_2, mu=(prior.mu/sigma0_2 + s/sigma_2) * sigmap_2,
sigma=math_ops.sqrt(sigmap_2)) sigma=math_ops.sqrt(sigmap_2))
def gaussian_congugates_known_sigma_predictive(prior, sigma, s, n): def normal_congugates_known_sigma_predictive(prior, sigma, s, n):
"""Posterior predictive Gaussian distribution w. conjugate prior on the mean. """Posterior predictive Normal distribution w. conjugate prior on the mean.
This model assumes that `n` observations (with sum `s`) come from a This model assumes that `n` observations (with sum `s`) come from a
Gaussian with unknown mean `mu` (described by the Gaussian `prior`) Normal with unknown mean `mu` (described by the Normal `prior`)
and known variance `sigma^2`. The "known sigma predictive" and known variance `sigma^2`. The "known sigma predictive"
is the distribution of new observations, conditioned on the existing is the distribution of new observations, conditioned on the existing
observations and our prior. observations and our prior.
Accepts a prior Gaussian distribution object, having parameters Accepts a prior Normal distribution object, having parameters
`mu0` and `sigma0`, as well as known `sigma` values of the predictive `mu0` and `sigma0`, as well as known `sigma` values of the predictive
distribution(s) (also assumed Gaussian), distribution(s) (also assumed Normal),
and statistical estimates `s` (the sum(s) of the observations) and and statistical estimates `s` (the sum(s) of the observations) and
`n` (the number(s) of observations). `n` (the number(s) of observations).
Calculates the Gaussian distribution(s) `p(x | sigma^2)`: Calculates the Normal distribution(s) `p(x | sigma^2)`:
``` ```
p(x | sigma^2) = int N(x | mu, sigma^2) N(mu | prior.mu, prior.sigma^2) dmu p(x | sigma^2) = int N(x | mu, sigma^2) N(mu | prior.mu, prior.sigma^2) dmu
@ -117,7 +117,7 @@ def gaussian_congugates_known_sigma_predictive(prior, sigma, s, n):
will broadcast in the case of multidimensional sets of parameters. will broadcast in the case of multidimensional sets of parameters.
Args: Args:
prior: `Gaussian` object of type `dtype`: prior: `Normal` object of type `dtype`:
the prior distribution having parameters `(mu0, sigma0)`. the prior distribution having parameters `(mu0, sigma0)`.
sigma: tensor of type `dtype`, taking values `sigma > 0`. sigma: tensor of type `dtype`, taking values `sigma > 0`.
The known stddev parameter(s). The known stddev parameter(s).
@ -125,14 +125,14 @@ def gaussian_congugates_known_sigma_predictive(prior, sigma, s, n):
n: Tensor of type `int`. The number(s) of observations. n: Tensor of type `int`. The number(s) of observations.
Returns: Returns:
A new Gaussian predictive distribution object. A new Normal predictive distribution object.
Raises: Raises:
TypeError: if dtype of `s` does not match `dtype`, or `prior` is not a TypeError: if dtype of `s` does not match `dtype`, or `prior` is not a
Gaussian object. Normal object.
""" """
if not isinstance(prior, Gaussian): if not isinstance(prior, Normal):
raise TypeError("Expected prior to be an instance of type Gaussian") raise TypeError("Expected prior to be an instance of type Normal")
if s.dtype != prior.dtype: if s.dtype != prior.dtype:
raise TypeError( raise TypeError(
@ -143,6 +143,6 @@ def gaussian_congugates_known_sigma_predictive(prior, sigma, s, n):
sigma0_2 = math_ops.square(prior.sigma) sigma0_2 = math_ops.square(prior.sigma)
sigma_2 = math_ops.square(sigma) sigma_2 = math_ops.square(sigma)
sigmap_2 = 1.0/(1/sigma0_2 + n/sigma_2) sigmap_2 = 1.0/(1/sigma0_2 + n/sigma_2)
return Gaussian( return Normal(
mu=(prior.mu/sigma0_2 + s/sigma_2) * sigmap_2, mu=(prior.mu/sigma0_2 + s/sigma_2) * sigmap_2,
sigma=math_ops.sqrt(sigmap_2 + sigma_2)) sigma=math_ops.sqrt(sigmap_2 + sigma_2))

View File

@ -17,6 +17,8 @@ filegroup(
srcs = glob(["testdata/*"]), srcs = glob(["testdata/*"]),
) )
exports_files(["ffmpeg_lib.h"])
cc_library( cc_library(
name = "decode_audio_op_cc", name = "decode_audio_op_cc",
srcs = ["decode_audio_op.cc"], srcs = ["decode_audio_op.cc"],

View File

@ -18,7 +18,7 @@
#include <cstdio> #include <cstdio>
#include <set> #include <set>
#include "tensorflow/contrib/ffmpeg/default/ffmpeg_lib.h" #include "tensorflow/contrib/ffmpeg/ffmpeg_lib.h"
#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/io/path.h"

View File

@ -11,7 +11,10 @@ package(default_visibility = ["//tensorflow:__subpackages__"])
cc_library( cc_library(
name = "ffmpeg_lib", name = "ffmpeg_lib",
srcs = ["ffmpeg_lib.cc"], srcs = ["ffmpeg_lib.cc"],
hdrs = ["ffmpeg_lib.h"], hdrs = [
# Header is shared between implementations.
"//tensorflow/contrib/ffmpeg:ffmpeg_lib.h",
],
deps = [ deps = [
"//google/protobuf", "//google/protobuf",
"//tensorflow/core:framework_headers_lib", "//tensorflow/core:framework_headers_lib",

View File

@ -13,7 +13,7 @@
// limitations under the License. // limitations under the License.
// ============================================================================= // =============================================================================
#include "tensorflow/contrib/ffmpeg/default/ffmpeg_lib.h" #include "tensorflow/contrib/ffmpeg/ffmpeg_lib.h"
#include <errno.h> #include <errno.h>
#include <stdlib.h> #include <stdlib.h>
@ -212,9 +212,9 @@ Status ReadAudioFile(const string& filename,
} }
} }
Status CreateAudioFile(const string& audio_format_id, int32 samples_per_second, Status CreateAudioFile(const string& audio_format_id, int32 bits_per_second,
int32 channel_count, const std::vector<float>& samples, int32 samples_per_second, int32 channel_count,
string* output_data) { const std::vector<float>& samples, string* output_data) {
if (audio_format_id != "wav") { if (audio_format_id != "wav") {
return Status(error::Code::INVALID_ARGUMENT, return Status(error::Code::INVALID_ARGUMENT,
"CreateAudioFile only supports the 'wav' audio format."); "CreateAudioFile only supports the 'wav' audio format.");

View File

@ -13,7 +13,7 @@
// limitations under the License. // limitations under the License.
// ============================================================================= // =============================================================================
#include "tensorflow/contrib/ffmpeg/default/ffmpeg_lib.h" #include "tensorflow/contrib/ffmpeg/ffmpeg_lib.h"
#include <stdlib.h> #include <stdlib.h>
#include <vector> #include <vector>
@ -91,7 +91,7 @@ TEST(FfmpegLibTest, TestRoundTripGeneratedWav) {
sine_wave.push_back(std::sin(6.28 * 440.0 * i / 20000.0)); sine_wave.push_back(std::sin(6.28 * 440.0 * i / 20000.0));
} }
string content; string content;
ASSERT_TRUE(CreateAudioFile("wav", 20000, 1, sine_wave, &content).ok()); ASSERT_TRUE(CreateAudioFile("wav", 0, 20000, 1, sine_wave, &content).ok());
string temp_filename = GetTempFilename("wav"); string temp_filename = GetTempFilename("wav");
ASSERT_TRUE(WriteStringToFile(Env::Default(), temp_filename, content).ok()); ASSERT_TRUE(WriteStringToFile(Env::Default(), temp_filename, content).ok());
std::vector<float> roundtrip_data; std::vector<float> roundtrip_data;
@ -122,7 +122,7 @@ TEST(FfmpegLibTest, TestRoundTripWav) {
string written_audio; string written_audio;
ASSERT_TRUE( ASSERT_TRUE(
CreateAudioFile("wav", 10000, 1, output_samples, &written_audio).ok()); CreateAudioFile("wav", 0, 10000, 1, output_samples, &written_audio).ok());
EXPECT_EQ(original_audio, written_audio); EXPECT_EQ(original_audio, written_audio);
} }

View File

@ -15,7 +15,7 @@
#include <limits> #include <limits>
#include "tensorflow/contrib/ffmpeg/default/ffmpeg_lib.h" #include "tensorflow/contrib/ffmpeg/ffmpeg_lib.h"
#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
@ -35,6 +35,8 @@ class EncodeAudioOp : public OpKernel {
context, context->GetAttr("samples_per_second", &samples_per_second_)); context, context->GetAttr("samples_per_second", &samples_per_second_));
OP_REQUIRES(context, samples_per_second_ > 0, OP_REQUIRES(context, samples_per_second_ > 0,
errors::InvalidArgument("samples_per_second must be > 0.")); errors::InvalidArgument("samples_per_second must be > 0."));
OP_REQUIRES_OK(
context, context->GetAttr("bits_per_second", &bits_per_second_));
} }
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {
@ -61,9 +63,9 @@ class EncodeAudioOp : public OpKernel {
} }
const int32 channel_count = contents.dim_size(1); const int32 channel_count = contents.dim_size(1);
string encoded_audio; string encoded_audio;
OP_REQUIRES_OK(context, OP_REQUIRES_OK(context, CreateAudioFile(file_format_, bits_per_second_,
CreateAudioFile(file_format_, samples_per_second_, samples_per_second_, channel_count,
channel_count, samples, &encoded_audio)); samples, &encoded_audio));
// Copy the encoded audio file to the output tensor. // Copy the encoded audio file to the output tensor.
Tensor* output = nullptr; Tensor* output = nullptr;
@ -75,6 +77,7 @@ class EncodeAudioOp : public OpKernel {
private: private:
string file_format_; string file_format_;
int32 samples_per_second_; int32 samples_per_second_;
int32 bits_per_second_;
}; };
REGISTER_KERNEL_BUILDER(Name("EncodeAudio").Device(DEVICE_CPU), EncodeAudioOp); REGISTER_KERNEL_BUILDER(Name("EncodeAudio").Device(DEVICE_CPU), EncodeAudioOp);
@ -84,6 +87,7 @@ REGISTER_OP("EncodeAudio")
.Output("contents: string") .Output("contents: string")
.Attr("file_format: string") .Attr("file_format: string")
.Attr("samples_per_second: int") .Attr("samples_per_second: int")
.Attr("bits_per_second: int = 192000")
.Doc(R"doc( .Doc(R"doc(
Processes a `Tensor` containing sampled audio with the number of channels Processes a `Tensor` containing sampled audio with the number of channels
and length of the audio specified by the dimensions of the `Tensor`. The and length of the audio specified by the dimensions of the `Tensor`. The
@ -100,6 +104,8 @@ sampled_audio: A rank 2 tensor containing all tracks of the audio. Dimension 0
contents: The binary audio file contents. contents: The binary audio file contents.
file_format: A string describing the audio file format. This must be "wav". file_format: A string describing the audio file format. This must be "wav".
samples_per_second: The number of samples per second that the audio should have. samples_per_second: The number of samples per second that the audio should have.
bits_per_second: The approximate bitrate of the encoded audio file. This is
ignored by the "wav" file format.
)doc"); )doc");
} // namespace ffmpeg } // namespace ffmpeg

View File

@ -13,10 +13,11 @@
// limitations under the License. // limitations under the License.
// ============================================================================= // =============================================================================
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_FFMPEG_DEFAULT_FFMPEG_LIB_H_ #ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_FFMPEG_FFMPEG_LIB_H_
#define THIRD_PARTY_TENSORFLOW_CONTRIB_FFMPEG_DEFAULT_FFMPEG_LIB_H_ #define THIRD_PARTY_TENSORFLOW_CONTRIB_FFMPEG_FFMPEG_LIB_H_
#include <string> #include <string>
#include <vector>
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
@ -40,9 +41,9 @@ Status ReadAudioFile(const string& filename,
// contain a separate sample for each channel. Frames are ordered by time. // contain a separate sample for each channel. Frames are ordered by time.
// Currently, the implementation only supports wav files, and ffmpeg is not used // Currently, the implementation only supports wav files, and ffmpeg is not used
// to create them. // to create them.
Status CreateAudioFile(const string& audio_format_id, int32 samples_per_second, Status CreateAudioFile(const string& audio_format_id, int32 bits_per_second,
int32 channel_count, const std::vector<float>& samples, int32 samples_per_second, int32 channel_count,
string* output_data); const std::vector<float>& samples, string* output_data);
} // namespace ffmpeg } // namespace ffmpeg
} // namespace tensorflow } // namespace tensorflow

View File

@ -39,6 +39,7 @@ from tensorflow.python.training import moving_averages
# TODO(b/28426988): Remove legacy_* when all uses have migrated to new API. # TODO(b/28426988): Remove legacy_* when all uses have migrated to new API.
__all__ = ['bias_add', __all__ = ['bias_add',
'batch_norm', 'batch_norm',
'conv2d',
'convolution2d', 'convolution2d',
'fully_connected', 'fully_connected',
'linear', 'linear',
@ -113,7 +114,7 @@ def batch_norm(inputs,
scale=False, scale=False,
epsilon=0.001, epsilon=0.001,
activation_fn=None, activation_fn=None,
updates_collection=None, updates_collections=ops.GraphKeys.UPDATE_OPS,
is_training=True, is_training=True,
reuse=None, reuse=None,
variables_collections=None, variables_collections=None,
@ -138,8 +139,9 @@ def batch_norm(inputs,
disabled since the scaling can be done by the next layer. disabled since the scaling can be done by the next layer.
epsilon: small float added to variance to avoid dividing by zero. epsilon: small float added to variance to avoid dividing by zero.
activation_fn: Optional activation function. activation_fn: Optional activation function.
updates_collection: collection to collect the update ops for computation. If updates_collections: collections to collect the update ops for computation.
None a control dependency would be added to make sure they are computed. If None, a control dependency would be added to make sure the updates are
computed.
is_training: whether or not the layer is in training mode. In training mode is_training: whether or not the layer is in training mode. In training mode
it would accumulate the statistics of the moments into `moving_mean` and it would accumulate the statistics of the moments into `moving_mean` and
`moving_variance` using an exponential moving average with the given `moving_variance` using an exponential moving average with the given
@ -207,7 +209,7 @@ def batch_norm(inputs,
moving_mean, mean, decay) moving_mean, mean, decay)
update_moving_variance = moving_averages.assign_moving_average( update_moving_variance = moving_averages.assign_moving_average(
moving_variance, variance, decay) moving_variance, variance, decay)
if updates_collection is None: if updates_collections is None:
# Make sure the updates are computed here. # Make sure the updates are computed here.
with ops.control_dependencies([update_moving_mean, with ops.control_dependencies([update_moving_mean,
update_moving_variance]): update_moving_variance]):
@ -215,8 +217,8 @@ def batch_norm(inputs,
inputs, mean, variance, beta, gamma, epsilon) inputs, mean, variance, beta, gamma, epsilon)
else: else:
# Collect the updates to be computed later. # Collect the updates to be computed later.
ops.add_to_collection(updates_collection, update_moving_mean) ops.add_to_collections(updates_collections, update_moving_mean)
ops.add_to_collection(updates_collection, update_moving_variance) ops.add_to_collections(updates_collections, update_moving_variance)
outputs = nn.batch_normalization( outputs = nn.batch_normalization(
inputs, mean, variance, beta, gamma, epsilon) inputs, mean, variance, beta, gamma, epsilon)
else: else:
@ -504,22 +506,6 @@ def legacy_fully_connected(x,
Raises: Raises:
ValueError: if x has rank less than 2 or if its last dimension is not set. ValueError: if x has rank less than 2 or if its last dimension is not set.
""" """
# pylint: enable=anomalous-backslash-in-string
# TODO(ptucker) redirect to fully_connected
# _ = trainable
# variables_collections = {'weights': weight_collections,
# 'biases': bias_collections}
# outputs = fully_connected(inputs=x,
# num_outputs=num_output_units,
# activation_fn=activation_fn,
# weights_initializer=weight_init,
# weights_regularizer=weight_regularizer,
# biases_initializer=bias_init,
# biases_regularizer=bias_regularizer,
# variables_collections=variables_collections,
# scope=name)
# ops.add_to_collections(output_collections, outputs)
# return outputs
with variable_scope.variable_op_scope([x], name, 'fully_connected'): with variable_scope.variable_op_scope([x], name, 'fully_connected'):
dims = x.get_shape().dims dims = x.get_shape().dims
if dims is None: if dims is None:
@ -645,24 +631,6 @@ def legacy_convolution2d(x,
Raises: Raises:
ValueError: If `kernel_size` or `stride` are not length 2. ValueError: If `kernel_size` or `stride` are not length 2.
""" """
# TODO(ptucker) redirect to convolution2d
# _ = trainable
# variables_collections = {'weights': weight_collections,
# 'biases': bias_collections}
# outputs = convolution2d(inputs=x,
# num_outputs=num_output_channels,
# kernel_size=kernel_size,
# stride=stride,
# padding=padding,
# activation_fn=activation_fn,
# weights_initializer=weight_init,
# weights_regularizer=weight_regularizer,
# biases_initializer=bias_init,
# biases_regularizer=bias_regularizer,
# variables_collections=variables_collections,
# scope=name)
# ops.add_to_collections(output_collections, outputs)
# return outputs
with variable_scope.variable_op_scope([x], name, 'convolution2d'): with variable_scope.variable_op_scope([x], name, 'convolution2d'):
num_input_channels = x.get_shape().dims[3].value num_input_channels = x.get_shape().dims[3].value
@ -714,3 +682,6 @@ linear = legacy_linear
relu = legacy_relu relu = legacy_relu
relu6 = legacy_relu6 relu6 = legacy_relu6
# Simple alias for convolution2d.
conv2d = convolution2d

View File

@ -430,8 +430,8 @@ class BatchNormTest(tf.test.TestCase):
height, width = 3, 3 height, width = 3, 3
with self.test_session(): with self.test_session():
images = tf.random_uniform((5, height, width, 3), seed=1) images = tf.random_uniform((5, height, width, 3), seed=1)
tf.contrib.layers.batch_norm(images, updates_collection='update_ops') tf.contrib.layers.batch_norm(images, updates_collections='my_update_ops')
update_layers = tf.get_collection('update_ops') update_layers = tf.get_collection('my_update_ops')
update_moving_mean = update_layers[0] update_moving_mean = update_layers[0]
update_moving_variance = update_layers[1] update_moving_variance = update_layers[1]
self.assertEquals(update_moving_mean.op.name, self.assertEquals(update_moving_mean.op.name,
@ -460,7 +460,7 @@ class BatchNormTest(tf.test.TestCase):
with self.test_session(): with self.test_session():
images = tf.random_uniform((5, height, width, 3), seed=1) images = tf.random_uniform((5, height, width, 3), seed=1)
with tf.contrib.framework.arg_scope([tf.contrib.layers.batch_norm], with tf.contrib.framework.arg_scope([tf.contrib.layers.batch_norm],
updates_collection='update_ops'): updates_collections='update_ops'):
tf.contrib.layers.batch_norm(images, scope='bn') tf.contrib.layers.batch_norm(images, scope='bn')
self.assertEquals(len(tf.get_collection('update_ops')), 2) self.assertEquals(len(tf.get_collection('update_ops')), 2)
tf.contrib.layers.batch_norm(images, scope='bn', reuse=True) tf.contrib.layers.batch_norm(images, scope='bn', reuse=True)
@ -479,7 +479,7 @@ class BatchNormTest(tf.test.TestCase):
self.assertEquals(len(moving_variance), 1) self.assertEquals(len(moving_variance), 1)
self.assertEquals(moving_variance[0].op.name, 'BatchNorm/moving_variance') self.assertEquals(moving_variance[0].op.name, 'BatchNorm/moving_variance')
def testUpdateMovingVars(self): def testForceUpdateMovingVars(self):
height, width = 3, 3 height, width = 3, 3
with self.test_session() as sess: with self.test_session() as sess:
image_shape = (10, height, width, 3) image_shape = (10, height, width, 3)
@ -487,7 +487,8 @@ class BatchNormTest(tf.test.TestCase):
expected_mean = np.mean(image_values, axis=(0, 1, 2)) expected_mean = np.mean(image_values, axis=(0, 1, 2))
expected_var = np.var(image_values, axis=(0, 1, 2)) expected_var = np.var(image_values, axis=(0, 1, 2))
images = tf.constant(image_values, shape=image_shape, dtype=tf.float32) images = tf.constant(image_values, shape=image_shape, dtype=tf.float32)
output = tf.contrib.layers.batch_norm(images, decay=0.1) output = tf.contrib.layers.batch_norm(images, decay=0.1,
updates_collections=None)
# Initialize all variables # Initialize all variables
sess.run(tf.initialize_all_variables()) sess.run(tf.initialize_all_variables())
moving_mean = tf.contrib.framework.get_variables( moving_mean = tf.contrib.framework.get_variables(
@ -515,9 +516,8 @@ class BatchNormTest(tf.test.TestCase):
expected_mean = np.mean(image_values, axis=(0, 1, 2)) expected_mean = np.mean(image_values, axis=(0, 1, 2))
expected_var = np.var(image_values, axis=(0, 1, 2)) expected_var = np.var(image_values, axis=(0, 1, 2))
images = tf.constant(image_values, shape=image_shape, dtype=tf.float32) images = tf.constant(image_values, shape=image_shape, dtype=tf.float32)
output = tf.contrib.layers.batch_norm(images, decay=0.1, output = tf.contrib.layers.batch_norm(images, decay=0.1)
updates_collection='update_ops') update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
update_ops = tf.get_collection('update_ops')
with tf.control_dependencies(update_ops): with tf.control_dependencies(update_ops):
barrier = tf.no_op(name='barrier') barrier = tf.no_op(name='barrier')
output = control_flow_ops.with_dependencies([barrier], output) output = control_flow_ops.with_dependencies([barrier], output)
@ -550,10 +550,9 @@ class BatchNormTest(tf.test.TestCase):
images = tf.constant(image_values, shape=image_shape, dtype=tf.float32) images = tf.constant(image_values, shape=image_shape, dtype=tf.float32)
output = tf.contrib.layers.batch_norm(images, output = tf.contrib.layers.batch_norm(images,
decay=0.1, decay=0.1,
is_training=False, is_training=False)
updates_collection='update_ops') update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
update_layers = tf.get_collection('update_ops') self.assertEquals(update_ops, [])
self.assertEquals(update_layers, [])
# Initialize all variables # Initialize all variables
sess.run(tf.initialize_all_variables()) sess.run(tf.initialize_all_variables())
moving_mean = tf.contrib.framework.get_variables( moving_mean = tf.contrib.framework.get_variables(
@ -587,10 +586,9 @@ class BatchNormTest(tf.test.TestCase):
images = tf.constant(image_values, shape=image_shape, dtype=tf.float32) images = tf.constant(image_values, shape=image_shape, dtype=tf.float32)
output = tf.contrib.layers.batch_norm(images, output = tf.contrib.layers.batch_norm(images,
decay=0.1, decay=0.1,
is_training=False, is_training=False)
updates_collection='update_ops') update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
update_layers = tf.get_collection('update_ops') self.assertEquals(update_ops, [])
self.assertEquals(update_layers, [])
# Initialize all variables # Initialize all variables
sess.run(tf.initialize_all_variables()) sess.run(tf.initialize_all_variables())
moving_mean = tf.contrib.framework.get_variables( moving_mean = tf.contrib.framework.get_variables(

View File

@ -1,5 +1,4 @@
"""Main Scikit Flow module.""" # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""High level API for learning with TensorFlow."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function

View File

@ -1,5 +1,4 @@
"""Base utilities for loading datasets.""" # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Base utilities for loading datasets."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function

View File

@ -1,5 +1,4 @@
"""Scikit Flow Estimators.""" # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -12,12 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Estimators."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.contrib.learn.python.learn.estimators.autoencoder import TensorFlowDNNAutoencoder from tensorflow.contrib.learn.python.learn.estimators.autoencoder import TensorFlowDNNAutoencoder
from tensorflow.contrib.learn.python.learn.estimators.base import TensorFlowEstimator, TensorFlowBaseTransformer from tensorflow.contrib.learn.python.learn.estimators.base import TensorFlowBaseTransformer
from tensorflow.contrib.learn.python.learn.estimators.base import TensorFlowEstimator
from tensorflow.contrib.learn.python.learn.estimators.dnn import DNNClassifier from tensorflow.contrib.learn.python.learn.estimators.dnn import DNNClassifier
from tensorflow.contrib.learn.python.learn.estimators.dnn import DNNRegressor from tensorflow.contrib.learn.python.learn.estimators.dnn import DNNRegressor
from tensorflow.contrib.learn.python.learn.estimators.dnn import TensorFlowDNNClassifier from tensorflow.contrib.learn.python.learn.estimators.dnn import TensorFlowDNNClassifier

View File

@ -1,5 +1,4 @@
"""sklearn cross-support.""" # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -12,6 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""sklearn cross-support."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
@ -20,6 +22,8 @@ import collections
import os import os
import numpy as np import numpy as np
import six
def _pprint(d): def _pprint(d):
return ', '.join(['%s=%s' % (key, str(value)) for key, value in d.items()]) return ', '.join(['%s=%s' % (key, str(value)) for key, value in d.items()])
@ -102,6 +106,7 @@ class _BaseEstimator(object):
_pprint(self.get_params(deep=False)),) _pprint(self.get_params(deep=False)),)
# pylint: disable=old-style-class
class _ClassifierMixin(): class _ClassifierMixin():
"""Mixin class for all classifiers.""" """Mixin class for all classifiers."""
pass pass
@ -111,8 +116,10 @@ class _RegressorMixin():
"""Mixin class for all regression estimators.""" """Mixin class for all regression estimators."""
pass pass
class _TransformerMixin(): class _TransformerMixin():
"""Mixin class for all transformer estimators.""" """Mixin class for all transformer estimators."""
class _NotFittedError(ValueError, AttributeError): class _NotFittedError(ValueError, AttributeError):
"""Exception class to raise if estimator is used before fitting. """Exception class to raise if estimator is used before fitting.
@ -134,6 +141,8 @@ class _NotFittedError(ValueError, AttributeError):
https://github.com/scikit-learn/scikit-learn/master/sklearn/exceptions.py https://github.com/scikit-learn/scikit-learn/master/sklearn/exceptions.py
""" """
# pylint: enable=old-style-class
def _accuracy_score(y_true, y_pred): def _accuracy_score(y_true, y_pred):
score = y_true == y_pred score = y_true == y_pred
@ -149,8 +158,7 @@ def _mean_squared_error(y_true, y_pred):
def _train_test_split(*args, **options): def _train_test_split(*args, **options):
n_array = len(args) # pylint: disable=missing-docstring
test_size = options.pop('test_size', None) test_size = options.pop('test_size', None)
train_size = options.pop('train_size', None) train_size = options.pop('train_size', None)
random_state = options.pop('random_state', None) random_state = options.pop('random_state', None)
@ -159,7 +167,7 @@ def _train_test_split(*args, **options):
train_size = 0.75 train_size = 0.75
elif train_size is None: elif train_size is None:
train_size = 1 - test_size train_size = 1 - test_size
train_size = train_size * args[0].shape[0] train_size *= args[0].shape[0]
np.random.seed(random_state) np.random.seed(random_state)
indices = np.random.permutation(args[0].shape[0]) indices = np.random.permutation(args[0].shape[0])
@ -173,6 +181,7 @@ def _train_test_split(*args, **options):
# If "TENSORFLOW_SKLEARN" flag is defined then try to import from sklearn. # If "TENSORFLOW_SKLEARN" flag is defined then try to import from sklearn.
TRY_IMPORT_SKLEARN = os.environ.get('TENSORFLOW_SKLEARN', False) TRY_IMPORT_SKLEARN = os.environ.get('TENSORFLOW_SKLEARN', False)
if TRY_IMPORT_SKLEARN: if TRY_IMPORT_SKLEARN:
# pylint: disable=g-import-not-at-top,g-multiple-import,unused-import
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin, TransformerMixin from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin, TransformerMixin
from sklearn.metrics import accuracy_score, log_loss, mean_squared_error from sklearn.metrics import accuracy_score, log_loss, mean_squared_error
from sklearn.cross_validation import train_test_split from sklearn.cross_validation import train_test_split

View File

@ -1,5 +1,4 @@
"""Deep Autoencoder estimators.""" # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -12,105 +11,115 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Deep Autoencoder estimators."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.ops import nn import numpy as np
from tensorflow.contrib.learn.python.learn.estimators.base import TensorFlowBaseTransformer
from tensorflow.contrib.learn.python.learn import models from tensorflow.contrib.learn.python.learn import models
from tensorflow.contrib.learn.python.learn.estimators.base import TensorFlowBaseTransformer
from tensorflow.python.ops import nn
class TensorFlowDNNAutoencoder(TensorFlowBaseTransformer): class TensorFlowDNNAutoencoder(TensorFlowBaseTransformer):
"""TensorFlow Autoencoder Regressor model. """TensorFlow Autoencoder Regressor model.
Parameters: Parameters:
hidden_units: List of hidden units per layer. hidden_units: List of hidden units per layer.
batch_size: Mini batch size. batch_size: Mini batch size.
activation: activation function used to map inner latent layer onto activation: activation function used to map inner latent layer onto
reconstruction layer. reconstruction layer.
add_noise: a function that adds noise to tensor_in, add_noise: a function that adds noise to tensor_in,
e.g. def add_noise(x): e.g. def add_noise(x):
return(x + np.random.normal(0, 0.1, (len(x), len(x[0])))) return(x + np.random.normal(0, 0.1, (len(x), len(x[0]))))
steps: Number of steps to run over data. steps: Number of steps to run over data.
optimizer: Optimizer name (or class), for example "SGD", "Adam", optimizer: Optimizer name (or class), for example "SGD", "Adam",
"Adagrad". "Adagrad".
learning_rate: If this is constant float value, no decay function is used. learning_rate: If this is constant float value, no decay function is used.
Instead, a customized decay function can be passed that accepts Instead, a customized decay function can be passed that accepts
global_step as parameter and returns a Tensor. global_step as parameter and returns a Tensor.
e.g. exponential decay function: e.g. exponential decay function:
def exp_decay(global_step): def exp_decay(global_step):
return tf.train.exponential_decay( return tf.train.exponential_decay(
learning_rate=0.1, global_step, learning_rate=0.1, global_step,
decay_steps=2, decay_rate=0.001) decay_steps=2, decay_rate=0.001)
continue_training: when continue_training is True, once initialized continue_training: when continue_training is True, once initialized
model will be continuely trained on every call of fit. model will be continuely trained on every call of fit.
config: RunConfig object that controls the configurations of the session, config: RunConfig object that controls the configurations of the session,
e.g. num_cores, gpu_memory_fraction, etc. e.g. num_cores, gpu_memory_fraction, etc.
verbose: Controls the verbosity, possible values: verbose: Controls the verbosity, possible values:
0: the algorithm and debug information is muted. 0: the algorithm and debug information is muted.
1: trainer prints the progress. 1: trainer prints the progress.
2: log device placement is printed. 2: log device placement is printed.
dropout: When not None, the probability we will drop out a given dropout: When not None, the probability we will drop out a given
coordinate. coordinate.
""" """
def __init__(self, hidden_units, n_classes=0, batch_size=32,
steps=200, optimizer="Adagrad", learning_rate=0.1,
clip_gradients=5.0, activation=nn.relu, add_noise=None,
continue_training=False, config=None,
verbose=1, dropout=None):
self.hidden_units = hidden_units
self.dropout = dropout
self.activation = activation
self.add_noise = add_noise
super(TensorFlowDNNAutoencoder, self).__init__(
model_fn=self._model_fn,
n_classes=n_classes,
batch_size=batch_size, steps=steps, optimizer=optimizer,
learning_rate=learning_rate, clip_gradients=clip_gradients,
continue_training=continue_training,
config=config, verbose=verbose)
def _model_fn(self, X, y): def __init__(self, hidden_units, n_classes=0, batch_size=32,
encoder, decoder, autoencoder_estimator = models.get_autoencoder_model( steps=200, optimizer="Adagrad", learning_rate=0.1,
self.hidden_units, clip_gradients=5.0, activation=nn.relu, add_noise=None,
models.linear_regression, continue_training=False, config=None,
activation=self.activation, verbose=1, dropout=None):
add_noise=self.add_noise, self.hidden_units = hidden_units
dropout=self.dropout)(X) self.dropout = dropout
self.encoder = encoder self.activation = activation
self.decoder = decoder self.add_noise = add_noise
return autoencoder_estimator super(TensorFlowDNNAutoencoder, self).__init__(
model_fn=self._model_fn,
n_classes=n_classes,
batch_size=batch_size, steps=steps, optimizer=optimizer,
learning_rate=learning_rate, clip_gradients=clip_gradients,
continue_training=continue_training,
config=config, verbose=verbose)
def generate(self, hidden=None): def _model_fn(self, X, y):
"""Generate new data using trained construction layer""" encoder, decoder, autoencoder_estimator = models.get_autoencoder_model(
if hidden is None: self.hidden_units,
last_layer = len(self.hidden_units) - 1 models.linear_regression,
bias = self.get_tensor_value('encoder/dnn/layer%d/Linear/Bias:0' % last_layer) activation=self.activation,
import numpy as np add_noise=self.add_noise,
hidden = np.random.normal(size=bias.shape) dropout=self.dropout)(X)
hidden = np.reshape(hidden, (1, len(hidden))) self.encoder = encoder
return self._session.run(self.decoder, feed_dict={self.encoder: hidden}) self.decoder = decoder
return autoencoder_estimator
@property def generate(self, hidden=None):
def weights_(self): """Generate new data using trained construction layer."""
"""Returns weights of the autoencoder's weight layers.""" if hidden is None:
weights = [] last_layer = len(self.hidden_units) - 1
for layer in range(len(self.hidden_units)): bias = self.get_tensor_value(
weights.append(self.get_tensor_value('encoder/dnn/layer%d/Linear/Matrix:0' % layer)) "encoder/dnn/layer%d/Linear/Bias:0" % last_layer)
for layer in range(len(self.hidden_units)): hidden = np.random.normal(size=bias.shape)
weights.append(self.get_tensor_value('decoder/dnn/layer%d/Linear/Matrix:0' % layer)) hidden = np.reshape(hidden, (1, len(hidden)))
weights.append(self.get_tensor_value('linear_regression/weights:0')) return self._session.run(self.decoder, feed_dict={self.encoder: hidden})
return weights
@property @property
def bias_(self): def weights_(self):
"""Returns bias of the autoencoder's bias layers.""" """Returns weights of the autoencoder's weight layers."""
biases = [] weights = []
for layer in range(len(self.hidden_units)): for layer in range(len(self.hidden_units)):
biases.append(self.get_tensor_value('encoder/dnn/layer%d/Linear/Bias:0' % layer)) weights.append(self.get_tensor_value(
for layer in range(len(self.hidden_units)): "encoder/dnn/layer%d/Linear/Matrix:0" % layer))
biases.append(self.get_tensor_value('decoder/dnn/layer%d/Linear/Bias:0' % layer)) for layer in range(len(self.hidden_units)):
biases.append(self.get_tensor_value('linear_regression/bias:0')) weights.append(self.get_tensor_value(
return biases "decoder/dnn/layer%d/Linear/Matrix:0" % layer))
weights.append(self.get_tensor_value("linear_regression/weights:0"))
return weights
@property
def bias_(self):
"""Returns bias of the autoencoder's bias layers."""
biases = []
for layer in range(len(self.hidden_units)):
biases.append(self.get_tensor_value(
"encoder/dnn/layer%d/Linear/Bias:0" % layer))
for layer in range(len(self.hidden_units)):
biases.append(self.get_tensor_value(
"decoder/dnn/layer%d/Linear/Bias:0" % layer))
biases.append(self.get_tensor_value("linear_regression/bias:0"))
return biases

View File

@ -1,5 +1,4 @@
"""Base estimator class.""" # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -12,18 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Base estimator class."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import datetime
import json import json
import os import os
import shutil
from six import string_types from six import string_types
import numpy as np
from google.protobuf import text_format from google.protobuf import text_format
from tensorflow.python.platform import gfile from tensorflow.python.platform import gfile

View File

@ -1,5 +1,4 @@
"""Deep Neural Network estimators.""" # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -12,6 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Deep Neural Network estimators."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function

View File

@ -563,9 +563,13 @@ class Estimator(BaseEstimator):
input_fn=input_fn, input_fn=input_fn,
batch_size=batch_size) batch_size=batch_size)
if self._classification: if self._classification:
for key in predictions: if isinstance(predictions, dict):
cur_axis = (len(predictions[key].shape) - 1) if axis is None else axis for key in predictions:
predictions[key] = np.argmax(predictions[key], axis=cur_axis) cur_axis = (len(predictions[key].shape) - 1) if axis is None else axis
predictions[key] = np.argmax(predictions[key], axis=cur_axis)
else:
cur_axis = (len(predictions.shape) - 1) if axis is None else axis
predictions = np.argmax(predictions, axis=cur_axis)
return predictions return predictions
def predict_proba(self, x=None, input_fn=None, batch_size=None): def predict_proba(self, x=None, input_fn=None, batch_size=None):

View File

@ -36,6 +36,17 @@ def boston_input_fn():
return features, target return features, target
def iris_input_fn():
iris = tf.contrib.learn.datasets.load_iris()
features = tf.cast(
tf.reshape(
tf.constant(iris.data), [-1, 4]), tf.float32)
target = tf.cast(
tf.reshape(
tf.constant(iris.target), [-1, 1]), tf.int32)
return features, target
def boston_eval_fn(): def boston_eval_fn():
boston = tf.contrib.learn.datasets.load_boston() boston = tf.contrib.learn.datasets.load_boston()
n_examples = len(boston.target) n_examples = len(boston.target)
@ -52,6 +63,10 @@ def linear_model_fn(features, target, unused_mode):
return tf.contrib.learn.models.linear_regression_zero_init(features, target) return tf.contrib.learn.models.linear_regression_zero_init(features, target)
def logistic_model_fn(features, target, unused_mode):
return tf.contrib.learn.models.logistic_regression_zero_init(features, target)
class CheckCallsMonitor(tf.contrib.learn.monitors.BaseMonitor): class CheckCallsMonitor(tf.contrib.learn.monitors.BaseMonitor):
def __init__(self): def __init__(self):
@ -84,6 +99,15 @@ class EstimatorTest(tf.test.TestCase):
other_score = mean_squared_error(predictions, boston.target) other_score = mean_squared_error(predictions, boston.target)
self.assertAllClose(other_score, scores['mean_squared_error']) self.assertAllClose(other_score, scores['mean_squared_error'])
def testIrisAll(self):
iris = tf.contrib.learn.datasets.load_iris()
est = tf.contrib.learn.Estimator(model_fn=logistic_model_fn,
classification=True)
est.train(input_fn=iris_input_fn, steps=100)
_ = est.evaluate(input_fn=iris_input_fn, steps=1)
predictions = est.predict(x=iris.data)
self.assertEqual(predictions.shape[0], iris.target.shape[0])
def testTrainInputFn(self): def testTrainInputFn(self):
est = tf.contrib.learn.Estimator(model_fn=linear_model_fn, est = tf.contrib.learn.Estimator(model_fn=linear_model_fn,
classification=False) classification=False)

View File

@ -1,5 +1,4 @@
"""Linear Estimators.""" # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -12,6 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Linear Estimators."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function

View File

@ -1,5 +1,4 @@
"""Recurrent Neural Network estimators.""" # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -12,6 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Recurrent Neural Network estimators."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function

View File

@ -1,6 +1,4 @@
"""Implementations of different data feeders to provide data for TF trainer.""" # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -14,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Implementations of different data feeders to provide data for TF trainer."""
# TODO(ipolosukhin): Replace this module with feed-dict queue runners & queues. # TODO(ipolosukhin): Replace this module with feed-dict queue runners & queues.
from __future__ import absolute_import from __future__ import absolute_import

View File

@ -1,5 +1,4 @@
"""Various high level TF models.""" # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -12,13 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Various high level TF models."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.contrib.learn.python.learn.ops import autoencoder_ops
from tensorflow.contrib.learn.python.learn.ops import dnn_ops from tensorflow.contrib.learn.python.learn.ops import dnn_ops
from tensorflow.contrib.learn.python.learn.ops import losses_ops from tensorflow.contrib.learn.python.learn.ops import losses_ops
from tensorflow.contrib.learn.python.learn.ops import autoencoder_ops
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops as array_ops_ from tensorflow.python.ops import array_ops as array_ops_
@ -29,8 +31,7 @@ from tensorflow.python.ops import variable_scope as vs
def linear_regression_zero_init(X, y): def linear_regression_zero_init(X, y):
"""Creates a linear regression TensorFlow subgraph, in which weights and """Linear regression subgraph with zero-value initial weights and bias.
bias terms are initialized to exactly zero.
Args: Args:
X: tensor or placeholder for input features. X: tensor or placeholder for input features.
@ -43,8 +44,7 @@ def linear_regression_zero_init(X, y):
def logistic_regression_zero_init(X, y): def logistic_regression_zero_init(X, y):
"""Creates a logistic regression TensorFlow subgraph, in which weights and """Logistic regression subgraph with zero-value initial weights and bias.
bias terms are initialized to exactly zero.
Args: Args:
X: tensor or placeholder for input features. X: tensor or placeholder for input features.
@ -85,7 +85,7 @@ def linear_regression(X, y, init_mean=None, init_stddev=1.0):
else: else:
output_shape = y_shape[1] output_shape = y_shape[1]
# Set up the requested initialization. # Set up the requested initialization.
if (init_mean is None): if init_mean is None:
weights = vs.get_variable('weights', [X.get_shape()[1], output_shape]) weights = vs.get_variable('weights', [X.get_shape()[1], output_shape])
bias = vs.get_variable('bias', [output_shape]) bias = vs.get_variable('bias', [output_shape])
else: else:
@ -134,7 +134,7 @@ def logistic_regression(X,
logging_ops.histogram_summary('logistic_regression.X', X) logging_ops.histogram_summary('logistic_regression.X', X)
logging_ops.histogram_summary('logistic_regression.y', y) logging_ops.histogram_summary('logistic_regression.y', y)
# Set up the requested initialization. # Set up the requested initialization.
if (init_mean is None): if init_mean is None:
weights = vs.get_variable('weights', weights = vs.get_variable('weights',
[X.get_shape()[1], y.get_shape()[-1]]) [X.get_shape()[1], y.get_shape()[-1]])
bias = vs.get_variable('bias', [y.get_shape()[-1]]) bias = vs.get_variable('bias', [y.get_shape()[-1]])
@ -188,35 +188,37 @@ def get_dnn_model(hidden_units, target_predictor_fn, dropout=None):
return dnn_estimator return dnn_estimator
def get_autoencoder_model(hidden_units, target_predictor_fn, def get_autoencoder_model(hidden_units, target_predictor_fn,
activation, add_noise=None, dropout=None): activation, add_noise=None, dropout=None):
"""Returns a function that creates a Autoencoder TensorFlow subgraph with given """Returns a function that creates a Autoencoder TensorFlow subgraph.
params.
Args: Args:
hidden_units: List of values of hidden units for layers. hidden_units: List of values of hidden units for layers.
target_predictor_fn: Function that will predict target from input target_predictor_fn: Function that will predict target from input
features. This can be logistic regression, features. This can be logistic regression,
linear regression or any other model, linear regression or any other model,
that takes X, y and returns predictions and loss tensors. that takes X, y and returns predictions and loss
activation: activation function used to map inner latent layer onto tensors.
reconstruction layer. activation: activation function used to map inner latent layer onto
add_noise: a function that adds noise to tensor_in, reconstruction layer.
e.g. def add_noise(x): add_noise: a function that adds noise to tensor_in,
return(x + np.random.normal(0, 0.1, (len(x), len(x[0])))) e.g. def add_noise(x):
dropout: When not none, causes dropout regularization to be used, return(x + np.random.normal(0, 0.1, (len(x), len(x[0]))))
with the specified probability of removing a given coordinate. dropout: When not none, causes dropout regularization to be used,
with the specified probability of removing a given coordinate.
Returns:
A function that creates the subgraph.
"""
def dnn_autoencoder_estimator(X):
"""Autoencoder estimator with target predictor function on top."""
encoder, decoder = autoencoder_ops.dnn_autoencoder(
X, hidden_units, activation,
add_noise=add_noise, dropout=dropout)
return encoder, decoder, target_predictor_fn(X, decoder)
return dnn_autoencoder_estimator
Returns:
A function that creates the subgraph.
"""
def dnn_autoencoder_estimator(X):
"""Autoencoder estimator with target predictor function on top."""
encoder, decoder = autoencoder_ops.dnn_autoencoder(
X, hidden_units, activation,
add_noise=add_noise, dropout=dropout)
return encoder, decoder, target_predictor_fn(X, decoder)
return dnn_autoencoder_estimator
## This will be in Tensorflow 0.7. ## This will be in Tensorflow 0.7.
## TODO(ilblackdragon): Clean this up when it's released ## TODO(ilblackdragon): Clean this up when it's released

View File

@ -1,5 +1,4 @@
"""Monitors to track model training, report on progress and request early stopping""" # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -13,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Monitors to track training, report progress and request early stopping."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function

View File

@ -81,6 +81,9 @@ class GPUAllocatorRetryTest : public ::testing::Test {
return; return;
} }
} }
// Failures are more likely to occur if each consumer
// delays for a while before returning the memory.
Env::Default()->SleepForMicroseconds(500);
++consumer_count_[i]; ++consumer_count_[i];
for (int j = 0; j < cap_needed; ++j) { for (int j = 0; j < cap_needed; ++j) {
alloc_->DeallocateRaw(ptr); alloc_->DeallocateRaw(ptr);
@ -141,9 +144,10 @@ TEST_F(GPUAllocatorRetryTest, RetrySuccess) {
EXPECT_GT(consumer_count_[2], 0); EXPECT_GT(consumer_count_[2], 0);
} }
/* Disabled due to flakiness. b/24738751
// Verifies OutOfMemory failure when memory is slightly overcommitted // Verifies OutOfMemory failure when memory is slightly overcommitted
// and retry is not allowed. // and retry is not allowed. Note that this test will fail, i.e. no
// memory alloc failure will be detected, if it is run in a context that
// does not permit real multi-threaded execution.
TEST_F(GPUAllocatorRetryTest, NoRetryFail) { TEST_F(GPUAllocatorRetryTest, NoRetryFail) {
// Support up to 2 allocations simultaneously, waits up to 0 msec for // Support up to 2 allocations simultaneously, waits up to 0 msec for
// a chance to alloc. // a chance to alloc.
@ -162,7 +166,6 @@ TEST_F(GPUAllocatorRetryTest, NoRetryFail) {
EXPECT_TRUE(has_failed_); EXPECT_TRUE(has_failed_);
} }
} }
*/
// Verifies OutOfMemory failure when retry is allowed but memory capacity // Verifies OutOfMemory failure when retry is allowed but memory capacity
// is too low even for retry. // is too low even for retry.

View File

@ -32,7 +32,7 @@ static string GraphNodeName(const DotOptions& opts, const Node* n) {
return strings::StrCat("N", n->id()); return strings::StrCat("N", n->id());
} }
bool ShoulDisplayOpType(const Node* n) { bool ShouldDisplayOpType(const Node* n) {
if (n->type_string() == "NoOp") { if (n->type_string() == "NoOp") {
return false; return false;
} }
@ -125,7 +125,7 @@ string DotGraph(const Graph& g, const DotOptions& opts) {
continue; continue;
} }
string label = src->name(); string label = src->name();
if (ShoulDisplayOpType(src)) { if (ShouldDisplayOpType(src)) {
// Append the op type if it is not directly deducible from the op name. // Append the op type if it is not directly deducible from the op name.
strings::StrAppend(&label, "\\n(", src->type_string(), ")"); strings::StrAppend(&label, "\\n(", src->type_string(), ")");
} }
@ -137,7 +137,14 @@ string DotGraph(const Graph& g, const DotOptions& opts) {
shape = "oval"; shape = "oval";
} else { } else {
const string& d = src->assigned_device_name(); const string& d = src->assigned_device_name();
const int dindex = (!d.empty()) ? device_index[d] : -1;
int dindex;
if (opts.node_color) {
dindex = opts.node_color(src);
} else {
dindex = (!d.empty()) ? device_index[d] : -1;
}
if (dindex >= 0) { if (dindex >= 0) {
color = ColorFor(dindex); color = ColorFor(dindex);
} }

View File

@ -48,6 +48,11 @@ struct DotOptions {
// A function that returns the "cost" of the edge. The dot display // A function that returns the "cost" of the edge. The dot display
// makes a edge thickness proportional to its cost. // makes a edge thickness proportional to its cost.
std::function<double(const Edge*)> edge_cost; std::function<double(const Edge*)> edge_cost;
// A function that returns a color number to apply to each node. < 0 means
// no color. A color will be assigned to each color number from a palette;
// adjacent color numbers will receive different colors.
std::function<int(const Node*)> node_color;
}; };
// Return a string that contains a graphviz specification of the graph. // Return a string that contains a graphviz specification of the graph.

View File

@ -76,7 +76,7 @@ class ConcatOp : public OpKernel {
for (int d = 0; d < concat_dim; ++d) { for (int d = 0; d < concat_dim; ++d) {
inputs_flat_dim0 *= input_shape.dim_size(d); inputs_flat_dim0 *= input_shape.dim_size(d);
} }
int output_concat_dim = 0; int64 output_concat_dim = 0;
const bool input_is_scalar = IsLegacyScalar(input_shape); const bool input_is_scalar = IsLegacyScalar(input_shape);
for (int i = 0; i < N; ++i) { for (int i = 0; i < N; ++i) {
const auto in = values[i]; const auto in = values[i];

View File

@ -61,6 +61,7 @@ class TensorContractionInputMapper<
typedef SubMapper LinearMapper; typedef SubMapper LinearMapper;
typedef typename packet_traits<Scalar>::type Packet; typedef typename packet_traits<Scalar>::type Packet;
EIGEN_DEVICE_FUNC
TensorContractionInputMapper( TensorContractionInputMapper(
const TensorEvaluator< const TensorEvaluator<
const TensorReshapingOp< const TensorReshapingOp<
@ -77,7 +78,7 @@ class TensorContractionInputMapper<
m_patch_cols = tensor.impl().dimensions()[2]; m_patch_cols = tensor.impl().dimensions()[2];
m_num_patches = tensor.impl().dimensions()[3]; m_num_patches = tensor.impl().dimensions()[3];
} else { } else {
static const int NumDims = tensor.impl().dimensions().size(); const int NumDims = tensor.impl().dimensions().size();
patch_depth = tensor.impl().dimensions()[NumDims - 1]; patch_depth = tensor.impl().dimensions()[NumDims - 1];
patch_rows = tensor.impl().dimensions()[NumDims - 2]; patch_rows = tensor.impl().dimensions()[NumDims - 2];
m_patch_cols = tensor.impl().dimensions()[NumDims - 3]; m_patch_cols = tensor.impl().dimensions()[NumDims - 3];
@ -99,7 +100,7 @@ class TensorContractionInputMapper<
m_inputRows = tensor.impl().impl().dimensions()[1]; m_inputRows = tensor.impl().impl().dimensions()[1];
m_inputCols = tensor.impl().impl().dimensions()[2]; m_inputCols = tensor.impl().impl().dimensions()[2];
} else { } else {
static const int NumDims = tensor.impl().impl().dimensions().size(); const int NumDims = tensor.impl().impl().dimensions().size();
m_inputRows = tensor.impl().impl().dimensions()[NumDims - 2]; m_inputRows = tensor.impl().impl().dimensions()[NumDims - 2];
m_inputCols = tensor.impl().impl().dimensions()[NumDims - 3]; m_inputCols = tensor.impl().impl().dimensions()[NumDims - 3];
} }
@ -121,6 +122,7 @@ class TensorContractionInputMapper<
m_fastDimZero = internal::TensorIntDivisor<Index>(patch_depth); m_fastDimZero = internal::TensorIntDivisor<Index>(patch_depth);
} }
EIGEN_DEVICE_FUNC
TensorContractionInputMapper(const TensorContractionInputMapper& base_mapper) TensorContractionInputMapper(const TensorContractionInputMapper& base_mapper)
: m_impl(base_mapper.m_impl) { : m_impl(base_mapper.m_impl) {
m_patch_cols = base_mapper.m_patch_cols; m_patch_cols = base_mapper.m_patch_cols;
@ -650,8 +652,10 @@ struct gemm_pack_rhs<
SubMapper; SubMapper;
typedef SubMapper DataMapper; typedef SubMapper DataMapper;
EIGEN_DEVICE_FUNC
static inline Index ceil_div(Index a, Index b) { return (a + b - 1) / b; } static inline Index ceil_div(Index a, Index b) { return (a + b - 1) / b; }
EIGEN_DEVICE_FUNC
EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs, EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
Index depth, Index cols, Index stride = 0, Index depth, Index cols, Index stride = 0,
Index offset = 0) const { Index offset = 0) const {
@ -822,8 +826,10 @@ struct gemm_pack_rhs<
SubMapper; SubMapper;
typedef SubMapper DataMapper; typedef SubMapper DataMapper;
EIGEN_DEVICE_FUNC
static inline Index ceil_div(Index a, Index b) { return (a + b - 1) / b; } static inline Index ceil_div(Index a, Index b) { return (a + b - 1) / b; }
EIGEN_DEVICE_FUNC
EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs, EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
Index depth, Index cols, Index stride = 0, Index depth, Index cols, Index stride = 0,
Index offset = 0) const { Index offset = 0) const {
@ -898,36 +904,40 @@ struct gemm_pack_rhs<
* *
*/ */
template <typename Input, typename Kernel> template <typename Input, typename Kernel>
EIGEN_ALWAYS_INLINE static const typename internal::conditional< EIGEN_DEVICE_FUNC
internal::traits<Input>::Layout == ColMajor, EIGEN_ALWAYS_INLINE static const typename internal::conditional<
TensorReshapingOp< internal::traits<Input>::Layout == ColMajor,
const DSizes<typename internal::traits<Input>::Index, TensorReshapingOp<
internal::traits<Input>::NumDimensions>, const DSizes<typename internal::traits<Input>::Index,
const TensorContractionOp< internal::traits<Input>::NumDimensions>,
const array<IndexPair<typename internal::traits<Input>::Index>, 1>, const TensorContractionOp<
const TensorReshapingOp< const array<IndexPair<typename internal::traits<Input>::Index>,
const DSizes<typename internal::traits<Input>::Index, 2>, 1>,
const Kernel>, const TensorReshapingOp<
const TensorReshapingOp< const DSizes<typename internal::traits<Input>::Index, 2>,
const DSizes<typename internal::traits<Input>::Index, 2>, const Kernel>,
const TensorImagePatchOp<Dynamic, Dynamic, const Input> > > >, const TensorReshapingOp<
TensorReshapingOp< const DSizes<typename internal::traits<Input>::Index, 2>,
const DSizes<typename internal::traits<Input>::Index, const TensorImagePatchOp<Dynamic, Dynamic,
internal::traits<Input>::NumDimensions>, const Input> > > >,
const TensorContractionOp< TensorReshapingOp<
const array<IndexPair<typename internal::traits<Input>::Index>, 1>, const DSizes<typename internal::traits<Input>::Index,
const TensorReshapingOp< internal::traits<Input>::NumDimensions>,
const DSizes<typename internal::traits<Input>::Index, 2>, const TensorContractionOp<
const TensorImagePatchOp<Dynamic, Dynamic, const Input> >, const array<IndexPair<typename internal::traits<Input>::Index>,
const TensorReshapingOp< 1>,
const DSizes<typename internal::traits<Input>::Index, 2>, const TensorReshapingOp<
const Kernel> > > >::type const DSizes<typename internal::traits<Input>::Index, 2>,
SpatialConvolution(const Input& input, const Kernel& kernel, const TensorImagePatchOp<Dynamic, Dynamic, const Input> >,
const DenseIndex row_stride = 1, const TensorReshapingOp<
const DenseIndex col_stride = 1, const DSizes<typename internal::traits<Input>::Index, 2>,
const PaddingType padding_type = PADDING_SAME, const Kernel> > > >::type
const DenseIndex row_in_stride = 1, SpatialConvolution(const Input& input, const Kernel& kernel,
const DenseIndex col_in_stride = 1) { const DenseIndex row_stride = 1,
const DenseIndex col_stride = 1,
const PaddingType padding_type = PADDING_SAME,
const DenseIndex row_in_stride = 1,
const DenseIndex col_in_stride = 1) {
typedef typename internal::traits<Input>::Index TensorIndex; typedef typename internal::traits<Input>::Index TensorIndex;
TensorRef<Tensor<typename internal::traits<Input>::Scalar, TensorRef<Tensor<typename internal::traits<Input>::Scalar,
internal::traits<Input>::NumDimensions, internal::traits<Input>::NumDimensions,
@ -941,9 +951,9 @@ SpatialConvolution(const Input& input, const Kernel& kernel,
EIGEN_STATIC_ASSERT( EIGEN_STATIC_ASSERT(
internal::traits<Input>::Layout == internal::traits<Kernel>::Layout, internal::traits<Input>::Layout == internal::traits<Kernel>::Layout,
YOU_MADE_A_PROGRAMMING_MISTAKE); YOU_MADE_A_PROGRAMMING_MISTAKE);
static const bool isColMajor = (internal::traits<Input>::Layout == ColMajor); const bool isColMajor = (internal::traits<Input>::Layout == ColMajor);
static const int NumDims = internal::traits<Input>::NumDimensions; const int NumDims = internal::traits<Input>::NumDimensions;
// Number of filters to apply. This is the same as the output depth of the // Number of filters to apply. This is the same as the output depth of the
// result // result

View File

@ -46,6 +46,19 @@ static inline void RandomShuffle(Iter first, Iter last, Random& uniform) {
} }
} }
template <class IntT, class InT, class OutT, class Random>
static void IndexedShuffle(const int64 size, const InT& input_mat,
OutT output_mat, Random& uniform) {
std::vector<IntT> permutation(size);
for (IntT i = 0; i < size; i++) {
permutation[i] = i;
}
RandomShuffle(permutation.begin(), permutation.end(), uniform);
for (IntT i = 0; i < size; i++) {
output_mat.template chip<0>(i) = input_mat.template chip<0>(permutation[i]);
}
}
template <typename T> template <typename T>
class RandomShuffleOp : public OpKernel { class RandomShuffleOp : public OpKernel {
public: public:
@ -79,14 +92,10 @@ class RandomShuffleOp : public OpKernel {
context->allocate_output(0, input.shape(), &output)); context->allocate_output(0, input.shape(), &output));
const auto input_mat = input.flat_outer_dims<T>(); const auto input_mat = input.flat_outer_dims<T>();
auto output_mat = output->flat_outer_dims<T>(); auto output_mat = output->flat_outer_dims<T>();
std::vector<int> permutation(size); if (size < kint32max) {
for (int i = 0; i < size; i++) { IndexedShuffle<int32>(size, input_mat, output_mat, uniform);
permutation[i] = i; } else {
} IndexedShuffle<int64>(size, input_mat, output_mat, uniform);
RandomShuffle(permutation.begin(), permutation.end(), uniform);
for (int i = 0; i < size; i++) {
output_mat.template chip<0>(i) =
input_mat.template chip<0>(permutation[i]);
} }
} }
} }

View File

@ -74,6 +74,14 @@ class SparseReduceSumOp : public OpKernel {
std::vector<int32> axes(num_reduction_axes); std::vector<int32> axes(num_reduction_axes);
std::copy_n(reduction_axes_t->flat<int32>().data(), num_reduction_axes, std::copy_n(reduction_axes_t->flat<int32>().data(), num_reduction_axes,
axes.begin()); axes.begin());
for (int i = 0; i < num_reduction_axes; ++i) {
int32 axis = axes[i];
OP_REQUIRES(
ctx, axis >= -ndims && axis < ndims,
errors::InvalidArgument("Invalid reduction dimension ", axis,
", for input with ", ndims, " dimensions."));
axes[i] = (axes[i] + ndims) % ndims;
}
std::sort(axes.begin(), axes.end()); std::sort(axes.begin(), axes.end());
std::vector<int64> group_by_dims; std::vector<int64> group_by_dims;

View File

@ -430,7 +430,8 @@ Reduces `sp_input` along the dimensions given in `reduction_axes`. Unless
with length 1. with length 1.
If `reduction_axes` has no entries, all dimensions are reduced, and a tensor If `reduction_axes` has no entries, all dimensions are reduced, and a tensor
with a single element is returned. with a single element is returned. Additionally, the axes can be negative,
which are interpreted according to the indexing rules in Python.
input_indices: 2-D. `N x R` matrix with the indices of non-empty values in a input_indices: 2-D. `N x R` matrix with the indices of non-empty values in a
SparseTensor, possibly not in canonical ordering. SparseTensor, possibly not in canonical ordering.

View File

@ -1,4 +1,4 @@
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.

View File

@ -1,4 +1,4 @@
# Copyright 2015-present The Scikit Flow Authors. All Rights Reserved. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -12,147 +12,155 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" """This example builds deep residual network for mnist data.
This example builds deep residual network for mnist data.
Reference Paper: http://arxiv.org/pdf/1512.03385.pdf Reference Paper: http://arxiv.org/pdf/1512.03385.pdf
Note that this is still a work-in-progress. Feel free to submit a PR Note that this is still a work-in-progress. Feel free to submit a PR
to make this better. to make this better.
""" """
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os
from collections import namedtuple from collections import namedtuple
from math import sqrt from math import sqrt
import os
from sklearn import metrics from sklearn import metrics
import tensorflow as tf import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.contrib import learn from tensorflow.contrib import learn
from tensorflow.examples.tutorials.mnist import input_data
def res_net(x, y, activation=tf.nn.relu): def res_net(x, y, activation=tf.nn.relu):
"""Builds a residual network. Note that if the input tensor is 2D, it must be """Builds a residual network.
square in order to be converted to a 4D tensor.
Borrowed structure from here: https://github.com/pkmital/tensorflow_tutorials/blob/master/10_residual_network.py Note that if the input tensor is 2D, it must be square in order to be
converted to a 4D tensor.
Args: Borrowed structure from:
x: Input of the network github.com/pkmital/tensorflow_tutorials/blob/master/10_residual_network.py
y: Output of the network
activation: Activation function to apply after each convolution
"""
# Configurations for each bottleneck block Args:
BottleneckBlock = namedtuple( x: Input of the network
'BottleneckBlock', ['num_layers', 'num_filters', 'bottleneck_size']) y: Output of the network
blocks = [BottleneckBlock(3, 128, 32), activation: Activation function to apply after each convolution
BottleneckBlock(3, 256, 64),
BottleneckBlock(3, 512, 128),
BottleneckBlock(3, 1024, 256)]
input_shape = x.get_shape().as_list() Returns:
Predictions and loss tensors.
"""
# Reshape the input into the right shape if it's 2D tensor # Configurations for each bottleneck block.
if len(input_shape) == 2: BottleneckBlock = namedtuple(
ndim = int(sqrt(input_shape[1])) 'BottleneckBlock', ['num_layers', 'num_filters', 'bottleneck_size'])
x = tf.reshape(x, [-1, ndim, ndim, 1]) blocks = [BottleneckBlock(3, 128, 32),
BottleneckBlock(3, 256, 64),
BottleneckBlock(3, 512, 128),
BottleneckBlock(3, 1024, 256)]
# First convolution expands to 64 channels input_shape = x.get_shape().as_list()
with tf.variable_scope('conv_layer1'):
net = learn.ops.conv2d(x, 64, [7, 7], batch_norm=True,
activation=activation, bias=False)
# Max pool # Reshape the input into the right shape if it's 2D tensor
net = tf.nn.max_pool( if len(input_shape) == 2:
net, [1, 3, 3, 1], strides=[1, 2, 2, 1], padding='SAME') ndim = int(sqrt(input_shape[1]))
x = tf.reshape(x, [-1, ndim, ndim, 1])
# First chain of resnets # First convolution expands to 64 channels
with tf.variable_scope('conv_layer2'): with tf.variable_scope('conv_layer1'):
net = learn.ops.conv2d(net, blocks[0].num_filters, net = learn.ops.conv2d(x, 64, [7, 7], batch_norm=True,
[1, 1], [1, 1, 1, 1], activation=activation, bias=False)
padding='VALID', bias=True)
# Create each bottleneck building block for each layer # Max pool
for block_i, block in enumerate(blocks): net = tf.nn.max_pool(
for layer_i in range(block.num_layers): net, [1, 3, 3, 1], strides=[1, 2, 2, 1], padding='SAME')
name = 'block_%d/layer_%d' % (block_i, layer_i) # First chain of resnets
with tf.variable_scope('conv_layer2'):
net = learn.ops.conv2d(net, blocks[0].num_filters,
[1, 1], [1, 1, 1, 1],
padding='VALID', bias=True)
# 1x1 convolution responsible for reducing dimension # Create each bottleneck building block for each layer
with tf.variable_scope(name + '/conv_in'): for block_i, block in enumerate(blocks):
conv = learn.ops.conv2d(net, block.bottleneck_size, for layer_i in range(block.num_layers):
[1, 1], [1, 1, 1, 1],
padding='VALID',
activation=activation,
batch_norm=True,
bias=False)
with tf.variable_scope(name + '/conv_bottleneck'): name = 'block_%d/layer_%d' % (block_i, layer_i)
conv = learn.ops.conv2d(conv, block.bottleneck_size,
[3, 3], [1, 1, 1, 1],
padding='SAME',
activation=activation,
batch_norm=True,
bias=False)
# 1x1 convolution responsible for restoring dimension # 1x1 convolution responsible for reducing dimension
with tf.variable_scope(name + '/conv_out'): with tf.variable_scope(name + '/conv_in'):
conv = learn.ops.conv2d(conv, block.num_filters, conv = learn.ops.conv2d(net, block.bottleneck_size,
[1, 1], [1, 1, 1, 1], [1, 1], [1, 1, 1, 1],
padding='VALID', padding='VALID',
activation=activation, activation=activation,
batch_norm=True, batch_norm=True,
bias=False) bias=False)
# shortcut connections that turn the network into its counterpart with tf.variable_scope(name + '/conv_bottleneck'):
# residual function (identity shortcut) conv = learn.ops.conv2d(conv, block.bottleneck_size,
net = conv + net [3, 3], [1, 1, 1, 1],
padding='SAME',
activation=activation,
batch_norm=True,
bias=False)
try: # 1x1 convolution responsible for restoring dimension
# upscale to the next block size with tf.variable_scope(name + '/conv_out'):
next_block = blocks[block_i + 1] conv = learn.ops.conv2d(conv, block.num_filters,
with tf.variable_scope('block_%d/conv_upscale' % block_i): [1, 1], [1, 1, 1, 1],
net = learn.ops.conv2d(net, next_block.num_filters, padding='VALID',
[1, 1], [1, 1, 1, 1], activation=activation,
bias=False, batch_norm=True,
padding='SAME') bias=False)
except IndexError:
pass
net_shape = net.get_shape().as_list() # shortcut connections that turn the network into its counterpart
net = tf.nn.avg_pool(net, # residual function (identity shortcut)
ksize=[1, net_shape[1], net_shape[2], 1], net = conv + net
strides=[1, 1, 1, 1], padding='VALID')
net_shape = net.get_shape().as_list() try:
net = tf.reshape(net, [-1, net_shape[1] * net_shape[2] * net_shape[3]]) # upscale to the next block size
next_block = blocks[block_i + 1]
with tf.variable_scope('block_%d/conv_upscale' % block_i):
net = learn.ops.conv2d(net, next_block.num_filters,
[1, 1], [1, 1, 1, 1],
bias=False,
padding='SAME')
except IndexError:
pass
return learn.models.logistic_regression(net, y) net_shape = net.get_shape().as_list()
net = tf.nn.avg_pool(net,
ksize=[1, net_shape[1], net_shape[2], 1],
strides=[1, 1, 1, 1], padding='VALID')
net_shape = net.get_shape().as_list()
net = tf.reshape(net, [-1, net_shape[1] * net_shape[2] * net_shape[3]])
return learn.models.logistic_regression(net, y)
# Download and load MNIST data. # Download and load MNIST data.
mnist = input_data.read_data_sets('MNIST_data') mnist = input_data.read_data_sets('MNIST_data')
# Restore model if graph is saved into a folder. # Restore model if graph is saved into a folder.
if os.path.exists("models/resnet/graph.pbtxt"): if os.path.exists('models/resnet/graph.pbtxt'):
classifier = learn.TensorFlowEstimator.restore("models/resnet/") classifier = learn.TensorFlowEstimator.restore('models/resnet/')
else: else:
# Create a new resnet classifier. # Create a new resnet classifier.
classifier = learn.TensorFlowEstimator( classifier = learn.TensorFlowEstimator(
model_fn=res_net, n_classes=10, batch_size=100, steps=100, model_fn=res_net, n_classes=10, batch_size=100, steps=100,
learning_rate=0.001, continue_training=True) learning_rate=0.001, continue_training=True)
while True: while True:
# Train model and save summaries into logdir. # Train model and save summaries into logdir.
classifier.fit(mnist.train.images, mnist.train.labels, logdir="models/resnet/") classifier.fit(
mnist.train.images, mnist.train.labels, logdir='models/resnet/')
# Calculate accuracy. # Calculate accuracy.
score = metrics.accuracy_score( score = metrics.accuracy_score(
mnist.test.labels, classifier.predict(mnist.test.images, batch_size=64)) mnist.test.labels, classifier.predict(mnist.test.images, batch_size=64))
print('Accuracy: {0:f}'.format(score)) print('Accuracy: {0:f}'.format(score))
# Save model graph and checkpoints. # Save model graph and checkpoints.
classifier.save("models/resnet/") classifier.save('models/resnet/')

View File

@ -74,7 +74,7 @@ example_names: ["input0", "input1"],
features: { features: {
"kw": VarLenFeature(tf.string), "kw": VarLenFeature(tf.string),
"dank": VarLenFeature(tf.int64), "dank": VarLenFeature(tf.int64),
"gps": VarLenFeature(tf.float), "gps": VarLenFeature(tf.float32),
} }
``` ```

View File

@ -1289,7 +1289,7 @@ example_names: ["input0", "input1"],
features: { features: {
"kw": VarLenFeature(tf.string), "kw": VarLenFeature(tf.string),
"dank": VarLenFeature(tf.int64), "dank": VarLenFeature(tf.int64),
"gps": VarLenFeature(tf.float), "gps": VarLenFeature(tf.float32),
} }
``` ```

View File

@ -1916,10 +1916,18 @@ class Graph(object):
def __init__(self): def __init__(self):
"""Creates a new, empty Graph.""" """Creates a new, empty Graph."""
self._nodes_by_id = dict() # Protects the core state that may be accessed by multiple readers.
self._next_node_id = [dict()] # Only state that can be returned via public accessors (`as_graph_def()`,
self._next_id_counter = 0 # `get_operations()`, `as_graph_element()`, `get_collection()`, and
self._nodes_by_name = dict() # `get_collection_ref()`) is by the lock. Thread-safety is provided on a
# best-effort basis to support buggy programs, and is not guaranteed by the
# public `tf.Graph` API.
# NOTE(mrry): This does not protect the various stacks. A warning will
# be reported if these are used from multiple threads
self._lock = threading.Lock()
self._nodes_by_id = dict() # GUARDED_BY(self._lock)
self._next_id_counter = 0 # GUARDED_BY(self._lock)
self._nodes_by_name = dict() # GUARDED_BY(self._lock)
# Current name stack: uniquified names # Current name stack: uniquified names
self._name_stack = "" self._name_stack = ""
# Maps a name used in the graph to the next id to use for that name. # Maps a name used in the graph to the next id to use for that name.
@ -1987,15 +1995,15 @@ class Graph(object):
self._check_not_finalized() self._check_not_finalized()
if not isinstance(op, (Tensor, Operation)): if not isinstance(op, (Tensor, Operation)):
raise TypeError("op must be a Tensor or Operation: %s" % op) raise TypeError("op must be a Tensor or Operation: %s" % op)
with self._lock:
if op._id in self._nodes_by_id: if op._id in self._nodes_by_id:
raise ValueError("cannot add an op with id %d as it already " raise ValueError("cannot add an op with id %d as it already "
"exists in the graph" % op._id) "exists in the graph" % op._id)
if op.name in self._nodes_by_name: if op.name in self._nodes_by_name:
raise ValueError("cannot add op with name %s as that name " raise ValueError("cannot add op with name %s as that name "
"is already used" % op.name) "is already used" % op.name)
self._nodes_by_id[op._id] = op self._nodes_by_id[op._id] = op
self._nodes_by_name[op.name] = op self._nodes_by_name[op.name] = op
@property @property
def version(self): def version(self):
@ -2081,31 +2089,32 @@ class Graph(object):
Raises: Raises:
ValueError: If the `graph_def` would be too large. ValueError: If the `graph_def` would be too large.
""" """
graph = graph_pb2.GraphDef() with self._lock:
graph.versions.CopyFrom(self._graph_def_versions) graph = graph_pb2.GraphDef()
bytesize = 0 graph.versions.CopyFrom(self._graph_def_versions)
for op_id in sorted(self._nodes_by_id): bytesize = 0
op = self._nodes_by_id[op_id] for op_id in sorted(self._nodes_by_id):
if from_version is None or op_id > from_version: op = self._nodes_by_id[op_id]
graph.node.extend([op.node_def]) if from_version is None or op_id > from_version:
if op.outputs and add_shapes: graph.node.extend([op.node_def])
assert "_output_shapes" not in graph.node[-1].attr if op.outputs and add_shapes:
graph.node[-1].attr["_output_shapes"].list.shape.extend([ assert "_output_shapes" not in graph.node[-1].attr
output.get_shape().as_proto() for output in op.outputs]) graph.node[-1].attr["_output_shapes"].list.shape.extend([
bytesize += op.node_def.ByteSize() output.get_shape().as_proto() for output in op.outputs])
if bytesize >= (1 << 31) or bytesize < 0: bytesize += op.node_def.ByteSize()
raise ValueError("GraphDef cannot be larger than 2GB.") if bytesize >= (1 << 31) or bytesize < 0:
if self._functions: raise ValueError("GraphDef cannot be larger than 2GB.")
for f in self._functions.values(): if self._functions:
bytesize += f.ByteSize() for f in self._functions.values():
if bytesize >= (1 << 31) or bytesize < 0: bytesize += f.ByteSize()
raise ValueError("GraphDef cannot be larger than 2GB.") if bytesize >= (1 << 31) or bytesize < 0:
graph.library.function.extend(self._functions.values()) raise ValueError("GraphDef cannot be larger than 2GB.")
for func in self._function_gradient: graph.library.function.extend(self._functions.values())
grad_def = function_pb2.GradientDef() for func in self._function_gradient:
grad_def.function_name = func grad_def = function_pb2.GradientDef()
grad_def.gradient_func = self._function_gradient[func] grad_def.function_name = func
graph.library.gradient.extend([grad_def]) grad_def.gradient_func = self._function_gradient[func]
graph.library.gradient.extend([grad_def])
return graph return graph
@ -2298,7 +2307,11 @@ class Graph(object):
example, an invalid string. example, an invalid string.
KeyError: If `obj` is not an object in the graph. KeyError: If `obj` is not an object in the graph.
""" """
with self._lock:
return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
def _as_graph_element_locked(self, obj, allow_tensor, allow_operation):
"""See `Graph.as_graph_element()` for details."""
# The vast majority of this function is figuring # The vast majority of this function is figuring
# out what an API user might be doing wrong, so # out what an API user might be doing wrong, so
# that we can give helpful error messages. # that we can give helpful error messages.
@ -2398,7 +2411,8 @@ class Graph(object):
Returns: Returns:
A list of Operations. A list of Operations.
""" """
return list(self._nodes_by_id.values()) with self._lock:
return list(self._nodes_by_id.values())
def get_operation_by_name(self, name): def get_operation_by_name(self, name):
"""Returns the `Operation` with the given `name`. """Returns the `Operation` with the given `name`.
@ -2445,8 +2459,9 @@ class Graph(object):
def _next_id(self): def _next_id(self):
"""Id for next Operation instance. Also increments the internal id.""" """Id for next Operation instance. Also increments the internal id."""
self._check_not_finalized() self._check_not_finalized()
self._next_id_counter += 1 with self._lock:
return self._next_id_counter self._next_id_counter += 1
return self._next_id_counter
@property @property
def _last_id(self): def _last_id(self):
@ -2499,10 +2514,11 @@ class Graph(object):
value: The value to add to the collection. value: The value to add to the collection.
""" """
self._check_not_finalized() self._check_not_finalized()
if name not in self._collections: with self._lock:
self._collections[name] = [value] if name not in self._collections:
else: self._collections[name] = [value]
self._collections[name].append(value) else:
self._collections[name].append(value)
def add_to_collections(self, names, value): def add_to_collections(self, names, value):
"""Stores `value` in the collections given by `names`. """Stores `value` in the collections given by `names`.
@ -2543,11 +2559,12 @@ class Graph(object):
The list of values in the collection with the given `name`, or an empty The list of values in the collection with the given `name`, or an empty
list if no value has been added to that collection. list if no value has been added to that collection.
""" """
coll_list = self._collections.get(name, None) with self._lock:
if coll_list is None: coll_list = self._collections.get(name, None)
coll_list = [] if coll_list is None:
self._collections[name] = coll_list coll_list = []
return coll_list self._collections[name] = coll_list
return coll_list
def get_collection(self, name, scope=None): def get_collection(self, name, scope=None):
"""Returns a list of values in the collection with the given `name`. """Returns a list of values in the collection with the given `name`.
@ -2571,22 +2588,24 @@ class Graph(object):
list contains the values in the order under which they were list contains the values in the order under which they were
collected. collected.
""" """
coll_list = self._collections.get(name, None) with self._lock:
if coll_list is None: coll_list = self._collections.get(name, None)
return [] if coll_list is None:
if scope is None: return []
return list(coll_list) if scope is None:
else: return list(coll_list)
c = [] else:
regex = re.compile(scope) c = []
for item in coll_list: regex = re.compile(scope)
if hasattr(item, "name") and regex.match(item.name): for item in coll_list:
c.append(item) if hasattr(item, "name") and regex.match(item.name):
return c c.append(item)
return c
def get_all_collection_keys(self): def get_all_collection_keys(self):
"""Returns a list of collections used in this graph.""" """Returns a list of collections used in this graph."""
return [x for x in self._collections if isinstance(x, six.string_types)] with self._lock:
return [x for x in self._collections if isinstance(x, six.string_types)]
@contextlib.contextmanager @contextlib.contextmanager
def _original_op(self, op): def _original_op(self, op):

View File

@ -412,6 +412,17 @@ class ConcatOpTest(tf.test.TestCase):
self.assertEqual(n + 3, after - before) self.assertEqual(n + 3, after - before)
print("graph = ", [x.name for x in g.get_operations()]) print("graph = ", [x.name for x in g.get_operations()])
def testConcatLargeTensors(self):
# CPU-only test, because it fails on GPUs with <= 4GB memory.
with tf.device("/cpu:0"):
a = tf.ones([2**31 + 6], dtype=tf.int8)
b = tf.zeros([1024], dtype=tf.int8)
onezeros = tf.concat(0, [a, b])
with self.test_session(use_gpu=False):
# TODO(dga): Add more depth to this test to validate correctness,
# not just non-crashingness, once other large tensor fixes have gone in.
_ = onezeros.eval()
class ConcatOffsetTest(tf.test.TestCase): class ConcatOffsetTest(tf.test.TestCase):

View File

@ -158,14 +158,14 @@ class MatMulTest(tf.test.TestCase):
def testComplex64Random(self): def testComplex64Random(self):
for _ in range(10): for _ in range(10):
n, k, m = np.random.randint(1, 100, size=3) n, k, m = np.random.randint(1, 10, size=3) # Smaller range than float
x = self._randMatrix(n, k, np.complex64) x = self._randMatrix(n, k, np.complex64)
y = self._randMatrix(k, m, np.complex64) y = self._randMatrix(k, m, np.complex64)
self._testCpuMatmul(x, y) self._testCpuMatmul(x, y)
def testComplex128Random(self): def testComplex128Random(self):
for _ in range(10): for _ in range(10):
n, k, m = np.random.randint(1, 100, size=3) n, k, m = np.random.randint(1, 10, size=3) # Smaller range than float
x = self._randMatrix(n, k, np.complex128) x = self._randMatrix(n, k, np.complex128)
y = self._randMatrix(k, m, np.complex128) y = self._randMatrix(k, m, np.complex128)
self._testCpuMatmul(x, y) self._testCpuMatmul(x, y)

View File

@ -417,16 +417,27 @@ class SparseFillEmptyRowsTest(test_util.TensorFlowTestCase):
class SparseReduceSumTest(test_util.TensorFlowTestCase): class SparseReduceSumTest(test_util.TensorFlowTestCase):
def _compare(self, sp_t, reduction_axes, keep_dims): # [[1, ?, 1]
# [?, 1, ?]]
# where ? is implictly-zero.
ind = np.array([[0, 0], [0, 2], [1, 1]]).astype(np.int64)
vals = np.array([1, 1, 1]).astype(np.int32)
shape = np.array([2, 3]).astype(np.int64)
def _compare(self, sp_t, reduction_axes, ndims, keep_dims):
densified = sparse_ops.sparse_tensor_to_dense(sp_t).eval() densified = sparse_ops.sparse_tensor_to_dense(sp_t).eval()
np_ans = densified np_ans = densified
if reduction_axes is None: if reduction_axes is None:
np_ans = np.sum(np_ans, keepdims=keep_dims) np_ans = np.sum(np_ans, keepdims=keep_dims)
else: else:
if isinstance(reduction_axes, list): if not isinstance(reduction_axes, list): # Single scalar.
reduction_axes = sorted(reduction_axes) # loop below depends on sorted reduction_axes = [reduction_axes]
reduction_axes = np.array(reduction_axes).astype(np.int32) reduction_axes = np.array(reduction_axes).astype(np.int32)
# Handles negative axes.
reduction_axes = (reduction_axes + ndims) % ndims
# Loop below depends on sorted.
reduction_axes.sort()
for ra in reduction_axes.ravel()[::-1]: for ra in reduction_axes.ravel()[::-1]:
np_ans = np.sum(np_ans, axis=ra, keepdims=keep_dims) np_ans = np.sum(np_ans, axis=ra, keepdims=keep_dims)
@ -436,25 +447,21 @@ class SparseReduceSumTest(test_util.TensorFlowTestCase):
self.assertAllClose(np_ans, out) self.assertAllClose(np_ans, out)
def _compare_all(self, sp_t, reduction_axes): def _compare_all(self, sp_t, reduction_axes, ndims):
self._compare(sp_t, reduction_axes, False) self._compare(sp_t, reduction_axes, ndims, False)
self._compare(sp_t, reduction_axes, True) self._compare(sp_t, reduction_axes, ndims, True)
def testSimpleAndRandomInputs(self): def testSimpleAndRandomInputs(self):
# [[1, ?, 1] sp_t = ops.SparseTensor(self.ind, self.vals, self.shape)
# [?, 1, ?]]
# where ? is implictly-zero.
ind = np.array([[0, 0], [0, 2], [1, 1]]).astype(np.int64)
vals = np.array([1, 1, 1]).astype(np.int32)
shape = np.array([2, 3]).astype(np.int64)
sp_t = ops.SparseTensor(ind, vals, shape)
with self.test_session(use_gpu=False): with self.test_session(use_gpu=False):
self._compare_all(sp_t, None) self._compare_all(sp_t, None, ndims=2)
self._compare_all(sp_t, 0) self._compare_all(sp_t, 0, ndims=2)
self._compare_all(sp_t, [1]) self._compare_all(sp_t, [1], ndims=2)
self._compare_all(sp_t, [0, 1]) self._compare_all(sp_t, [0, 1], ndims=2)
self._compare_all(sp_t, [1, 0]) self._compare_all(sp_t, [1, 0], ndims=2)
self._compare_all(sp_t, [-1], ndims=2)
self._compare_all(sp_t, [1, -2], ndims=2)
np.random.seed(1618) np.random.seed(1618)
test_dims = [(1618, 1, 11, 7, 1), (1,), (1, 1, 1)] test_dims = [(1618, 1, 11, 7, 1), (1,), (1, 1, 1)]
@ -462,11 +469,19 @@ class SparseReduceSumTest(test_util.TensorFlowTestCase):
for dims in test_dims: for dims in test_dims:
sp_t, unused_nnz = _sparsify(np.random.randn(*dims)) sp_t, unused_nnz = _sparsify(np.random.randn(*dims))
# reduce all using None # reduce all using None
self._compare_all(sp_t, None) self._compare_all(sp_t, None, ndims=len(dims))
# reduce random axes from 1D to N-D # reduce random axes from 1D to N-D
for d in range(1, len(dims) + 1): for d in range(1, len(dims) + 1):
axes = np.random.choice(len(dims), size=d, replace=False).tolist() axes = np.random.choice(len(dims), size=d, replace=False).tolist()
self._compare_all(sp_t, axes) self._compare_all(sp_t, axes, ndims=len(dims))
def testInvalidAxes(self):
sp_t = ops.SparseTensor(self.ind, self.vals, self.shape)
with self.test_session(use_gpu=False):
with self.assertRaisesOpError("Invalid reduction dimension -3"):
sparse_ops.sparse_reduce_sum(sp_t, -3).eval()
with self.assertRaisesOpError("Invalid reduction dimension 2"):
sparse_ops.sparse_reduce_sum(sp_t, 2).eval()
def testGradient(self): def testGradient(self):
np.random.seed(8161) np.random.seed(8161)
@ -483,6 +498,12 @@ class SparseReduceSumTest(test_util.TensorFlowTestCase):
reduced.eval().shape) reduced.eval().shape)
self.assertLess(err, 1e-3) self.assertLess(err, 1e-3)
# Tests for negative axes.
reduced = sparse_ops.sparse_reduce_sum(sp_t, -1)
err = tf.test.compute_gradient_error(sp_t.values, (nnz,), reduced,
reduced.eval().shape)
self.assertLess(err, 1e-3)
class SparseMathOpsTest(test_util.TensorFlowTestCase): class SparseMathOpsTest(test_util.TensorFlowTestCase):

View File

@ -225,7 +225,7 @@ def parse_example(serialized, features, name=None, example_names=None):
features: { features: {
"kw": VarLenFeature(tf.string), "kw": VarLenFeature(tf.string),
"dank": VarLenFeature(tf.int64), "dank": VarLenFeature(tf.int64),
"gps": VarLenFeature(tf.float), "gps": VarLenFeature(tf.float32),
} }
``` ```

View File

@ -548,7 +548,8 @@ def sparse_reduce_sum(sp_input, reduction_axes=None, keep_dims=False):
with length 1. with length 1.
If `reduction_axes` has no entries, all dimensions are reduced, and a tensor If `reduction_axes` has no entries, all dimensions are reduced, and a tensor
with a single element is returned. with a single element is returned. Additionally, the axes can be negative,
similar to the indexing rules in Python.
For example: For example:
@ -558,7 +559,7 @@ def sparse_reduce_sum(sp_input, reduction_axes=None, keep_dims=False):
# where ? is implictly-zero. # where ? is implictly-zero.
tf.sparse_reduce_sum(x) ==> 3 tf.sparse_reduce_sum(x) ==> 3
tf.sparse_reduce_sum(x, 0) ==> [1, 1, 1] tf.sparse_reduce_sum(x, 0) ==> [1, 1, 1]
tf.sparse_reduce_sum(x, 1) ==> [2, 1] tf.sparse_reduce_sum(x, 1) ==> [2, 1] # Can also use -1 as the axis.
tf.sparse_reduce_sum(x, 1, keep_dims=True) ==> [[2], [1]] tf.sparse_reduce_sum(x, 1, keep_dims=True) ==> [[2], [1]]
tf.sparse_reduce_sum(x, [0, 1]) ==> 3 tf.sparse_reduce_sum(x, [0, 1]) ==> 3
``` ```

View File

@ -114,8 +114,7 @@ class EventAccumulator(object):
`Accumulator.Scalars(tag)`) allow for the retrieval of all data `Accumulator.Scalars(tag)`) allow for the retrieval of all data
associated with that tag. associated with that tag.
Before usage, the `EventAccumulator` must be activated via `Reload()`. This The `Reload()` method synchronously loads all of the data written so far.
method synchronosly loads all of the data written so far.
Histograms, audio, and images are very large, so storing all of them is not Histograms, audio, and images are very large, so storing all of them is not
recommended. recommended.
@ -175,7 +174,6 @@ class EventAccumulator(object):
self._compression_bps = compression_bps self._compression_bps = compression_bps
self.purge_orphaned_data = purge_orphaned_data self.purge_orphaned_data = purge_orphaned_data
self._activated = False
self.most_recent_step = -1 self.most_recent_step = -1
self.most_recent_wall_time = -1 self.most_recent_wall_time = -1
self.file_version = None self.file_version = None
@ -188,12 +186,10 @@ class EventAccumulator(object):
"""Loads all events added since the last call to `Reload`. """Loads all events added since the last call to `Reload`.
If `Reload` was never called, loads all events in the file. If `Reload` was never called, loads all events in the file.
Calling `Reload` activates the `EventAccumulator`.
Returns: Returns:
The `EventAccumulator`. The `EventAccumulator`.
""" """
self._activated = True
with self._generator_mutex: with self._generator_mutex:
for event in self._generator.Load(): for event in self._generator.Load():
if event.HasField('file_version'): if event.HasField('file_version'):
@ -232,13 +228,9 @@ class EventAccumulator(object):
def Tags(self): def Tags(self):
"""Return all tags found in the value stream. """Return all tags found in the value stream.
Raises:
RuntimeError: If the `EventAccumulator` has not been activated.
Returns: Returns:
A `{tagType: ['list', 'of', 'tags']}` dictionary. A `{tagType: ['list', 'of', 'tags']}` dictionary.
""" """
self._VerifyActivated()
return {IMAGES: self._images.Keys(), return {IMAGES: self._images.Keys(),
AUDIO: self._audio.Keys(), AUDIO: self._audio.Keys(),
HISTOGRAMS: self._histograms.Keys(), HISTOGRAMS: self._histograms.Keys(),
@ -255,12 +247,10 @@ class EventAccumulator(object):
Raises: Raises:
KeyError: If the tag is not found. KeyError: If the tag is not found.
RuntimeError: If the `EventAccumulator` has not been activated.
Returns: Returns:
An array of `ScalarEvent`s. An array of `ScalarEvent`s.
""" """
self._VerifyActivated()
return self._scalars.Items(tag) return self._scalars.Items(tag)
def Graph(self): def Graph(self):
@ -268,12 +258,10 @@ class EventAccumulator(object):
Raises: Raises:
ValueError: If there is no graph for this run. ValueError: If there is no graph for this run.
RuntimeError: If the `EventAccumulator` has not been activated.
Returns: Returns:
The `graph_def` proto. The `graph_def` proto.
""" """
self._VerifyActivated()
if self._graph is None: if self._graph is None:
raise ValueError('There is no graph in this EventAccumulator') raise ValueError('There is no graph in this EventAccumulator')
graph = graph_pb2.GraphDef() graph = graph_pb2.GraphDef()
@ -288,12 +276,10 @@ class EventAccumulator(object):
Raises: Raises:
ValueError: If the tag is not found. ValueError: If the tag is not found.
RuntimeError: If the `EventAccumulator` has not been activated.
Returns: Returns:
The metadata in form of `RunMetadata` proto. The metadata in form of `RunMetadata` proto.
""" """
self._VerifyActivated()
if tag not in self._tagged_metadata: if tag not in self._tagged_metadata:
raise ValueError('There is no run metadata with this tag name') raise ValueError('There is no run metadata with this tag name')
@ -309,12 +295,10 @@ class EventAccumulator(object):
Raises: Raises:
KeyError: If the tag is not found. KeyError: If the tag is not found.
RuntimeError: If the `EventAccumulator` has not been activated.
Returns: Returns:
An array of `HistogramEvent`s. An array of `HistogramEvent`s.
""" """
self._VerifyActivated()
return self._histograms.Items(tag) return self._histograms.Items(tag)
def CompressedHistograms(self, tag): def CompressedHistograms(self, tag):
@ -325,12 +309,10 @@ class EventAccumulator(object):
Raises: Raises:
KeyError: If the tag is not found. KeyError: If the tag is not found.
RuntimeError: If the `EventAccumulator` has not been activated.
Returns: Returns:
An array of `CompressedHistogramEvent`s. An array of `CompressedHistogramEvent`s.
""" """
self._VerifyActivated()
return self._compressed_histograms.Items(tag) return self._compressed_histograms.Items(tag)
def Images(self, tag): def Images(self, tag):
@ -341,12 +323,10 @@ class EventAccumulator(object):
Raises: Raises:
KeyError: If the tag is not found. KeyError: If the tag is not found.
RuntimeError: If the `EventAccumulator` has not been activated.
Returns: Returns:
An array of `ImageEvent`s. An array of `ImageEvent`s.
""" """
self._VerifyActivated()
return self._images.Items(tag) return self._images.Items(tag)
def Audio(self, tag): def Audio(self, tag):
@ -357,12 +337,10 @@ class EventAccumulator(object):
Raises: Raises:
KeyError: If the tag is not found. KeyError: If the tag is not found.
RuntimeError: If the `EventAccumulator` has not been activated.
Returns: Returns:
An array of `AudioEvent`s. An array of `AudioEvent`s.
""" """
self._VerifyActivated()
return self._audio.Items(tag) return self._audio.Items(tag)
def _MaybePurgeOrphanedData(self, event): def _MaybePurgeOrphanedData(self, event):
@ -599,10 +577,6 @@ class EventAccumulator(object):
event.wall_time, *expired_per_type) event.wall_time, *expired_per_type)
logging.warn(purge_msg) logging.warn(purge_msg)
def _VerifyActivated(self):
if not self._activated:
raise RuntimeError('Accumulator must be activated before it may be used.')
def _GetPurgeMessage(most_recent_step, most_recent_wall_time, event_step, def _GetPurgeMessage(most_recent_step, most_recent_wall_time, event_step,
event_wall_time, num_expired_scalars, num_expired_histos, event_wall_time, num_expired_scalars, num_expired_histos,

View File

@ -456,18 +456,6 @@ class MockingEventAccumulatorTest(EventAccumulatorTest):
self.assertEqual(acc.Audio('snd1'), [snd1]) self.assertEqual(acc.Audio('snd1'), [snd1])
self.assertEqual(acc.Audio('snd2'), [snd2]) self.assertEqual(acc.Audio('snd2'), [snd2])
def testActivation(self):
gen = _EventGenerator()
acc = ea.EventAccumulator(gen)
self.assertFalse(acc._activated)
with self.assertRaises(RuntimeError):
acc.Tags()
with self.assertRaises(RuntimeError):
acc.Scalars('s1')
acc.Reload()
self.assertTrue(acc._activated)
acc._activated = False
def testKeyError(self): def testKeyError(self):
gen = _EventGenerator() gen = _EventGenerator()
acc = ea.EventAccumulator(gen) acc = ea.EventAccumulator(gen)

View File

@ -113,8 +113,7 @@ class EventMultiplexer(object):
accumulator. accumulator.
If `Reload` has been called, it will `Reload` the newly created If `Reload` has been called, it will `Reload` the newly created
accumulators. This maintains the invariant that once the Multiplexer was accumulators.
activated, all of its accumulators are active.
Args: Args:
path: Path to the event files (or event directory) for given run. path: Path to the event files (or event directory) for given run.
@ -199,7 +198,6 @@ class EventMultiplexer(object):
Raises: Raises:
KeyError: If the run is not found, or the tag is not available for KeyError: If the run is not found, or the tag is not available for
the given run. the given run.
RuntimeError: If the run's `EventAccumulator` has not been activated.
Returns: Returns:
An array of `event_accumulator.ScalarEvents`. An array of `event_accumulator.ScalarEvents`.
@ -216,7 +214,6 @@ class EventMultiplexer(object):
Raises: Raises:
KeyError: If the run is not found. KeyError: If the run is not found.
ValueError: If the run does not have an associated graph. ValueError: If the run does not have an associated graph.
RuntimeError: If the run's EventAccumulator has not been activated.
Returns: Returns:
The `graph_def` protobuf data structure. The `graph_def` protobuf data structure.
@ -234,7 +231,6 @@ class EventMultiplexer(object):
Raises: Raises:
KeyError: If the run is not found, or the tag is not available for the KeyError: If the run is not found, or the tag is not available for the
given run. given run.
RuntimeError: If the run's EventAccumulator has not been activated.
Returns: Returns:
The metadata in the form of `RunMetadata` protobuf data structure. The metadata in the form of `RunMetadata` protobuf data structure.
@ -252,7 +248,6 @@ class EventMultiplexer(object):
Raises: Raises:
KeyError: If the run is not found, or the tag is not available for KeyError: If the run is not found, or the tag is not available for
the given run. the given run.
RuntimeError: If the run's `EventAccumulator` has not been activated.
Returns: Returns:
An array of `event_accumulator.HistogramEvents`. An array of `event_accumulator.HistogramEvents`.
@ -270,7 +265,6 @@ class EventMultiplexer(object):
Raises: Raises:
KeyError: If the run is not found, or the tag is not available for KeyError: If the run is not found, or the tag is not available for
the given run. the given run.
RuntimeError: If the run's EventAccumulator has not been activated.
Returns: Returns:
An array of `event_accumulator.CompressedHistogramEvents`. An array of `event_accumulator.CompressedHistogramEvents`.
@ -288,7 +282,6 @@ class EventMultiplexer(object):
Raises: Raises:
KeyError: If the run is not found, or the tag is not available for KeyError: If the run is not found, or the tag is not available for
the given run. the given run.
RuntimeError: If the run's `EventAccumulator` has not been activated.
Returns: Returns:
An array of `event_accumulator.ImageEvents`. An array of `event_accumulator.ImageEvents`.
@ -306,7 +299,6 @@ class EventMultiplexer(object):
Raises: Raises:
KeyError: If the run is not found, or the tag is not available for KeyError: If the run is not found, or the tag is not available for
the given run. the given run.
RuntimeError: If the run's `EventAccumulator` has not been activated.
Returns: Returns:
An array of `event_accumulator.AudioEvents`. An array of `event_accumulator.AudioEvents`.

View File

@ -184,6 +184,7 @@ bool IsCudnnR2() {
__macro(cudnnSetStream) \ __macro(cudnnSetStream) \
__macro(cudnnActivationForward) \ __macro(cudnnActivationForward) \
__macro(cudnnConvolutionForward) \ __macro(cudnnConvolutionForward) \
__macro(cudnnConvolutionBackwardBias) \
__macro(cudnnGetConvolutionForwardWorkspaceSize) \ __macro(cudnnGetConvolutionForwardWorkspaceSize) \
__macro(cudnnTransformTensor) \ __macro(cudnnTransformTensor) \
__macro(cudnnSetConvolutionNdDescriptor) \ __macro(cudnnSetConvolutionNdDescriptor) \
@ -1493,6 +1494,72 @@ bool CudnnSupport::DoConvolveBackwardFilter(
algorithm, output_profile_result); algorithm, output_profile_result);
} }
template <class T>
bool CudnnSupport::DoConvolveBackwardBiasImpl(
Stream* stream, int cudnn_type, // Actually cudnnDataType_t.
const dnn::BatchDescriptor& input_descriptor,
const DeviceMemory<T>& input_data,
const dnn::BatchDescriptor& bias_descriptor,
DeviceMemory<T>* backward_bias_data) {
mutex_lock lock{dnn_handle_mutex_};
auto status = dynload::cudnnSetStream(parent_, ToHandle(dnn_handle_),
AsCUDAStreamValue(stream));
if (status != CUDNN_STATUS_SUCCESS) {
LOG(FATAL) << "failed to set stream for cudnn handle: " << ToString(status);
}
ScopedTensorDescriptor input_nd{parent_, input_descriptor,
static_cast<cudnnDataType_t>(cudnn_type)};
ScopedTensorDescriptor bias_nd{parent_, bias_descriptor,
static_cast<cudnnDataType_t>(cudnn_type)};
// Alpha is the scaling factor for input.
float alpha = 1.0;
// Beta is the scaling factor for output.
float beta = 0.0;
status = dynload::cudnnConvolutionBackwardBias(
parent_, ToHandle(dnn_handle_), &alpha, input_nd.handle(),
input_data.opaque(), &beta, bias_nd.handle(),
backward_bias_data->opaque());
if (status != CUDNN_STATUS_SUCCESS) {
LOG(FATAL) << "failed to enqueue backward convolution on stream: "
<< ToString(status);
return false;
}
return true;
}
bool CudnnSupport::DoConvolveBackwardBias(
Stream* stream, const BatchDescriptor& input_descriptor,
const DeviceMemory<double>& input_data,
const BatchDescriptor& bias_descriptor,
DeviceMemory<double>* backward_bias_data) {
return DoConvolveBackwardBiasImpl(stream, CUDNN_DATA_DOUBLE, input_descriptor,
input_data, bias_descriptor,
backward_bias_data);
}
bool CudnnSupport::DoConvolveBackwardBias(
Stream* stream, const BatchDescriptor& input_descriptor,
const DeviceMemory<float>& input_data,
const BatchDescriptor& bias_descriptor,
DeviceMemory<float>* backward_bias_data) {
return DoConvolveBackwardBiasImpl(stream, CUDNN_DATA_FLOAT, input_descriptor,
input_data, bias_descriptor,
backward_bias_data);
}
bool CudnnSupport::DoConvolveBackwardBias(
Stream* stream, const BatchDescriptor& input_descriptor,
const DeviceMemory<Eigen::half>& input_data,
const BatchDescriptor& bias_descriptor,
DeviceMemory<Eigen::half>* backward_bias_data) {
return DoConvolveBackwardBiasImpl(stream, CUDNN_DATA_HALF, input_descriptor,
input_data, bias_descriptor,
backward_bias_data);
}
bool CudnnSupport::DoMatMul(Stream* stream, bool CudnnSupport::DoMatMul(Stream* stream,
const DeviceMemory<float>& input_data, const DeviceMemory<float>& input_data,
const DeviceMemory<float>& weights, const DeviceMemory<float>& weights,

View File

@ -140,6 +140,24 @@ class CudnnSupport : public dnn::DnnSupport {
ScratchAllocator* scratch_allocator, dnn::AlgorithmType algorithm, ScratchAllocator* scratch_allocator, dnn::AlgorithmType algorithm,
dnn::ProfileResult* output_profile_result) override; dnn::ProfileResult* output_profile_result) override;
bool DoConvolveBackwardBias(
Stream* stream, const dnn::BatchDescriptor& input_descriptor,
const DeviceMemory<double>& input_data,
const dnn::BatchDescriptor& bias_descriptor,
DeviceMemory<double>* backward_bias_data) override;
bool DoConvolveBackwardBias(Stream* stream,
const dnn::BatchDescriptor& input_descriptor,
const DeviceMemory<float>& input_data,
const dnn::BatchDescriptor& bias_descriptor,
DeviceMemory<float>* backward_bias_data) override;
bool DoConvolveBackwardBias(
Stream* stream, const dnn::BatchDescriptor& input_descriptor,
const DeviceMemory<Eigen::half>& input_data,
const dnn::BatchDescriptor& bias_descriptor,
DeviceMemory<Eigen::half>* backward_bias_data) override;
bool DoMatMul(Stream* stream, const DeviceMemory<float>& input_data, bool DoMatMul(Stream* stream, const DeviceMemory<float>& input_data,
const DeviceMemory<float>& weights, const DeviceMemory<float>& weights,
const dnn::BatchDescriptor& input_dimensions, const dnn::BatchDescriptor& input_dimensions,
@ -311,6 +329,14 @@ class CudnnSupport : public dnn::DnnSupport {
dnn::AlgorithmType algorithm, dnn::AlgorithmType algorithm,
dnn::ProfileResult* output_profile_result); dnn::ProfileResult* output_profile_result);
template <class T>
bool DoConvolveBackwardBiasImpl(Stream* stream,
int cudnn_type, // Actually cudnnDataType_t.
const dnn::BatchDescriptor& input_descriptor,
const DeviceMemory<T>& input_data,
const dnn::BatchDescriptor& bias_descriptor,
DeviceMemory<T>* backward_bias_data);
SE_DISALLOW_COPY_AND_ASSIGN(CudnnSupport); SE_DISALLOW_COPY_AND_ASSIGN(CudnnSupport);
}; };

View File

@ -849,6 +849,43 @@ class DnnSupport {
ScratchAllocator* scratch_allocator, AlgorithmType algorithm, ScratchAllocator* scratch_allocator, AlgorithmType algorithm,
ProfileResult* output_profile_result) = 0; ProfileResult* output_profile_result) = 0;
// Enqueues a single-precision backward convolution (for bias) operation onto
// the stream.
//
// Arguments:
// stream: borrowed pointer to the stream that the 'convolve' operation
// should be enqueued onto.
// input_descriptor: dimensions of the input layer.
// input_data: un-owned device memory region which contains the
// convolution input.
// bias_descriptor: dimensions of the bias tensor. Should be the same as the
// input dimensions, but with the spatial dimensions set to 1.
// backward_filter_data: un-owned device memory region in which to place the
// backprop of the bias.
virtual bool DoConvolveBackwardBias(Stream* stream,
const BatchDescriptor& input_descriptor,
const DeviceMemory<float>& input_data,
const BatchDescriptor& bias_descriptor,
DeviceMemory<float>* backward_bias_data) {
return false;
}
virtual bool DoConvolveBackwardBias(
Stream* stream, const BatchDescriptor& input_descriptor,
const DeviceMemory<double>& input_data,
const BatchDescriptor& bias_descriptor,
DeviceMemory<double>* backward_bias_data) {
return false;
}
virtual bool DoConvolveBackwardBias(
Stream* stream, const BatchDescriptor& input_descriptor,
const DeviceMemory<Eigen::half>& input_data,
const BatchDescriptor& bias_descriptor,
DeviceMemory<Eigen::half>* backward_bias_data) {
return false;
}
// Fully connects the "nodes" (float values) in input_data with // Fully connects the "nodes" (float values) in input_data with
// shape input_dimensions to output_data with output_dimensions // shape input_dimensions to output_data with output_dimensions
// using provided weights. This is equivalent to computing a matrix // using provided weights. This is equivalent to computing a matrix

View File

@ -741,6 +741,57 @@ Stream &Stream::ThenConvolveBackwardFilter(
/*scratch_allocator=*/nullptr); /*scratch_allocator=*/nullptr);
} }
template <typename T>
Stream &Stream::ThenConvolveBackwardBiasImpl(
const dnn::BatchDescriptor &input_descriptor,
const DeviceMemory<T> &input_data,
const dnn::BatchDescriptor &bias_descriptor,
DeviceMemory<T> *backward_bias_data) {
VLOG_CALL(PARAM(input_descriptor), PARAM(input_data), PARAM(bias_descriptor),
PARAM(backward_bias_data));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
CheckError(dnn->DoConvolveBackwardBias(this, input_descriptor, input_data,
bias_descriptor,
backward_bias_data));
} else {
SetError();
LOG(WARNING)
<< "attempting to perform DNN operation using StreamExecutor "
"without DNN support";
}
}
return *this;
}
Stream &Stream::ThenConvolveBackwardBias(
const dnn::BatchDescriptor &input_descriptor,
const DeviceMemory<double> &input_data,
const dnn::BatchDescriptor &bias_descriptor,
DeviceMemory<double> *backward_bias_data) {
return ThenConvolveBackwardBiasImpl(input_descriptor, input_data,
bias_descriptor, backward_bias_data);
}
Stream &Stream::ThenConvolveBackwardBias(
const dnn::BatchDescriptor &input_descriptor,
const DeviceMemory<float> &input_data,
const dnn::BatchDescriptor &bias_descriptor,
DeviceMemory<float> *backward_bias_data) {
return ThenConvolveBackwardBiasImpl(input_descriptor, input_data,
bias_descriptor, backward_bias_data);
}
Stream &Stream::ThenConvolveBackwardBias(
const dnn::BatchDescriptor &input_descriptor,
const DeviceMemory<Eigen::half> &input_data,
const dnn::BatchDescriptor &bias_descriptor,
DeviceMemory<Eigen::half> *backward_bias_data) {
return ThenConvolveBackwardBiasImpl(input_descriptor, input_data,
bias_descriptor, backward_bias_data);
}
Stream &Stream::ThenMatMul(const DeviceMemory<float> &input_data, Stream &Stream::ThenMatMul(const DeviceMemory<float> &input_data,
const DeviceMemory<float> &weights, const DeviceMemory<float> &weights,
const dnn::BatchDescriptor &input_dimensions, const dnn::BatchDescriptor &input_dimensions,

View File

@ -371,6 +371,22 @@ class Stream {
ScratchAllocator *scratch_allocator, dnn::AlgorithmType algorithm, ScratchAllocator *scratch_allocator, dnn::AlgorithmType algorithm,
dnn::ProfileResult *output_profile_result); dnn::ProfileResult *output_profile_result);
Stream &ThenConvolveBackwardBias(const dnn::BatchDescriptor &input_descriptor,
const DeviceMemory<double> &input_data,
const dnn::BatchDescriptor &bias_descriptor,
DeviceMemory<double> *backward_bias_data);
Stream &ThenConvolveBackwardBias(const dnn::BatchDescriptor &input_descriptor,
const DeviceMemory<float> &input_data,
const dnn::BatchDescriptor &bias_descriptor,
DeviceMemory<float> *backward_bias_data);
Stream &ThenConvolveBackwardBias(
const dnn::BatchDescriptor &input_descriptor,
const DeviceMemory<Eigen::half> &input_data,
const dnn::BatchDescriptor &bias_descriptor,
DeviceMemory<Eigen::half> *backward_bias_data);
Stream &ThenMatMul(const DeviceMemory<float> &input_data, Stream &ThenMatMul(const DeviceMemory<float> &input_data,
const DeviceMemory<float> &weights, const DeviceMemory<float> &weights,
const dnn::BatchDescriptor &input_dimensions, const dnn::BatchDescriptor &input_dimensions,
@ -1439,6 +1455,14 @@ class Stream {
// BlockHostUntilDone() is called. // BlockHostUntilDone() is called.
internal::TemporaryMemoryManager temporary_memory_manager_; internal::TemporaryMemoryManager temporary_memory_manager_;
// Implementation of ThenConvolveBackwardBias that is shared by all types.
template <typename T>
Stream &ThenConvolveBackwardBiasImpl(
const dnn::BatchDescriptor &input_descriptor,
const DeviceMemory<T> &input_data,
const dnn::BatchDescriptor &bias_descriptor,
DeviceMemory<T> *backward_bias_data);
SE_DISALLOW_COPY_AND_ASSIGN(Stream); SE_DISALLOW_COPY_AND_ASSIGN(Stream);
}; };

View File

@ -120,12 +120,9 @@ def StartMultiplexerReloadingThread(multiplexer, path_to_run, load_interval):
Returns: Returns:
A started `threading.Thread` that reloads the multiplexer. A started `threading.Thread` that reloads the multiplexer.
""" """
# Ensure the Multiplexer initializes in a loaded state before it adds runs # We don't call multiplexer.Reload() here because that would make
# So it can handle HTTP requests while runs are loading # AddRunsFromDirectory block until the runs have all loaded.
multiplexer.Reload()
for path in path_to_run.keys(): for path in path_to_run.keys():
if gcs.IsGCSPath(path): if gcs.IsGCSPath(path):
gcs.CheckIsSupported() gcs.CheckIsSupported()

View File

@ -321,6 +321,11 @@ def _cuda_copts():
"--cuda-gpu-arch=sm_35", "--cuda-gpu-arch=sm_35",
] ]
), ),
}) + select({
# Pass -O3 when building CUDA code with clang; some important
# optimizations are not enabled at O2.
"//third_party/gpus/cuda:using_clang_opt": ["-O3"],
"//conditions:default": [],
}) })
# Build defs for TensorFlow kernels # Build defs for TensorFlow kernels

View File

@ -13,8 +13,8 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
native.new_http_archive( native.new_http_archive(
name = "eigen_archive", name = "eigen_archive",
url = "https://bitbucket.org/eigen/eigen/get/a5e9085a94e8.tar.gz", url = "https://bitbucket.org/eigen/eigen/get/f3a13643ac1f.tar.gz",
sha256 = "967126237829c7c87abb6cd0e13a5a235b0377d51575522c390b9486aed13e71", sha256 = "a9266e60366cddb371a23d86b11a297eee86372a89ef4b38a3509012f9cc37ec",
build_file = path_prefix + "eigen.BUILD", build_file = path_prefix + "eigen.BUILD",
) )

View File

@ -1 +1 @@
#include "eigen-eigen-a5e9085a94e8/Eigen/Cholesky" #include "eigen-eigen-f3a13643ac1f/Eigen/Cholesky"

View File

@ -1 +1 @@
#include "eigen-eigen-a5e9085a94e8/Eigen/Core" #include "eigen-eigen-f3a13643ac1f/Eigen/Core"

View File

@ -1 +1 @@
#include "eigen-eigen-a5e9085a94e8/Eigen/Eigenvalues" #include "eigen-eigen-f3a13643ac1f/Eigen/Eigenvalues"

View File

@ -1 +1 @@
#include "eigen-eigen-a5e9085a94e8/Eigen/LU" #include "eigen-eigen-f3a13643ac1f/Eigen/LU"

View File

@ -1 +1 @@
#include "eigen-eigen-a5e9085a94e8/Eigen/QR" #include "eigen-eigen-f3a13643ac1f/Eigen/QR"

View File

@ -1 +1 @@
#include "eigen-eigen-a5e9085a94e8/unsupported/Eigen/CXX11/Tensor" #include "eigen-eigen-f3a13643ac1f/unsupported/Eigen/CXX11/Tensor"

View File

@ -31,6 +31,15 @@ config_setting(
}, },
) )
# Equivalent to using_clang && -c opt.
config_setting(
name = "using_clang_opt",
values = {
"define": "using_cuda_clang=true",
"compilation_mode": "opt",
},
)
config_setting( config_setting(
name = "darwin", name = "darwin",
values = {"cpu": "darwin"}, values = {"cpu": "darwin"},