From aa2cacd6627ffb296bedc910c957a0fd4a2f957f Mon Sep 17 00:00:00 2001 From: Siddharth Agrawal Date: Fri, 1 Jul 2016 01:59:14 +0530 Subject: [PATCH] Enable tf.erf() for SparseTensor (#3122) --- .../python/kernel_tests/cwise_ops_test.py | 4 ++++ tensorflow/python/ops/math_ops.py | 19 +++++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py index 7f1be574bbf..093da97469a 100644 --- a/tensorflow/python/kernel_tests/cwise_ops_test.py +++ b/tensorflow/python/kernel_tests/cwise_ops_test.py @@ -215,6 +215,7 @@ class UnaryOpTest(tf.test.TestCase): self._compareBothSparse(z, np.sqrt, tf.sqrt, tol=1e-3) self._compareBothSparse(x, np.tanh, tf.tanh) self._compareBothSparse(y, np.sign, tf.sign) + self._compareBothSparse(x, np.vectorize(math.erf), tf.erf) def testFloatTanhEdge(self): x = np.arange(40, 40 + 6).reshape(6).astype(np.float32) @@ -254,6 +255,7 @@ class UnaryOpTest(tf.test.TestCase): self._compareBothSparse(x, np.sqrt, tf.sqrt, tol=1e-3) self._compareBothSparse(x, np.tanh, tf.tanh) self._compareBothSparse(x, np.sign, tf.sign) + self._compareBothSparse(x, np.sign, tf.erf) def testDoubleBasic(self): x = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float64) @@ -292,6 +294,7 @@ class UnaryOpTest(tf.test.TestCase): self._compareBothSparse(z, np.sqrt, tf.sqrt, tol=1e-3) self._compareBothSparse(x, np.tanh, tf.tanh) self._compareBothSparse(y, np.sign, tf.sign) + self._compareBothSparse(x, np.vectorize(math.erf), tf.erf) def testHalfBasic(self): x = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float16) @@ -325,6 +328,7 @@ class UnaryOpTest(tf.test.TestCase): self._compareBothSparse(z, np.sqrt, tf.sqrt, tol=1e-3) self._compareBothSparse(x, np.tanh, tf.tanh) self._compareBothSparse(y, np.sign, tf.sign) + self._compareBothSparse(x, np.vectorize(math.erf), tf.erf, tol=1e-3) def testInt32Basic(self): x = np.arange(-6, 6, 2).reshape(1, 3, 2).astype(np.int32) diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 0a76450c5b1..38c7e515941 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -348,6 +348,25 @@ def sqrt(x, name=None): return gen_math_ops.sqrt(x, name=name) +def erf(x, name=None): + """Computes the Gauss error function of `x` element-wise. + + Args: + x: A `Tensor` of `SparseTensor`. Must be one of the following types: `half`, + `float32`, `float64`. + name: A name for the operation (optional). + + Returns: + A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. + """ + with ops.op_scope([x], name, "Erf") as name: + if isinstance(x, ops.SparseTensor): + x_erf = gen_math_ops.erf(x.values, name=name) + return ops.SparseTensor(indices=x.indices, values=x_erf, shape=x.shape) + else: + return gen_math_ops.erf(x, name=name) + + def complex_abs(x, name=None): r"""Computes the complex absolute value of a tensor.