From 34b49bb5b91d5e3902690597b5264710a39a9c99 Mon Sep 17 00:00:00 2001 From: Richard Uhler Date: Wed, 9 Dec 2020 16:31:54 -0800 Subject: [PATCH] Add lowering from tf.RiscAdd op to HLO_Add. PiperOrigin-RevId: 346663526 Change-Id: If21a1a88d27bb8c16e5a1a2331fc436646a94ee6 --- .../mlir/tensorflow/ir/tf_generated_ops.td | 23 ++++++++++ .../xla/transforms/legalize_tf_patterns.td | 2 + tensorflow/compiler/tests/BUILD | 16 +++++++ tensorflow/compiler/tests/risc_ops_test.py | 44 +++++++++++++++++++ tensorflow/python/ops/risc/risc_ops.py | 5 --- 5 files changed, 85 insertions(+), 5 deletions(-) create mode 100644 tensorflow/compiler/tests/risc_ops_test.py diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 90ff30c6653..81007dca35e 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -11645,6 +11645,29 @@ rint([-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]) ==> [-2., -2., -0., 0., 2., 2., 2.] TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_RiscAddOp : TF_Op<"RiscAdd", [Commutative, NoSideEffect]> { + let summary = "Returns x + y element-wise."; + + let description = [{ +*NOTE*: `RiscAdd` does not supports broadcasting. + +Given two input tensors, the `tf.risc_add` operation computes the sum for every element in the tensor. + +Both input and output have a range `(-inf, inf)`. + }]; + + let arguments = (ins + TF_FloatTensor:$x, + TF_FloatTensor:$y + ); + + let results = (outs + TF_FloatTensor:$z + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_RollOp : TF_Op<"Roll", [NoSideEffect]> { let summary = "Rolls the elements of a tensor along an axis."; diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td index f897351e6c2..f7d1f6641db 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td @@ -188,6 +188,8 @@ def : Pat<(TF_FloorModOp AnyTensor:$l, AnyTensor:$r), (HLOClient_BroadcastAddOp $r, $rem, (BinBroadcastDimensions $r, $rem)), $rem)>; +def : Pat<(TF_RiscAddOp $l, $r), (HLO_AddOp $l, $r)>; + //===----------------------------------------------------------------------===// // Logical & bitwise binary op patterns. //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index f0941d3621b..8472fdd9da1 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -1965,3 +1965,19 @@ tf_xla_py_test( "//tensorflow/python/compiler/xla:compiler_py", ], ) + +tf_xla_py_test( + name = "risc_ops_test", + size = "small", + srcs = ["risc_ops_test.py"], + enabled_backends = ["cpu"], + python_version = "PY3", + deps = [ + ":xla_test", + "//tensorflow/python:framework", + "//tensorflow/python:is_mlir_bridge_test_true", + "//tensorflow/python:platform_test", + "//tensorflow/python/eager:function", + "//tensorflow/python/ops/risc:risc_ops", + ], +) diff --git a/tensorflow/compiler/tests/risc_ops_test.py b/tensorflow/compiler/tests/risc_ops_test.py new file mode 100644 index 00000000000..0aa0936bd0f --- /dev/null +++ b/tensorflow/compiler/tests/risc_ops_test.py @@ -0,0 +1,44 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for RISC Ops.""" + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.eager import def_function +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops.risc import risc_ops +from tensorflow.python.platform import test + + +class XlaRiscOpsTest(xla_test.XLATestCase): + + def testRiscAddBasic(self): + + @def_function.function(jit_compile=True) + def f(a, b): + return risc_ops.risc_add(a, b) + + l1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], + dtype=dtypes.float32) + l2 = constant_op.constant([[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]], + dtype=dtypes.float32) + l = f(l1, l2) + self.assertAllEqual(l, [[8.0, 10.0], [12.0, 14.0], [16.0, 18.0]]) + + +if __name__ == "__main__": + ops.enable_eager_execution() + test.main() diff --git a/tensorflow/python/ops/risc/risc_ops.py b/tensorflow/python/ops/risc/risc_ops.py index 14cdb9e4b44..05a609402ff 100644 --- a/tensorflow/python/ops/risc/risc_ops.py +++ b/tensorflow/python/ops/risc/risc_ops.py @@ -21,11 +21,6 @@ from __future__ import print_function from tensorflow.python.ops import gen_risc_ops -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.python.ops.risc_ops_gen import * -# pylint: enable=wildcard-import - def risc_add( input_lhs,