Add int32 support to floor, ceil & rint
Though the result is trivial, this avoids the need to call tf.cast if receiving a int32 tensor from another operation. PiperOrigin-RevId: 283883439 Change-Id: I351206bd165fbf681f0231886ec131522ddf83ed
This commit is contained in:
parent
6f96b26d9d
commit
5666233dc7
@ -16,8 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/cwise_ops_common.h"
|
||||
|
||||
namespace tensorflow {
|
||||
REGISTER4(UnaryOp, CPU, "Ceil", functor::ceil, float, Eigen::half, double,
|
||||
int32);
|
||||
REGISTER3(UnaryOp, CPU, "Ceil", functor::ceil, float, Eigen::half, double);
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
REGISTER3(UnaryOp, GPU, "Ceil", functor::ceil, float, Eigen::half, double);
|
||||
|
@ -16,8 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/cwise_ops_common.h"
|
||||
|
||||
namespace tensorflow {
|
||||
REGISTER4(UnaryOp, CPU, "Floor", functor::floor, float, Eigen::half, double,
|
||||
int32);
|
||||
REGISTER3(UnaryOp, CPU, "Floor", functor::floor, float, Eigen::half, double);
|
||||
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
REGISTER3(UnaryOp, GPU, "Floor", functor::floor, float, Eigen::half, double);
|
||||
|
@ -16,7 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/kernels/cwise_ops_common.h"
|
||||
|
||||
namespace tensorflow {
|
||||
REGISTER3(UnaryOp, CPU, "Rint", functor::rint, float, double, int32);
|
||||
REGISTER2(UnaryOp, CPU, "Rint", functor::rint, float, double);
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
REGISTER2(UnaryOp, GPU, "Rint", functor::rint, float, double);
|
||||
#endif
|
||||
|
@ -349,19 +349,19 @@ REGISTER_OP("Sign")
|
||||
REGISTER_OP("Floor")
|
||||
.Input("x: T")
|
||||
.Output("y: T")
|
||||
.Attr("T: {bfloat16, half, float, double, int32}")
|
||||
.Attr("T: {bfloat16, half, float, double}")
|
||||
.SetShapeFn(shape_inference::UnchangedShape);
|
||||
|
||||
REGISTER_OP("Ceil")
|
||||
.Input("x: T")
|
||||
.Output("y: T")
|
||||
.Attr("T: {bfloat16, half, float, double, int32}")
|
||||
.Attr("T: {bfloat16, half, float, double}")
|
||||
.SetShapeFn(shape_inference::UnchangedShape);
|
||||
|
||||
REGISTER_OP("Rint")
|
||||
.Input("x: T")
|
||||
.Output("y: T")
|
||||
.Attr("T: {bfloat16, half, float, double, int32}")
|
||||
.Attr("T: {bfloat16, half, float, double}")
|
||||
.SetShapeFn(shape_inference::UnchangedShape);
|
||||
|
||||
// Declares cwise binary operations signature: 't, 't -> 't.
|
||||
|
@ -179,48 +179,6 @@ class RoundTest(test_util.TensorFlowTestCase):
|
||||
self.assertAllClose(y_tf_np, y_np, atol=1e-2)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class FloorTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testFloor(self):
|
||||
x = np.arange(-5.0, 5.0, .25)
|
||||
for dtype in [np.float32, np.double, np.int32]:
|
||||
x_np = np.array(x, dtype=dtype)
|
||||
x_tf = constant_op.constant(x_np, shape=x_np.shape)
|
||||
y_tf = math_ops.floor(x_tf)
|
||||
y_tf_np = self.evaluate(y_tf)
|
||||
y_np = np.floor(x_np)
|
||||
self.assertAllClose(y_tf_np, y_np, atol=1e-2)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class CeilTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testCeil(self):
|
||||
x = np.arange(-5.0, 5.0, .25)
|
||||
for dtype in [np.float32, np.double, np.int32]:
|
||||
x_np = np.array(x, dtype=dtype)
|
||||
x_tf = constant_op.constant(x_np, shape=x_np.shape)
|
||||
y_tf = math_ops.ceil(x_tf)
|
||||
y_tf_np = self.evaluate(y_tf)
|
||||
y_np = np.ceil(x_np)
|
||||
self.assertAllClose(y_tf_np, y_np, atol=1e-2)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RintTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testRint(self):
|
||||
x = np.arange(-5.0, 5.0, .25)
|
||||
for dtype in [np.float32, np.double, np.int32]:
|
||||
x_np = np.array(x, dtype=dtype)
|
||||
x_tf = constant_op.constant(x_np, shape=x_np.shape)
|
||||
y_tf = math_ops.rint(x_tf)
|
||||
y_tf_np = self.evaluate(y_tf)
|
||||
y_np = np.rint(x_np)
|
||||
self.assertAllClose(y_tf_np, y_np, atol=1e-2)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class ModTest(test_util.TensorFlowTestCase):
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user