Base classes for probability distributions and uniform distribution

Change: 122194730
This commit is contained in:
A. Unique TensorFlower 2016-05-12 13:02:03 -08:00 committed by TensorFlower Gardener
parent 313408ba1f
commit 1f6cd6fbb0
5 changed files with 736 additions and 1 deletions

View File

@ -38,6 +38,16 @@ cuda_py_tests(
],
)
cuda_py_tests(
name = "uniform_test",
size = "small",
srcs = ["python/kernel_tests/uniform_test.py"],
additional_deps = [
":distributions_py",
"//tensorflow/python:platform_test",
],
)
cuda_py_tests(
name = "mvn_test",
size = "small",

View File

@ -12,16 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Classes representing statistical distributions. Ops for working with them.
"""Classes representing statistical distributions and ops for working with them.
## Classes for statistical distributions.
Classes that represent batches of statistical distributions. Each class is
initialized with parameters that define the distributions.
### Base classes
@@BaseDistribution
@@ContinuousDistribution
@@DiscreteDistribution
### Univariate (scalar) distributions
@@Gaussian
@@Uniform
### Multivariate distributions
@ -44,6 +51,8 @@ from __future__ import print_function
# pylint: disable=unused-import,wildcard-import,line-too-long
from tensorflow.contrib.distributions.python.ops.dirichlet_multinomial import *
from tensorflow.contrib.distributions.python.ops.distribution import *
from tensorflow.contrib.distributions.python.ops.gaussian import *
from tensorflow.contrib.distributions.python.ops.gaussian_conjugate_posteriors import *
from tensorflow.contrib.distributions.python.ops.mvn import *
from tensorflow.contrib.distributions.python.ops.uniform import *

View File

@ -0,0 +1,220 @@
# Copyright 2015 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Uniform distribution."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
class UniformTest(tf.test.TestCase):
def testUniformRange(self):
with self.test_session():
a = 3.0
b = 10.0
uniform = tf.contrib.distributions.Uniform(a=a, b=b)
self.assertAllClose(a, uniform.a.eval())
self.assertAllClose(b, uniform.b.eval())
self.assertAllClose(b - a, uniform.range.eval())
def testUniformPDF(self):
with self.test_session():
a = tf.constant([-3.0] * 5 + [15.0])
b = tf.constant([11.0] * 5 + [20.0])
uniform = tf.contrib.distributions.Uniform(a=a, b=b)
a_v = -3.0
b_v = 11.0
x = np.array([-10.5, 4.0, 0.0, 10.99, 11.3, 17.0], dtype=np.float32)
def _expected_pdf():
pdf = np.zeros_like(x) + 1.0 / (b_v - a_v)
pdf[x > b_v] = 0.0
pdf[x < a_v] = 0.0
pdf[5] = 1.0 / (20.0 - 15.0)
return pdf
expected_pdf = _expected_pdf()
pdf = uniform.pdf(x)
self.assertAllClose(expected_pdf, pdf.eval())
log_pdf = uniform.log_pdf(x)
self.assertAllClose(np.log(expected_pdf), log_pdf.eval())
def testUniformShape(self):
with self.test_session():
a = tf.constant([-3.0] * 5)
b = tf.constant(11.0)
uniform = tf.contrib.distributions.Uniform(a=a, b=b)
self.assertEqual(uniform.batch_shape().eval(), (5,))
self.assertEqual(uniform.get_batch_shape(), tf.TensorShape([5]))
self.assertEqual(uniform.event_shape().eval(), 1)
self.assertEqual(uniform.get_event_shape(), tf.TensorShape([]))
def testUniformPDFWithScalarEndpoint(self):
with self.test_session():
a = tf.constant([0.0, 5.0])
b = tf.constant(10.0)
uniform = tf.contrib.distributions.Uniform(a=a, b=b)
x = np.array([0.0, 8.0], dtype=np.float32)
expected_pdf = np.array([1.0 / (10.0 - 0.0), 1.0 / (10.0 - 5.0)])
pdf = uniform.pdf(x)
self.assertAllClose(expected_pdf, pdf.eval())
def testUniformCDF(self):
with self.test_session():
batch_size = 6
a = tf.constant([1.0] * batch_size)
b = tf.constant([11.0] * batch_size)
a_v = 1.0
b_v = 11.0
x = np.array([-2.5, 2.5, 4.0, 0.0, 10.99, 12.0], dtype=np.float32)
uniform = tf.contrib.distributions.Uniform(a=a, b=b)
def _expected_cdf():
cdf = (x - a_v) / (b_v - a_v)
cdf[x >= b_v] = 1
cdf[x < a_v] = 0
return cdf
cdf = uniform.cdf(x)
self.assertAllClose(_expected_cdf(), cdf.eval())
log_cdf = uniform.log_cdf(x)
self.assertAllClose(np.log(_expected_cdf()), log_cdf.eval())
def testUniformEntropy(self):
with self.test_session():
a_v = np.array([1.0, 1.0, 1.0])
b_v = np.array([[1.5, 2.0, 3.0]])
uniform = tf.contrib.distributions.Uniform(a=a_v, b=b_v)
expected_entropy = np.log(b_v - a_v)
self.assertAllClose(expected_entropy, uniform.entropy().eval())
def testUniformAssertMaxGtMin(self):
with self.test_session():
a_v = np.array([1.0, 1.0, 1.0], dtype=np.float32)
b_v = np.array([1.0, 2.0, 3.0], dtype=np.float32)
uniform = tf.contrib.distributions.Uniform(a=a_v, b=b_v)
with self.assertRaisesWithPredicateMatch(tf.errors.InvalidArgumentError,
"x < y"):
uniform.a.eval()
def testUniformSample(self):
with self.test_session():
a = tf.constant([3.0, 4.0])
b = tf.constant(13.0)
a1_v = 3.0
a2_v = 4.0
b_v = 13.0
n = tf.constant(100000)
uniform = tf.contrib.distributions.Uniform(a=a, b=b)
samples = uniform.sample(n, seed=137)
sample_values = samples.eval()
self.assertEqual(sample_values.shape, (100000, 2))
self.assertAllClose(sample_values[::, 0].mean(), (b_v + a1_v) / 2,
atol=1e-2)
self.assertAllClose(sample_values[::, 1].mean(), (b_v + a2_v) / 2,
atol=1e-2)
self.assertFalse(np.any(sample_values[::, 0] < a1_v) or np.any(
sample_values >= b_v))
self.assertFalse(np.any(sample_values[::, 1] < a2_v) or np.any(
sample_values >= b_v))
def testUniformSampleMultiDimensional(self):
with self.test_session():
batch_size = 2
a_v = [3.0, 22.0]
b_v = [13.0, 35.0]
a = tf.constant([a_v] * batch_size)
b = tf.constant([b_v] * batch_size)
uniform = tf.contrib.distributions.Uniform(a=a, b=b)
n_v = 100000
n = tf.constant(n_v)
samples = uniform.sample(n, seed=138)
self.assertEqual(samples.get_shape(), (n_v, batch_size, 2))
sample_values = samples.eval()
self.assertFalse(np.any(sample_values[:, 0, 0] < a_v[0]) or np.any(
sample_values[:, 0, 0] >= b_v[0]))
self.assertFalse(np.any(sample_values[:, 0, 1] < a_v[1]) or np.any(
sample_values[:, 0, 1] >= b_v[1]))
self.assertAllClose(sample_values[:, 0, 0].mean(), (a_v[0] + b_v[0]) / 2,
atol=1e-2)
self.assertAllClose(sample_values[:, 0, 1].mean(), (a_v[1] + b_v[1]) / 2,
atol=1e-2)
def testUniformMeanAndVariance(self):
with self.test_session():
a = 10.0
b = 100.0
uniform = tf.contrib.distributions.Uniform(a=a, b=b)
self.assertAllClose(uniform.variance.eval(), (b - a)**2 / 12)
self.assertAllClose(uniform.mean.eval(), (b + a) / 2)
def testUniformNans(self):
with self.test_session():
a = 10.0
b = [11.0, 100.0]
uniform = tf.contrib.distributions.Uniform(a=a, b=b)
no_nans = tf.constant(1.0)
nans = tf.constant(0.0) / tf.constant(0.0)
self.assertTrue(tf.is_nan(nans).eval())
with_nans = tf.pack([no_nans, nans])
pdf = uniform.pdf(with_nans)
is_nan = tf.is_nan(pdf).eval()
print(pdf.eval())
self.assertFalse(is_nan[0])
self.assertTrue(is_nan[1])
def testUniformSamplePdf(self):
with self.test_session():
a = 10.0
b = [11.0, 100.0]
uniform = tf.contrib.distributions.Uniform(a, b)
self.assertTrue(tf.reduce_all(uniform.pdf(uniform.sample(10)) > 0).eval())
def testUniformBroadcasting(self):
with self.test_session():
a = 10.0
b = [11.0, 20.0]
uniform = tf.contrib.distributions.Uniform(a, b)
pdf = uniform.pdf([[10.5, 11.5], [9.0, 19.0], [10.5, 21.0]])
expected_pdf = np.array([[1.0, 0.1], [0.0, 0.1], [1.0, 0.0]])
self.assertAllClose(expected_pdf, pdf.eval())
if __name__ == "__main__":
tf.test.main()

View File

@ -0,0 +1,256 @@
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base classes for probability distributions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import six
from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
@six.add_metaclass(abc.ABCMeta)
class BaseDistribution(object):
"""Abstract base class for probability distributions.
This class, along with `ContinuousDistribution` and `DiscreteDistribution`,
defines the API for probability distributions.
Users will never instantiate a `BaseDistribution`, but will instead
instantiate subclasses of either `ContinuousDistribution` or
`DiscreteDistribution`.
Developers of new distributions should prefer to subclass
`ContinuousDistribution` or `DiscreteDistribution`.
### API
The key methods for probability distributions are defined here. The likelihood
functions (`pdf`, `log_pdf`) and (`pmf`, `log_pmf`) are defined in
`ContinuousDistribution` and `DiscreteDistribution`, respectively.
To keep ops generated by the distribution tied together by name, subclasses
should override `name` and use it to preprend names of ops in other methods
(see `cdf` for an example).
Subclasses that wish to support `cdf` and `log_cdf` can override `log_cdf`
and use the base class's implementation for `cdf`.
### Broadcasting, batching, and shapes
All distributions support batches of independent distributions of that type.
The batch shape is determined by broadcasting together the parameters.
The shape of arguments to `__init__`, `cdf`, `log_cdf`, and the likelihood
functions defined in `ContinuousDistribution` and `DiscreteDistribution`
reflect this broadcasting, as does the return value of `sample`.
`sample_shape = (n,) + batch_shape + event_shape`, where `sample_shape` is the
shape of the `Tensor` returned from `sample`, `n` is the number of samples,
`batch_shape` defines how many independent distributions there are, and
`event_shape` defines the shape of samples from each of those independent
distributions. Samples are independent along the `batch_shape` dimensions,
but not necessarily so along the `event_shape` dimensions (dependending on
the particulars of the underlying distribution).
Using the `Uniform` distribution as an example:
```python
minval = 3.0
maxval = [[4.0, 6.0],
[10.0, 12.0]]
# Broadcasting:
# This instance represents 4 Uniform distributions. Each has a lower bound at
# 3.0 as the `minval` parameter was broadcasted to match `maxval`'s shape.
u = Uniform(minval, maxval)
# `event_shape` is `TensorShape([])`.
event_shape = u.get_event_shape()
# `event_shape_t` is a `Tensor` which will evaluate to a scalar 1.
event_shape_t = u.event_shape
# Sampling returns a sample per distribution. `samples` has shape
# (5, 2, 2), which is (n,) + batch_shape + event_shape, where n=5,
# batch_shape=(2, 2), and event_shape=().
samples = u.sample(5)
# The broadcasting holds across methods. Here we use `cdf` as an example. The
# same holds for `log_cdf` and the likelihood functions.
# `cum_prob` has shape (2, 2) as the `value` argument was broadcasted to the
# shape of the `Uniform` instance.
cum_prob_broadcast = u.cdf(4.0)
# `cum_prob`'s shape is (2, 2), one per distribution. No broadcasting
# occurred.
cum_prob_per_dist = u.cdf([[4.0, 5.0],
[6.0, 7.0]])
# INVALID as the `value` argument is not broadcastable to the distribution's
# shape.
cum_prob_invalid = u.cdf([4.0, 5.0, 6.0])
```
"""
@abc.abstractproperty
def name(self):
"""Name to prepend to all ops."""
pass
@abc.abstractproperty
def dtype(self):
"""dtype of samples from this distribution."""
pass
@abc.abstractmethod
def event_shape(self, name=None):
"""Shape of a sample from a single distribution as a 1-D int32 `Tensor`.
Args:
name: name to give to the op
Returns:
`Tensor` `event_shape`
"""
pass
@abc.abstractmethod
def get_event_shape(self):
"""`TensorShape` available at graph construction time.
Same meaning as `event_shape`. May be only partially defined.
"""
pass
@abc.abstractmethod
def batch_shape(self, name=None):
"""Batch dimensions of this instance as a 1-D int32 `Tensor`.
The product of the dimensions of the `batch_shape` is the number of
independent distributions of this kind the instance represents.
Args:
name: name to give to the op
Returns:
`Tensor` `batch_shape`
"""
pass
@abc.abstractmethod
def get_batch_shape(self):
"""`TensorShape` available at graph construction time.
Same meaning as `batch_shape`. May be only partially defined.
"""
pass
def sample(self, n, seed=None, name=None):
"""Generate `n` samples.
Args:
n: scalar. Number of samples to draw from each distribution.
seed: Python integer seed for RNG
name: name to give to the op.
Returns:
samples: a `Tensor` of shape `(n,) + self.batch_shape + self.event_shape`
with values of type `self.dtype`.
"""
raise NotImplementedError("sample not implemented")
def cdf(self, value, name="cdf"):
"""Cumulative distribution function."""
value = ops.convert_to_tensor(value)
with ops.op_scope([value], self.name):
with ops.name_scope(name):
return math_ops.exp(self.log_cdf(value))
def log_cdf(self, value, name="log_cdf"):
"""Log CDF."""
raise NotImplementedError("log_cdf is not implemented")
def entropy(self, name=None):
"""Entropy of the distribution in nats."""
raise NotImplementedError("entropy not implemented")
@property
def mean(self):
raise NotImplementedError("mean not implemented")
class ContinuousDistribution(BaseDistribution):
"""Base class for continuous probability distributions.
`ContinuousDistribution` defines the API for the likelihood functions `pdf`
and `log_pdf` of continuous probability distributions.
Subclasses must override both `pdf` and `log_pdf` but one can call this base
class's implementation.
See `BaseDistribution` for more information on the API for probability
distributions.
"""
@abc.abstractmethod
def pdf(self, value, name="pdf"):
"""Probability density function."""
value = ops.convert_to_tensor(value)
with ops.op_scope([value], self.name):
with ops.name_scope(name):
return math_ops.exp(self.log_pdf(value))
@abc.abstractmethod
def log_pdf(self, value, name="log_pdf"):
"""Log of the probability density function."""
value = ops.convert_to_tensor(value)
with ops.op_scope([value], self.name):
with ops.name_scope(name):
return math_ops.log(self.pdf(value))
class DiscreteDistribution(BaseDistribution):
"""Base class for discrete probability distributions.
`DiscreteDistribution` defines the API for the likelihood functions `pmf` and
`log_pmf` of discrete probability distributions.
Subclasses must override both `pmf` and `log_pmf` but one can call this base
class's implementation.
See `BaseDistribution` for more information on the API for probability
distributions.
"""
@abc.abstractmethod
def pmf(self, value, name="pmf"):
"""Probability mass function."""
value = ops.convert_to_tensor(value)
with ops.op_scope([value], self.name):
with ops.name_scope(name):
return math_ops.exp(self.log_pmf(value))
@abc.abstractmethod
def log_pmf(self, value, name="log_pmf"):
"""Log of the probability mass function."""
value = ops.convert_to_tensor(value)
with ops.op_scope([value], self.name):
with ops.name_scope(name):
return math_ops.log(self.pmf(value))

View File

@ -0,0 +1,240 @@
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The Uniform distribution class."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.distributions.python.ops.distribution import ContinuousDistribution # pylint: disable=line-too-long
from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util # pylint: disable=line-too-long
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import constant_op
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
class Uniform(ContinuousDistribution):
"""Uniform distribution with `a` and `b` parameters.
The PDF of this distribution is constant between [`a`, `b`], and 0 elsewhere.
"""
def __init__(self, a=0.0, b=1.0, name="Uniform"):
"""Construct Uniform distributions with `a` and `b`.
The parameters `a` and `b` must be shaped in a way that supports
broadcasting (e.g. `b - a` is a valid operation).
Here are examples without broadcasting:
```python
# Without broadcasting
u1 = Uniform(3.0, 4.0) # a single uniform distribution [3, 4]
u2 = Uniform([1.0, 2.0], [3.0, 4.0]) # 2 distributions [1, 3], [2, 4]
u3 = Uniform([[1.0, 2.0],
[3.0, 4.0]],
[[1.5, 2.5],
[3.5, 4.5]]) # 4 distributions
```
And with broadcasting:
```python
u1 = Uniform(3.0, [5.0, 6.0, 7.0]) # 3 distributions
```
Args:
a: `float` or `double` tensor, the minimum endpoint.
b: `float` or `double` tensor, the maximum endpoint. Must be > `a`.
name: The name to prefix Ops created by this distribution class.
Raises:
InvalidArgumentError: if `a >= b`.
"""
with ops.op_scope([a, b], name):
with ops.control_dependencies([check_ops.assert_less(a, b)]):
a = ops.convert_to_tensor(a, name="a")
b = ops.convert_to_tensor(b, name="b")
if a.dtype != b.dtype:
raise TypeError("Input x dtype does not match dtype: %s vs. %s" %
(a.dtype, b.dtype))
self._a = a
self._b = b
self._name = name
self._batch_shape = self._ones().get_shape()
self._event_shape = tensor_shape.TensorShape([])
contrib_tensor_util.assert_same_float_dtype((a, b))
@property
def name(self):
return self._name
@property
def dtype(self):
return self.a.dtype
def batch_shape(self, name="batch_shape"):
with ops.name_scope(self.name):
return array_ops.shape(self._ones(), name=name)
def get_batch_shape(self):
return self._batch_shape
def event_shape(self, name="event_shape"):
with ops.name_scope(self.name):
return constant_op.constant(1, name=name)
def get_event_shape(self):
return self._event_shape
@property
def a(self):
return self._a
@property
def b(self):
return self._b
def pdf(self, x, name="pdf"):
"""The PDF of observations in `x` under these Uniform distribution(s).
Args:
x: tensor of dtype `dtype`, must be broadcastable with `a` and `b`.
name: The name to give this op.
Returns:
pdf: tensor of dtype `dtype`, the pdf values of `x`. If `x` is `nan`, will
return `nan`.
"""
with ops.op_scope([self.a, self.b, x], self.name):
with ops.name_scope(name):
x = ops.convert_to_tensor(x, name="x")
if x.dtype != self.dtype:
raise TypeError("Input x dtype does not match dtype: %s vs. %s" %
(x.dtype, self.dtype))
broadcasted_x = x * self._ones()
return math_ops.select(
math_ops.is_nan(broadcasted_x), broadcasted_x, math_ops.select(
math_ops.logical_or(broadcasted_x < self.a,
broadcasted_x > self.b),
array_ops.zeros_like(broadcasted_x),
(1.0 / self.range) * array_ops.ones_like(broadcasted_x)))
def log_pdf(self, x, name="log_pdf"):
return super(Uniform, self).log_pdf(x, name)
def cdf(self, x, name="cdf"):
"""CDF of observations in `x` under these Uniform distribution(s).
Args:
x: tensor of dtype `dtype`, must be broadcastable with `a` and `b`.
name: The name to give this op.
Returns:
cdf: tensor of dtype `dtype`, the CDFs of `x`. If `x` is `nan`, will
return `nan`.
"""
with ops.op_scope([self.a, self.b, x], self.name):
with ops.name_scope(name):
x = ops.convert_to_tensor(x, name="x")
if x.dtype != self.dtype:
raise TypeError("Input x dtype does not match dtype: %s vs. %s" %
(x.dtype, self.dtype))
broadcasted_x = x * self._ones()
return math_ops.select(broadcasted_x < self.a,
array_ops.zeros_like(broadcasted_x),
math_ops.select(broadcasted_x >= self.b,
array_ops.ones_like(broadcasted_x),
(broadcasted_x - self.a) /
self.range))
def log_cdf(self, x, name="log_cdf"):
with ops.op_scope([self.a, self.b, x], self.name):
with ops.name_scope(name):
x = ops.convert_to_tensor(x, name="x")
return math_ops.log(self.cdf(x))
def entropy(self, name="entropy"):
"""The entropy of Uniform distribution(s).
Args:
name: The name to give this op.
Returns:
entropy: tensor of dtype `dtype`, the entropy.
"""
with ops.op_scope([self.a, self.b], self.name):
with ops.name_scope(name):
return math_ops.log(self.range)
def sample(self, n, seed=None, name="sample"):
"""Sample `n` observations from the Uniform Distributions.
Args:
n: `Scalar`, type int32, the number of observations to sample.
seed: Python integer, the random seed.
name: The name to give this op.
Returns:
samples: a `Tensor` of shape `(n,) + self.batch_shape + self.event_shape`
with values of type `self.dtype`.
"""
with ops.op_scope([self.a, self.b, n], self.name):
with ops.name_scope(name):
n = ops.convert_to_tensor(n, name="n")
n_val = tensor_util.constant_value(n)
shape = array_ops.concat(0, [array_ops.pack([n]), self.batch_shape()])
samples = random_ops.random_uniform(shape=shape,
dtype=self.dtype,
seed=seed)
# Provide some hints to shape inference
inferred_shape = tensor_shape.vector(n_val).concatenate(
self.get_batch_shape())
samples.set_shape(inferred_shape)
return (array_ops.expand_dims(self.a, 0) + array_ops.expand_dims(
self.range, 0) * samples)
@property
def mean(self):
return (self.a + self.b) / 2
@property
def variance(self):
return math_ops.square(self.range) / 12
@property
def range(self):
"""`b - a`."""
return self.b - self.a
# TODO(rsepassi): Find a more efficient way of doing the broadcasting in_ones
# and _zeros.
def _ones(self):
return array_ops.ones_like(self.a + self.b)
def _zeros(self):
return array_ops.zeros_like(self.a + self.b)