Enable tf.erf() for SparseTensor (#3122)

This commit is contained in:
Siddharth Agrawal 2016-07-01 01:59:14 +05:30 committed by Martin Wicke
parent ac90ecb08d
commit aa2cacd662
2 changed files with 23 additions and 0 deletions

View File

@ -215,6 +215,7 @@ class UnaryOpTest(tf.test.TestCase):
self._compareBothSparse(z, np.sqrt, tf.sqrt, tol=1e-3)
self._compareBothSparse(x, np.tanh, tf.tanh)
self._compareBothSparse(y, np.sign, tf.sign)
self._compareBothSparse(x, np.vectorize(math.erf), tf.erf)
def testFloatTanhEdge(self):
x = np.arange(40, 40 + 6).reshape(6).astype(np.float32)
@ -254,6 +255,7 @@ class UnaryOpTest(tf.test.TestCase):
self._compareBothSparse(x, np.sqrt, tf.sqrt, tol=1e-3)
self._compareBothSparse(x, np.tanh, tf.tanh)
self._compareBothSparse(x, np.sign, tf.sign)
self._compareBothSparse(x, np.sign, tf.erf)
def testDoubleBasic(self):
x = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float64)
@ -292,6 +294,7 @@ class UnaryOpTest(tf.test.TestCase):
self._compareBothSparse(z, np.sqrt, tf.sqrt, tol=1e-3)
self._compareBothSparse(x, np.tanh, tf.tanh)
self._compareBothSparse(y, np.sign, tf.sign)
self._compareBothSparse(x, np.vectorize(math.erf), tf.erf)
def testHalfBasic(self):
x = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float16)
@ -325,6 +328,7 @@ class UnaryOpTest(tf.test.TestCase):
self._compareBothSparse(z, np.sqrt, tf.sqrt, tol=1e-3)
self._compareBothSparse(x, np.tanh, tf.tanh)
self._compareBothSparse(y, np.sign, tf.sign)
self._compareBothSparse(x, np.vectorize(math.erf), tf.erf, tol=1e-3)
def testInt32Basic(self):
x = np.arange(-6, 6, 2).reshape(1, 3, 2).astype(np.int32)

View File

@ -348,6 +348,25 @@ def sqrt(x, name=None):
return gen_math_ops.sqrt(x, name=name)
def erf(x, name=None):
"""Computes the Gauss error function of `x` element-wise.
Args:
x: A `Tensor` of `SparseTensor`. Must be one of the following types: `half`,
`float32`, `float64`.
name: A name for the operation (optional).
Returns:
A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
"""
with ops.op_scope([x], name, "Erf") as name:
if isinstance(x, ops.SparseTensor):
x_erf = gen_math_ops.erf(x.values, name=name)
return ops.SparseTensor(indices=x.indices, values=x_erf, shape=x.shape)
else:
return gen_math_ops.erf(x, name=name)
def complex_abs(x, name=None):
r"""Computes the complex absolute value of a tensor.