Enable tf.erf() for SparseTensor (#3122)
This commit is contained in:
parent
ac90ecb08d
commit
aa2cacd662
@ -215,6 +215,7 @@ class UnaryOpTest(tf.test.TestCase):
|
||||
self._compareBothSparse(z, np.sqrt, tf.sqrt, tol=1e-3)
|
||||
self._compareBothSparse(x, np.tanh, tf.tanh)
|
||||
self._compareBothSparse(y, np.sign, tf.sign)
|
||||
self._compareBothSparse(x, np.vectorize(math.erf), tf.erf)
|
||||
|
||||
def testFloatTanhEdge(self):
|
||||
x = np.arange(40, 40 + 6).reshape(6).astype(np.float32)
|
||||
@ -254,6 +255,7 @@ class UnaryOpTest(tf.test.TestCase):
|
||||
self._compareBothSparse(x, np.sqrt, tf.sqrt, tol=1e-3)
|
||||
self._compareBothSparse(x, np.tanh, tf.tanh)
|
||||
self._compareBothSparse(x, np.sign, tf.sign)
|
||||
self._compareBothSparse(x, np.sign, tf.erf)
|
||||
|
||||
def testDoubleBasic(self):
|
||||
x = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float64)
|
||||
@ -292,6 +294,7 @@ class UnaryOpTest(tf.test.TestCase):
|
||||
self._compareBothSparse(z, np.sqrt, tf.sqrt, tol=1e-3)
|
||||
self._compareBothSparse(x, np.tanh, tf.tanh)
|
||||
self._compareBothSparse(y, np.sign, tf.sign)
|
||||
self._compareBothSparse(x, np.vectorize(math.erf), tf.erf)
|
||||
|
||||
def testHalfBasic(self):
|
||||
x = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float16)
|
||||
@ -325,6 +328,7 @@ class UnaryOpTest(tf.test.TestCase):
|
||||
self._compareBothSparse(z, np.sqrt, tf.sqrt, tol=1e-3)
|
||||
self._compareBothSparse(x, np.tanh, tf.tanh)
|
||||
self._compareBothSparse(y, np.sign, tf.sign)
|
||||
self._compareBothSparse(x, np.vectorize(math.erf), tf.erf, tol=1e-3)
|
||||
|
||||
def testInt32Basic(self):
|
||||
x = np.arange(-6, 6, 2).reshape(1, 3, 2).astype(np.int32)
|
||||
|
@ -348,6 +348,25 @@ def sqrt(x, name=None):
|
||||
return gen_math_ops.sqrt(x, name=name)
|
||||
|
||||
|
||||
def erf(x, name=None):
|
||||
"""Computes the Gauss error function of `x` element-wise.
|
||||
|
||||
Args:
|
||||
x: A `Tensor` of `SparseTensor`. Must be one of the following types: `half`,
|
||||
`float32`, `float64`.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
|
||||
"""
|
||||
with ops.op_scope([x], name, "Erf") as name:
|
||||
if isinstance(x, ops.SparseTensor):
|
||||
x_erf = gen_math_ops.erf(x.values, name=name)
|
||||
return ops.SparseTensor(indices=x.indices, values=x_erf, shape=x.shape)
|
||||
else:
|
||||
return gen_math_ops.erf(x, name=name)
|
||||
|
||||
|
||||
def complex_abs(x, name=None):
|
||||
r"""Computes the complex absolute value of a tensor.
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user