Enable MLIR bridge for ops that are already supported

* Removed SameOperandsAndResultShape from ClipByValue op as the op has multiple operands and UnchangedShape function doesn't mean all shapes are equal.

PiperOrigin-RevId: 310465081
Change-Id: Id56c0d67e1b70638c57df5f4d1e5da1875529064
This commit is contained in:
Smit Hinsu 2020-05-07 17:04:56 -07:00 committed by TensorFlower Gardener
parent 491d6e42ce
commit 4d94fe13fa
7 changed files with 34 additions and 3 deletions

View File

@ -1217,7 +1217,7 @@ that are not a number (NaN) or infinity (Inf). Otherwise, passes `tensor` as-is.
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_ClipByValueOp : TF_Op<"ClipByValue", [NoSideEffect, SameOperandsAndResultType]> {
def TF_ClipByValueOp : TF_Op<"ClipByValue", [NoSideEffect]> {
let summary = "Clips tensor values to a specified min and max.";
let description = [{
@ -1682,6 +1682,27 @@ Given an input tensor, this function computes hyperbolic cosine of every
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_CrossOp : TF_Op<"Cross", [NoSideEffect]> {
let summary = "Compute the pairwise cross product.";
let description = [{
`a` and `b` must be the same shape; they can either be simple 3-element vectors,
or any shape where the innermost dimension is 3. In the latter case, each pair
of corresponding 3-element vectors is cross-multiplied independently.
}];
let arguments = (ins
TF_IntOrFpTensor:$a,
TF_IntOrFpTensor:$b
);
let results = (outs
TF_IntOrFpTensor:$product
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_CrossReplicaSumOp : TF_Op<"CrossReplicaSum", [AllTypesMatch<["input", "output"]>, NoSideEffect]> {
let summary = "An Op to sum inputs across replicated TPU instances.";

View File

@ -100,8 +100,10 @@ static bool IsOpWhitelisted(Operation* op) {
TypeID::get<TF::BitwiseOrOp>(),
TypeID::get<TF::BitwiseXorOp>(),
TypeID::get<TF::CastOp>(),
TypeID::get<TF::ClipByValueOp>(),
TypeID::get<TF::ComplexAbsOp>(),
TypeID::get<TF::CoshOp>(),
TypeID::get<TF::CrossOp>(),
TypeID::get<TF::DataFormatDimMapOp>(),
TypeID::get<TF::DataFormatVecPermuteOp>(),
TypeID::get<TF::DigammaOp>(),

View File

@ -470,6 +470,7 @@ tf_xla_py_test(
name = "concat_ops_test",
size = "medium",
srcs = ["concat_ops_test.py"],
enable_mlir_bridge = True,
python_version = "PY3",
tags = [
"many_xla_args",
@ -1342,6 +1343,7 @@ tf_xla_py_test(
name = "ternary_ops_test",
size = "medium",
srcs = ["ternary_ops_test.py"],
enable_mlir_bridge = True,
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip

View File

@ -1511,7 +1511,6 @@ class BinaryOpsTest(xla_test.XLATestCase):
np.array([1, 0], dtype=np.int32),
expected=np.array([[1 + 1j, 3 + 3j], [2 - 2j, 4 - 4j]], dtype=dtype))
@test_util.disable_mlir_bridge("Enable tf.Cross Compilation")
def testCross(self):
for dtype in self.float_types:
self._testBinary(

View File

@ -23,6 +23,7 @@ import numpy as np
from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gradients_impl
@ -293,6 +294,7 @@ class ConcatTest(xla_test.XLATestCase):
# The purpose of this is to ensure that XLA on GPU will not run out of memory
# with too many arguments.
@test_util.disable_mlir_bridge("TODO(b/153895138): Debug.")
def testConcatLargeNumberOfTensors(self):
if "CPU" in self.device:
self.skipTest("This test can time out on CPU, so we will just allow "

View File

@ -24,6 +24,7 @@ import scipy.special as sps
from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import math_ops
@ -47,6 +48,7 @@ class TernaryOpsTest(xla_test.XLATestCase, parameterized.TestCase):
{'start': 1, 'end': 2, 'num': 1},
{'start': 1, 'end': 4, 'num': 3},
{'start': 0, 'end': 41, 'num': 42})
@test_util.disable_mlir_bridge('Requires dynamic shape handling')
def testLinspace(self, start, end, num):
expected = np.linspace(start, end, num, dtype=np.float32)
result = self._testTernary(
@ -74,6 +76,7 @@ class TernaryOpsTest(xla_test.XLATestCase, parameterized.TestCase):
np.int32(2),
expected=np.array([1, 3, 5], dtype=np.int32))
@test_util.disable_mlir_bridge('TODO(b/155949336)')
def testSelect(self):
for dtype in self.numeric_types:
self._testTernary(
@ -179,6 +182,7 @@ class TernaryOpsTest(xla_test.XLATestCase, parameterized.TestCase):
np.array([8, 9], dtype=dtype),
expected=np.array([[7, 9], [8, 7], [8, 9]], dtype=dtype))
@test_util.disable_mlir_bridge('TODO(b/155097273)')
def testSlice(self):
for dtype in self.numeric_types:
self._testTernary(
@ -211,6 +215,7 @@ class TernaryOpsTest(xla_test.XLATestCase, parameterized.TestCase):
upper,
expected=np.minimum(np.maximum(x, lower), upper))
@test_util.disable_mlir_bridge('Enable tf.Betainc Compilation')
def testBetaincSanity(self):
# This operation is only supported for float32 and float64.
for dtype in self.numeric_types & {np.float32, np.float64}:
@ -248,6 +253,7 @@ class TernaryOpsTest(xla_test.XLATestCase, parameterized.TestCase):
'atol': 2e-4
},
)
@test_util.disable_mlir_bridge('Enable tf.Betainc Compilation')
def testBetainc(self, sigma, rtol, atol):
# This operation is only supported for float32 and float64.
for dtype in self.numeric_types & {np.float32, np.float64}:

View File

@ -72,7 +72,6 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
np.array([7, 11], dtype=dtype)),
expected=np.array([[8, 13], [10, 15]], dtype=dtype))
@test_util.disable_mlir_bridge('Not supported yet')
def testBroadcast(self):
for dtype in self.numeric_types:
v = np.arange(4, dtype=np.int32).astype(dtype).reshape([2, 2])