Add lowering from tf.RiscAdd op to HLO_Add.

PiperOrigin-RevId: 346663526
Change-Id: If21a1a88d27bb8c16e5a1a2331fc436646a94ee6
This commit is contained in:
Richard Uhler 2020-12-09 16:31:54 -08:00 committed by TensorFlower Gardener
parent b48df5c1f3
commit 34b49bb5b9
5 changed files with 85 additions and 5 deletions

View File

@ -11645,6 +11645,29 @@ rint([-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]) ==> [-2., -2., -0., 0., 2., 2., 2.]
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
} }
def TF_RiscAddOp : TF_Op<"RiscAdd", [Commutative, NoSideEffect]> {
let summary = "Returns x + y element-wise.";
let description = [{
*NOTE*: `RiscAdd` does not supports broadcasting.
Given two input tensors, the `tf.risc_add` operation computes the sum for every element in the tensor.
Both input and output have a range `(-inf, inf)`.
}];
let arguments = (ins
TF_FloatTensor:$x,
TF_FloatTensor:$y
);
let results = (outs
TF_FloatTensor:$z
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_RollOp : TF_Op<"Roll", [NoSideEffect]> { def TF_RollOp : TF_Op<"Roll", [NoSideEffect]> {
let summary = "Rolls the elements of a tensor along an axis."; let summary = "Rolls the elements of a tensor along an axis.";

View File

@ -188,6 +188,8 @@ def : Pat<(TF_FloorModOp AnyTensor:$l, AnyTensor:$r),
(HLOClient_BroadcastAddOp $r, (HLOClient_BroadcastAddOp $r,
$rem, (BinBroadcastDimensions $r, $rem)), $rem)>; $rem, (BinBroadcastDimensions $r, $rem)), $rem)>;
def : Pat<(TF_RiscAddOp $l, $r), (HLO_AddOp $l, $r)>;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Logical & bitwise binary op patterns. // Logical & bitwise binary op patterns.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -1965,3 +1965,19 @@ tf_xla_py_test(
"//tensorflow/python/compiler/xla:compiler_py", "//tensorflow/python/compiler/xla:compiler_py",
], ],
) )
tf_xla_py_test(
name = "risc_ops_test",
size = "small",
srcs = ["risc_ops_test.py"],
enabled_backends = ["cpu"],
python_version = "PY3",
deps = [
":xla_test",
"//tensorflow/python:framework",
"//tensorflow/python:is_mlir_bridge_test_true",
"//tensorflow/python:platform_test",
"//tensorflow/python/eager:function",
"//tensorflow/python/ops/risc:risc_ops",
],
)

View File

@ -0,0 +1,44 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for RISC Ops."""
from tensorflow.compiler.tests import xla_test
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops.risc import risc_ops
from tensorflow.python.platform import test
class XlaRiscOpsTest(xla_test.XLATestCase):
def testRiscAddBasic(self):
@def_function.function(jit_compile=True)
def f(a, b):
return risc_ops.risc_add(a, b)
l1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
dtype=dtypes.float32)
l2 = constant_op.constant([[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]],
dtype=dtypes.float32)
l = f(l1, l2)
self.assertAllEqual(l, [[8.0, 10.0], [12.0, 14.0], [16.0, 18.0]])
if __name__ == "__main__":
ops.enable_eager_execution()
test.main()

View File

@ -21,11 +21,6 @@ from __future__ import print_function
from tensorflow.python.ops import gen_risc_ops from tensorflow.python.ops import gen_risc_ops
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.ops.risc_ops_gen import *
# pylint: enable=wildcard-import
def risc_add( def risc_add(
input_lhs, input_lhs,