Add lowering from tf.RiscAdd op to HLO_Add.
PiperOrigin-RevId: 346663526 Change-Id: If21a1a88d27bb8c16e5a1a2331fc436646a94ee6
This commit is contained in:
parent
b48df5c1f3
commit
34b49bb5b9
@ -11645,6 +11645,29 @@ rint([-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]) ==> [-2., -2., -0., 0., 2., 2., 2.]
|
|||||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def TF_RiscAddOp : TF_Op<"RiscAdd", [Commutative, NoSideEffect]> {
|
||||||
|
let summary = "Returns x + y element-wise.";
|
||||||
|
|
||||||
|
let description = [{
|
||||||
|
*NOTE*: `RiscAdd` does not supports broadcasting.
|
||||||
|
|
||||||
|
Given two input tensors, the `tf.risc_add` operation computes the sum for every element in the tensor.
|
||||||
|
|
||||||
|
Both input and output have a range `(-inf, inf)`.
|
||||||
|
}];
|
||||||
|
|
||||||
|
let arguments = (ins
|
||||||
|
TF_FloatTensor:$x,
|
||||||
|
TF_FloatTensor:$y
|
||||||
|
);
|
||||||
|
|
||||||
|
let results = (outs
|
||||||
|
TF_FloatTensor:$z
|
||||||
|
);
|
||||||
|
|
||||||
|
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||||
|
}
|
||||||
|
|
||||||
def TF_RollOp : TF_Op<"Roll", [NoSideEffect]> {
|
def TF_RollOp : TF_Op<"Roll", [NoSideEffect]> {
|
||||||
let summary = "Rolls the elements of a tensor along an axis.";
|
let summary = "Rolls the elements of a tensor along an axis.";
|
||||||
|
|
||||||
|
@ -188,6 +188,8 @@ def : Pat<(TF_FloorModOp AnyTensor:$l, AnyTensor:$r),
|
|||||||
(HLOClient_BroadcastAddOp $r,
|
(HLOClient_BroadcastAddOp $r,
|
||||||
$rem, (BinBroadcastDimensions $r, $rem)), $rem)>;
|
$rem, (BinBroadcastDimensions $r, $rem)), $rem)>;
|
||||||
|
|
||||||
|
def : Pat<(TF_RiscAddOp $l, $r), (HLO_AddOp $l, $r)>;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Logical & bitwise binary op patterns.
|
// Logical & bitwise binary op patterns.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -1965,3 +1965,19 @@ tf_xla_py_test(
|
|||||||
"//tensorflow/python/compiler/xla:compiler_py",
|
"//tensorflow/python/compiler/xla:compiler_py",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_xla_py_test(
|
||||||
|
name = "risc_ops_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["risc_ops_test.py"],
|
||||||
|
enabled_backends = ["cpu"],
|
||||||
|
python_version = "PY3",
|
||||||
|
deps = [
|
||||||
|
":xla_test",
|
||||||
|
"//tensorflow/python:framework",
|
||||||
|
"//tensorflow/python:is_mlir_bridge_test_true",
|
||||||
|
"//tensorflow/python:platform_test",
|
||||||
|
"//tensorflow/python/eager:function",
|
||||||
|
"//tensorflow/python/ops/risc:risc_ops",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
44
tensorflow/compiler/tests/risc_ops_test.py
Normal file
44
tensorflow/compiler/tests/risc_ops_test.py
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for RISC Ops."""
|
||||||
|
|
||||||
|
from tensorflow.compiler.tests import xla_test
|
||||||
|
from tensorflow.python.eager import def_function
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.ops.risc import risc_ops
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
class XlaRiscOpsTest(xla_test.XLATestCase):
|
||||||
|
|
||||||
|
def testRiscAddBasic(self):
|
||||||
|
|
||||||
|
@def_function.function(jit_compile=True)
|
||||||
|
def f(a, b):
|
||||||
|
return risc_ops.risc_add(a, b)
|
||||||
|
|
||||||
|
l1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
|
||||||
|
dtype=dtypes.float32)
|
||||||
|
l2 = constant_op.constant([[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]],
|
||||||
|
dtype=dtypes.float32)
|
||||||
|
l = f(l1, l2)
|
||||||
|
self.assertAllEqual(l, [[8.0, 10.0], [12.0, 14.0], [16.0, 18.0]])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
ops.enable_eager_execution()
|
||||||
|
test.main()
|
@ -21,11 +21,6 @@ from __future__ import print_function
|
|||||||
|
|
||||||
from tensorflow.python.ops import gen_risc_ops
|
from tensorflow.python.ops import gen_risc_ops
|
||||||
|
|
||||||
# go/tf-wildcard-import
|
|
||||||
# pylint: disable=wildcard-import
|
|
||||||
from tensorflow.python.ops.risc_ops_gen import *
|
|
||||||
# pylint: enable=wildcard-import
|
|
||||||
|
|
||||||
|
|
||||||
def risc_add(
|
def risc_add(
|
||||||
input_lhs,
|
input_lhs,
|
||||||
|
Loading…
Reference in New Issue
Block a user