diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc index eb6fe2e98b4..d0949a08eaf 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc @@ -77,9 +77,13 @@ static bool IsOpWhitelisted(Operation* op) { // building valid MLIR using MlirHloBuilder. // TODO(hinsu): Drop explicit whitelist when MLIR based bridge is enabled for // all tf2xla kernels. - return isa(op) || isa(op) || isa(op) || - isa(op) || isa(op) || - isa(op); + return isa(op) || isa(op) || + isa(op) || isa(op) || + isa(op) || isa(op) || + isa(op) || isa(op) || isa(op) || + isa(op) || isa(op) || + isa(op) || isa(op) || + isa(op) || isa(op); } static std::unique_ptr CreateDeviceMgr( diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 1ee25813320..5325addc8df 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -1346,26 +1346,6 @@ tf_xla_py_test( name = "unary_ops_test", size = "medium", srcs = ["unary_ops_test.py"], - python_version = "PY3", - tags = [ - "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip - ], - deps = [ - ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework", - "//tensorflow/python:math_ops", - "//tensorflow/python:nn_ops", - "//tensorflow/python:nn_ops_gen", - "//tensorflow/python:platform_test", - ], -) - -# TODO(hinsu): Combine this test with unary_ops_test instead of replicating it. -tf_xla_py_test( - name = "unary_mlir_ops_test", - size = "medium", - srcs = ["unary_mlir_ops_test.py"], enable_mlir_bridge = True, python_version = "PY3", tags = [ diff --git a/tensorflow/compiler/tests/unary_mlir_ops_test.py b/tensorflow/compiler/tests/unary_mlir_ops_test.py deleted file mode 100644 index 4238877c761..00000000000 --- a/tensorflow/compiler/tests/unary_mlir_ops_test.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright 2020 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for XLA JIT compiler.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.compiler.tests import xla_test -from tensorflow.python.framework import dtypes -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.platform import googletest - - -class UnaryOpsTest(xla_test.XLATestCase): - """Test cases for unary operators.""" - - def _assertOpOutputMatchesExpected(self, - op, - inp, - expected, - equality_test=None, - rtol=1e-3, - atol=1e-5): - """Verifies that 'op' produces 'expected' when fed input 'inp' . - - Args: - op: operator to test - inp: numpy input array to use as input to 'op'. - expected: numpy array representing the expected output of 'op'. - equality_test: either None, or a function that tests two numpy arrays for - equality. If None, self.assertAllClose is used. - rtol: relative tolerance for equality test. - atol: absolute tolerance for equality test. - """ - with self.session() as session: - with self.test_scope(): - pinp = array_ops.placeholder( - dtypes.as_dtype(inp.dtype), inp.shape, name='a') - output = op(pinp) - result = session.run(output, {pinp: inp}) - if equality_test is None: - self.assertEqual(output.dtype, expected.dtype) - self.assertAllCloseAccordingToType( - expected, result, rtol=rtol, atol=atol, bfloat16_rtol=0.03) - else: - equality_test(result, expected, rtol=rtol, atol=atol) - - def testNumericOps(self): - for dtype in self.numeric_types - {np.int8, np.uint8}: - self._assertOpOutputMatchesExpected( - math_ops.abs, - np.array([[2, -1]], dtype=dtype), - expected=np.array([[2, 1]], dtype=np.real(dtype(0)).dtype)) - - -if __name__ == '__main__': - googletest.main() diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index a9f5a5e743d..cd9ba983785 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -25,6 +25,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import bitwise_ops from tensorflow.python.ops import gen_nn_ops @@ -84,6 +85,8 @@ class UnaryOpsTest(xla_test.XLATestCase): for i in xrange(len(result)): self.assertAllClose(result[i], expected[i], rtol, atol) + @test_util.disable_mlir_bridge( + "MlirHloBuilder::Iota missing required for xla::Diag") def testAllTypeOps(self): for dtype in self.numeric_types - {np.int8, np.uint8}: self._assertOpOutputMatchesExpected( @@ -183,6 +186,8 @@ class UnaryOpsTest(xla_test.XLATestCase): self._assertOpOutputMatchesExpected( math_ops.cos, x, expected=np.cos(x), rtol=tol, atol=1e-5) + @test_util.disable_mlir_bridge( + "TODO(b/153812660): Handle tf.Softmax compilation") def testFloatOps(self): for dtype in self.float_types: x = np.arange(-0.90, 0.90, 0.25) @@ -593,6 +598,8 @@ class UnaryOpsTest(xla_test.XLATestCase): np.array([-1, -0.5, 0, 0.3], dtype=dtype), expected=np.array([-1., -0.5, 0., 0.296875], dtype=dtype)) + @test_util.disable_mlir_bridge( + "Complex types not supported in CreateDenseElementsAttrFromLiteral") def testComplexOps(self): for dtype in self.complex_types: @@ -750,6 +757,7 @@ class UnaryOpsTest(xla_test.XLATestCase): np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype), expected=np.array([1, -4, 2.7, 0], dtype=ctypes[dtype])) + @test_util.disable_mlir_bridge("TODO(b/153896312): Handle unsigned ints") def testIntOps(self): for dtype in self.int_types: self._assertOpOutputMatchesExpected( @@ -823,6 +831,8 @@ class UnaryOpsTest(xla_test.XLATestCase): [[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]]], dtype=np.float32), expected=np.array([14., 22.], dtype=np.float32)) + @test_util.disable_mlir_bridge("TODO(b/153812660): Handle tf.Cast compilation" + ) def testCast(self): shapes = [[], [4], [2, 3], [2, 0, 4]] types = { @@ -870,6 +880,8 @@ class UnaryOpsTest(xla_test.XLATestCase): src, expected=dst) + @test_util.disable_mlir_bridge( + "TODO(b/153812660): Handle tf.Bitcast compilation") def testBitcast(self): self._assertOpOutputMatchesExpected( lambda x: array_ops.bitcast(x, dtypes.int32), @@ -893,12 +905,16 @@ class UnaryOpsTest(xla_test.XLATestCase): np.array([1, 0x100000003f800000], np.int64), expected=np.array([1, 0x100000003f800000], np.uint64)) + @test_util.disable_mlir_bridge( + "TODO(b/153812660): Handle tf.InvertPermutation compilation") def testInvertPermutation(self): self._assertOpOutputMatchesExpected( array_ops.invert_permutation, np.array([1, 2, 0], np.int32), expected=np.array([2, 0, 1], dtype=np.int32)) + @test_util.disable_mlir_bridge( + "TODO(b/153812660): Handle tf.InvertPermutation compilation") def testInvertPermutationTwiceIsNoop(self): self._assertOpOutputMatchesExpected( lambda x: array_ops.invert_permutation(array_ops.invert_permutation(x)), @@ -990,6 +1006,8 @@ class UnaryOpsTest(xla_test.XLATestCase): ], equality_test=self.ListsAreClose) + @test_util.disable_mlir_bridge( + "TODO(b/153812660): Handle tf.DepthToSpace compilation") def testDepthToSpace(self): def make_op(data_format): @@ -1042,6 +1060,8 @@ class UnaryOpsTest(xla_test.XLATestCase): [[[6, 7], [14, 15]], [[22, 23], [30, 31]]]]], dtype=dtype)) + @test_util.disable_mlir_bridge( + "TODO(b/153812660): Handle tf.SpaceToDepth compilation") def testSpaceToDepth(self): def make_op(data_format): @@ -1101,6 +1121,8 @@ class UnaryOpsTest(xla_test.XLATestCase): self._assertOpOutputMatchesExpected( nn_ops.softplus, features, expected=expected, rtol=1e-6, atol=9.1e-6) + @test_util.disable_mlir_bridge( + "bf16 type not supported in CreateDenseElementsAttrFromLiteral") def testSoftplus(self): for dtype in self.float_types: self._assertSoftplusMatchesExpected([[-2, 0, 8]], dtype) diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 9b423bf10c5..f3d5900680c 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -1738,18 +1738,15 @@ def enable_tf_xla_constant_folding(description): return enable_tf_xla_constant_folding_impl -# The description is just for documentation purposes. -def disable_xla(description): +# Updates test function by selectively disabling it. +def _disable_test(execute_func): - def disable_xla_impl(func): - """Execute the test method only if xla is not enabled.""" + def disable_test_impl(func): def decorator(func): def decorated(self, *args, **kwargs): - if is_xla_enabled(): - return - else: + if execute_func: return func(self, *args, **kwargs) return decorated @@ -1759,7 +1756,21 @@ def disable_xla(description): return decorator - return disable_xla_impl + return disable_test_impl + + +# The description is just for documentation purposes. +def disable_xla(description): # pylint: disable=unused-argument + """Execute the test method only if xla is not enabled.""" + execute_func = not is_xla_enabled() + return _disable_test(execute_func) + + +# The description is just for documentation purposes. +def disable_mlir_bridge(description): # pylint: disable=unused-argument + """Execute the test method only if MLIR bridge is not enabled.""" + execute_func = not is_mlir_bridge_enabled() + return _disable_test(execute_func) def for_all_test_methods(decorator, *args, **kwargs): @@ -1791,27 +1802,9 @@ def for_all_test_methods(decorator, *args, **kwargs): # The description is just for documentation purposes. def no_xla_auto_jit(description): # pylint: disable=unused-argument - - def no_xla_auto_jit_impl(func): - """This test is not intended to be run with XLA auto jit enabled.""" - - def decorator(func): - - def decorated(self, *args, **kwargs): - if is_xla_enabled(): - # Skip test if using XLA is forced. - return - else: - return func(self, *args, **kwargs) - - return decorated - - if func is not None: - return decorator(func) - - return decorator - - return no_xla_auto_jit_impl + """This test is not intended to be run with XLA auto jit enabled.""" + execute_func = not is_xla_enabled() + return _disable_test(execute_func) # The description is just for documentation purposes.