Enable MLIR bridge for ops that are already supported
* Removed SameOperandsAndResultShape from ClipByValue op as the op has multiple operands and UnchangedShape function doesn't mean all shapes are equal. PiperOrigin-RevId: 310465081 Change-Id: Id56c0d67e1b70638c57df5f4d1e5da1875529064
This commit is contained in:
parent
491d6e42ce
commit
4d94fe13fa
@ -1217,7 +1217,7 @@ that are not a number (NaN) or infinity (Inf). Otherwise, passes `tensor` as-is.
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_ClipByValueOp : TF_Op<"ClipByValue", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
def TF_ClipByValueOp : TF_Op<"ClipByValue", [NoSideEffect]> {
|
||||
let summary = "Clips tensor values to a specified min and max.";
|
||||
|
||||
let description = [{
|
||||
@ -1682,6 +1682,27 @@ Given an input tensor, this function computes hyperbolic cosine of every
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_CrossOp : TF_Op<"Cross", [NoSideEffect]> {
|
||||
let summary = "Compute the pairwise cross product.";
|
||||
|
||||
let description = [{
|
||||
`a` and `b` must be the same shape; they can either be simple 3-element vectors,
|
||||
or any shape where the innermost dimension is 3. In the latter case, each pair
|
||||
of corresponding 3-element vectors is cross-multiplied independently.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_IntOrFpTensor:$a,
|
||||
TF_IntOrFpTensor:$b
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_IntOrFpTensor:$product
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_CrossReplicaSumOp : TF_Op<"CrossReplicaSum", [AllTypesMatch<["input", "output"]>, NoSideEffect]> {
|
||||
let summary = "An Op to sum inputs across replicated TPU instances.";
|
||||
|
||||
|
@ -100,8 +100,10 @@ static bool IsOpWhitelisted(Operation* op) {
|
||||
TypeID::get<TF::BitwiseOrOp>(),
|
||||
TypeID::get<TF::BitwiseXorOp>(),
|
||||
TypeID::get<TF::CastOp>(),
|
||||
TypeID::get<TF::ClipByValueOp>(),
|
||||
TypeID::get<TF::ComplexAbsOp>(),
|
||||
TypeID::get<TF::CoshOp>(),
|
||||
TypeID::get<TF::CrossOp>(),
|
||||
TypeID::get<TF::DataFormatDimMapOp>(),
|
||||
TypeID::get<TF::DataFormatVecPermuteOp>(),
|
||||
TypeID::get<TF::DigammaOp>(),
|
||||
|
@ -470,6 +470,7 @@ tf_xla_py_test(
|
||||
name = "concat_ops_test",
|
||||
size = "medium",
|
||||
srcs = ["concat_ops_test.py"],
|
||||
enable_mlir_bridge = True,
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"many_xla_args",
|
||||
@ -1342,6 +1343,7 @@ tf_xla_py_test(
|
||||
name = "ternary_ops_test",
|
||||
size = "medium",
|
||||
srcs = ["ternary_ops_test.py"],
|
||||
enable_mlir_bridge = True,
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
|
@ -1511,7 +1511,6 @@ class BinaryOpsTest(xla_test.XLATestCase):
|
||||
np.array([1, 0], dtype=np.int32),
|
||||
expected=np.array([[1 + 1j, 3 + 3j], [2 - 2j, 4 - 4j]], dtype=dtype))
|
||||
|
||||
@test_util.disable_mlir_bridge("Enable tf.Cross Compilation")
|
||||
def testCross(self):
|
||||
for dtype in self.float_types:
|
||||
self._testBinary(
|
||||
|
@ -23,6 +23,7 @@ import numpy as np
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_array_ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
@ -293,6 +294,7 @@ class ConcatTest(xla_test.XLATestCase):
|
||||
|
||||
# The purpose of this is to ensure that XLA on GPU will not run out of memory
|
||||
# with too many arguments.
|
||||
@test_util.disable_mlir_bridge("TODO(b/153895138): Debug.")
|
||||
def testConcatLargeNumberOfTensors(self):
|
||||
if "CPU" in self.device:
|
||||
self.skipTest("This test can time out on CPU, so we will just allow "
|
||||
|
@ -24,6 +24,7 @@ import scipy.special as sps
|
||||
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_math_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
@ -47,6 +48,7 @@ class TernaryOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
{'start': 1, 'end': 2, 'num': 1},
|
||||
{'start': 1, 'end': 4, 'num': 3},
|
||||
{'start': 0, 'end': 41, 'num': 42})
|
||||
@test_util.disable_mlir_bridge('Requires dynamic shape handling')
|
||||
def testLinspace(self, start, end, num):
|
||||
expected = np.linspace(start, end, num, dtype=np.float32)
|
||||
result = self._testTernary(
|
||||
@ -74,6 +76,7 @@ class TernaryOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
np.int32(2),
|
||||
expected=np.array([1, 3, 5], dtype=np.int32))
|
||||
|
||||
@test_util.disable_mlir_bridge('TODO(b/155949336)')
|
||||
def testSelect(self):
|
||||
for dtype in self.numeric_types:
|
||||
self._testTernary(
|
||||
@ -179,6 +182,7 @@ class TernaryOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
np.array([8, 9], dtype=dtype),
|
||||
expected=np.array([[7, 9], [8, 7], [8, 9]], dtype=dtype))
|
||||
|
||||
@test_util.disable_mlir_bridge('TODO(b/155097273)')
|
||||
def testSlice(self):
|
||||
for dtype in self.numeric_types:
|
||||
self._testTernary(
|
||||
@ -211,6 +215,7 @@ class TernaryOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
upper,
|
||||
expected=np.minimum(np.maximum(x, lower), upper))
|
||||
|
||||
@test_util.disable_mlir_bridge('Enable tf.Betainc Compilation')
|
||||
def testBetaincSanity(self):
|
||||
# This operation is only supported for float32 and float64.
|
||||
for dtype in self.numeric_types & {np.float32, np.float64}:
|
||||
@ -248,6 +253,7 @@ class TernaryOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
'atol': 2e-4
|
||||
},
|
||||
)
|
||||
@test_util.disable_mlir_bridge('Enable tf.Betainc Compilation')
|
||||
def testBetainc(self, sigma, rtol, atol):
|
||||
# This operation is only supported for float32 and float64.
|
||||
for dtype in self.numeric_types & {np.float32, np.float64}:
|
||||
|
@ -72,7 +72,6 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
np.array([7, 11], dtype=dtype)),
|
||||
expected=np.array([[8, 13], [10, 15]], dtype=dtype))
|
||||
|
||||
@test_util.disable_mlir_bridge('Not supported yet')
|
||||
def testBroadcast(self):
|
||||
for dtype in self.numeric_types:
|
||||
v = np.arange(4, dtype=np.int32).astype(dtype).reshape([2, 2])
|
||||
|
Loading…
Reference in New Issue
Block a user