Enable unary TensorFlow ops in "xla-legalize-tf-with-tf2xla" pass
This enables unary ops in unary_ops_test.py that have kernel defined in tf2xla/kernels/unary_ops.cc and doesn't already have legalizations. Some tests are disabled if the op is not supported or either using unsigned int or complex constants. This also deletes unary_mlir_ops_test test now the old and new bridge tests are consolidated. PiperOrigin-RevId: 306570457 Change-Id: Idbaab1e8986c3659916ae0da28f60bf1960b9f4e
This commit is contained in:
parent
983cbdbe61
commit
016718b562
|
@ -77,9 +77,13 @@ static bool IsOpWhitelisted(Operation* op) {
|
||||||
// building valid MLIR using MlirHloBuilder.
|
// building valid MLIR using MlirHloBuilder.
|
||||||
// TODO(hinsu): Drop explicit whitelist when MLIR based bridge is enabled for
|
// TODO(hinsu): Drop explicit whitelist when MLIR based bridge is enabled for
|
||||||
// all tf2xla kernels.
|
// all tf2xla kernels.
|
||||||
return isa<TF::AbsOp>(op) || isa<TF::Atan2Op>(op) || isa<TF::CastOp>(op) ||
|
return isa<TF::AbsOp>(op) || isa<TF::Atan2Op>(op) ||
|
||||||
isa<TF::GreaterOp>(op) || isa<TF::InvOp>(op) ||
|
isa<TF::BiasAddGradOp>(op) || isa<TF::CastOp>(op) ||
|
||||||
isa<TF::SelectV2Op>(op);
|
isa<TF::ComplexAbsOp>(op) || isa<TF::GreaterOp>(op) ||
|
||||||
|
isa<TF::InvOp>(op) || isa<TF::InvertOp>(op) || isa<TF::LogOp>(op) ||
|
||||||
|
isa<TF::LogicalNotOp>(op) || isa<TF::NegOp>(op) ||
|
||||||
|
isa<TF::SelectV2Op>(op) || isa<TF::SinOp>(op) ||
|
||||||
|
isa<TF::SquareOp>(op) || isa<TF::UnpackOp>(op);
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::unique_ptr<tensorflow::StaticDeviceMgr> CreateDeviceMgr(
|
static std::unique_ptr<tensorflow::StaticDeviceMgr> CreateDeviceMgr(
|
||||||
|
|
|
@ -1346,26 +1346,6 @@ tf_xla_py_test(
|
||||||
name = "unary_ops_test",
|
name = "unary_ops_test",
|
||||||
size = "medium",
|
size = "medium",
|
||||||
srcs = ["unary_ops_test.py"],
|
srcs = ["unary_ops_test.py"],
|
||||||
python_version = "PY3",
|
|
||||||
tags = [
|
|
||||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
|
||||||
],
|
|
||||||
deps = [
|
|
||||||
":xla_test",
|
|
||||||
"//tensorflow/python:array_ops",
|
|
||||||
"//tensorflow/python:framework",
|
|
||||||
"//tensorflow/python:math_ops",
|
|
||||||
"//tensorflow/python:nn_ops",
|
|
||||||
"//tensorflow/python:nn_ops_gen",
|
|
||||||
"//tensorflow/python:platform_test",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO(hinsu): Combine this test with unary_ops_test instead of replicating it.
|
|
||||||
tf_xla_py_test(
|
|
||||||
name = "unary_mlir_ops_test",
|
|
||||||
size = "medium",
|
|
||||||
srcs = ["unary_mlir_ops_test.py"],
|
|
||||||
enable_mlir_bridge = True,
|
enable_mlir_bridge = True,
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
tags = [
|
tags = [
|
||||||
|
|
|
@ -1,73 +0,0 @@
|
||||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
"""Tests for XLA JIT compiler."""
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from tensorflow.compiler.tests import xla_test
|
|
||||||
from tensorflow.python.framework import dtypes
|
|
||||||
from tensorflow.python.ops import array_ops
|
|
||||||
from tensorflow.python.ops import math_ops
|
|
||||||
from tensorflow.python.platform import googletest
|
|
||||||
|
|
||||||
|
|
||||||
class UnaryOpsTest(xla_test.XLATestCase):
|
|
||||||
"""Test cases for unary operators."""
|
|
||||||
|
|
||||||
def _assertOpOutputMatchesExpected(self,
|
|
||||||
op,
|
|
||||||
inp,
|
|
||||||
expected,
|
|
||||||
equality_test=None,
|
|
||||||
rtol=1e-3,
|
|
||||||
atol=1e-5):
|
|
||||||
"""Verifies that 'op' produces 'expected' when fed input 'inp' .
|
|
||||||
|
|
||||||
Args:
|
|
||||||
op: operator to test
|
|
||||||
inp: numpy input array to use as input to 'op'.
|
|
||||||
expected: numpy array representing the expected output of 'op'.
|
|
||||||
equality_test: either None, or a function that tests two numpy arrays for
|
|
||||||
equality. If None, self.assertAllClose is used.
|
|
||||||
rtol: relative tolerance for equality test.
|
|
||||||
atol: absolute tolerance for equality test.
|
|
||||||
"""
|
|
||||||
with self.session() as session:
|
|
||||||
with self.test_scope():
|
|
||||||
pinp = array_ops.placeholder(
|
|
||||||
dtypes.as_dtype(inp.dtype), inp.shape, name='a')
|
|
||||||
output = op(pinp)
|
|
||||||
result = session.run(output, {pinp: inp})
|
|
||||||
if equality_test is None:
|
|
||||||
self.assertEqual(output.dtype, expected.dtype)
|
|
||||||
self.assertAllCloseAccordingToType(
|
|
||||||
expected, result, rtol=rtol, atol=atol, bfloat16_rtol=0.03)
|
|
||||||
else:
|
|
||||||
equality_test(result, expected, rtol=rtol, atol=atol)
|
|
||||||
|
|
||||||
def testNumericOps(self):
|
|
||||||
for dtype in self.numeric_types - {np.int8, np.uint8}:
|
|
||||||
self._assertOpOutputMatchesExpected(
|
|
||||||
math_ops.abs,
|
|
||||||
np.array([[2, -1]], dtype=dtype),
|
|
||||||
expected=np.array([[2, 1]], dtype=np.real(dtype(0)).dtype))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
googletest.main()
|
|
|
@ -25,6 +25,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin
|
||||||
|
|
||||||
from tensorflow.compiler.tests import xla_test
|
from tensorflow.compiler.tests import xla_test
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import bitwise_ops
|
from tensorflow.python.ops import bitwise_ops
|
||||||
from tensorflow.python.ops import gen_nn_ops
|
from tensorflow.python.ops import gen_nn_ops
|
||||||
|
@ -84,6 +85,8 @@ class UnaryOpsTest(xla_test.XLATestCase):
|
||||||
for i in xrange(len(result)):
|
for i in xrange(len(result)):
|
||||||
self.assertAllClose(result[i], expected[i], rtol, atol)
|
self.assertAllClose(result[i], expected[i], rtol, atol)
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge(
|
||||||
|
"MlirHloBuilder::Iota missing required for xla::Diag")
|
||||||
def testAllTypeOps(self):
|
def testAllTypeOps(self):
|
||||||
for dtype in self.numeric_types - {np.int8, np.uint8}:
|
for dtype in self.numeric_types - {np.int8, np.uint8}:
|
||||||
self._assertOpOutputMatchesExpected(
|
self._assertOpOutputMatchesExpected(
|
||||||
|
@ -183,6 +186,8 @@ class UnaryOpsTest(xla_test.XLATestCase):
|
||||||
self._assertOpOutputMatchesExpected(
|
self._assertOpOutputMatchesExpected(
|
||||||
math_ops.cos, x, expected=np.cos(x), rtol=tol, atol=1e-5)
|
math_ops.cos, x, expected=np.cos(x), rtol=tol, atol=1e-5)
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge(
|
||||||
|
"TODO(b/153812660): Handle tf.Softmax compilation")
|
||||||
def testFloatOps(self):
|
def testFloatOps(self):
|
||||||
for dtype in self.float_types:
|
for dtype in self.float_types:
|
||||||
x = np.arange(-0.90, 0.90, 0.25)
|
x = np.arange(-0.90, 0.90, 0.25)
|
||||||
|
@ -593,6 +598,8 @@ class UnaryOpsTest(xla_test.XLATestCase):
|
||||||
np.array([-1, -0.5, 0, 0.3], dtype=dtype),
|
np.array([-1, -0.5, 0, 0.3], dtype=dtype),
|
||||||
expected=np.array([-1., -0.5, 0., 0.296875], dtype=dtype))
|
expected=np.array([-1., -0.5, 0., 0.296875], dtype=dtype))
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge(
|
||||||
|
"Complex types not supported in CreateDenseElementsAttrFromLiteral")
|
||||||
def testComplexOps(self):
|
def testComplexOps(self):
|
||||||
for dtype in self.complex_types:
|
for dtype in self.complex_types:
|
||||||
|
|
||||||
|
@ -750,6 +757,7 @@ class UnaryOpsTest(xla_test.XLATestCase):
|
||||||
np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype),
|
np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype),
|
||||||
expected=np.array([1, -4, 2.7, 0], dtype=ctypes[dtype]))
|
expected=np.array([1, -4, 2.7, 0], dtype=ctypes[dtype]))
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge("TODO(b/153896312): Handle unsigned ints")
|
||||||
def testIntOps(self):
|
def testIntOps(self):
|
||||||
for dtype in self.int_types:
|
for dtype in self.int_types:
|
||||||
self._assertOpOutputMatchesExpected(
|
self._assertOpOutputMatchesExpected(
|
||||||
|
@ -823,6 +831,8 @@ class UnaryOpsTest(xla_test.XLATestCase):
|
||||||
[[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]]], dtype=np.float32),
|
[[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]]], dtype=np.float32),
|
||||||
expected=np.array([14., 22.], dtype=np.float32))
|
expected=np.array([14., 22.], dtype=np.float32))
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge("TODO(b/153812660): Handle tf.Cast compilation"
|
||||||
|
)
|
||||||
def testCast(self):
|
def testCast(self):
|
||||||
shapes = [[], [4], [2, 3], [2, 0, 4]]
|
shapes = [[], [4], [2, 3], [2, 0, 4]]
|
||||||
types = {
|
types = {
|
||||||
|
@ -870,6 +880,8 @@ class UnaryOpsTest(xla_test.XLATestCase):
|
||||||
src,
|
src,
|
||||||
expected=dst)
|
expected=dst)
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge(
|
||||||
|
"TODO(b/153812660): Handle tf.Bitcast compilation")
|
||||||
def testBitcast(self):
|
def testBitcast(self):
|
||||||
self._assertOpOutputMatchesExpected(
|
self._assertOpOutputMatchesExpected(
|
||||||
lambda x: array_ops.bitcast(x, dtypes.int32),
|
lambda x: array_ops.bitcast(x, dtypes.int32),
|
||||||
|
@ -893,12 +905,16 @@ class UnaryOpsTest(xla_test.XLATestCase):
|
||||||
np.array([1, 0x100000003f800000], np.int64),
|
np.array([1, 0x100000003f800000], np.int64),
|
||||||
expected=np.array([1, 0x100000003f800000], np.uint64))
|
expected=np.array([1, 0x100000003f800000], np.uint64))
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge(
|
||||||
|
"TODO(b/153812660): Handle tf.InvertPermutation compilation")
|
||||||
def testInvertPermutation(self):
|
def testInvertPermutation(self):
|
||||||
self._assertOpOutputMatchesExpected(
|
self._assertOpOutputMatchesExpected(
|
||||||
array_ops.invert_permutation,
|
array_ops.invert_permutation,
|
||||||
np.array([1, 2, 0], np.int32),
|
np.array([1, 2, 0], np.int32),
|
||||||
expected=np.array([2, 0, 1], dtype=np.int32))
|
expected=np.array([2, 0, 1], dtype=np.int32))
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge(
|
||||||
|
"TODO(b/153812660): Handle tf.InvertPermutation compilation")
|
||||||
def testInvertPermutationTwiceIsNoop(self):
|
def testInvertPermutationTwiceIsNoop(self):
|
||||||
self._assertOpOutputMatchesExpected(
|
self._assertOpOutputMatchesExpected(
|
||||||
lambda x: array_ops.invert_permutation(array_ops.invert_permutation(x)),
|
lambda x: array_ops.invert_permutation(array_ops.invert_permutation(x)),
|
||||||
|
@ -990,6 +1006,8 @@ class UnaryOpsTest(xla_test.XLATestCase):
|
||||||
],
|
],
|
||||||
equality_test=self.ListsAreClose)
|
equality_test=self.ListsAreClose)
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge(
|
||||||
|
"TODO(b/153812660): Handle tf.DepthToSpace compilation")
|
||||||
def testDepthToSpace(self):
|
def testDepthToSpace(self):
|
||||||
|
|
||||||
def make_op(data_format):
|
def make_op(data_format):
|
||||||
|
@ -1042,6 +1060,8 @@ class UnaryOpsTest(xla_test.XLATestCase):
|
||||||
[[[6, 7], [14, 15]], [[22, 23], [30, 31]]]]],
|
[[[6, 7], [14, 15]], [[22, 23], [30, 31]]]]],
|
||||||
dtype=dtype))
|
dtype=dtype))
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge(
|
||||||
|
"TODO(b/153812660): Handle tf.SpaceToDepth compilation")
|
||||||
def testSpaceToDepth(self):
|
def testSpaceToDepth(self):
|
||||||
|
|
||||||
def make_op(data_format):
|
def make_op(data_format):
|
||||||
|
@ -1101,6 +1121,8 @@ class UnaryOpsTest(xla_test.XLATestCase):
|
||||||
self._assertOpOutputMatchesExpected(
|
self._assertOpOutputMatchesExpected(
|
||||||
nn_ops.softplus, features, expected=expected, rtol=1e-6, atol=9.1e-6)
|
nn_ops.softplus, features, expected=expected, rtol=1e-6, atol=9.1e-6)
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge(
|
||||||
|
"bf16 type not supported in CreateDenseElementsAttrFromLiteral")
|
||||||
def testSoftplus(self):
|
def testSoftplus(self):
|
||||||
for dtype in self.float_types:
|
for dtype in self.float_types:
|
||||||
self._assertSoftplusMatchesExpected([[-2, 0, 8]], dtype)
|
self._assertSoftplusMatchesExpected([[-2, 0, 8]], dtype)
|
||||||
|
|
|
@ -1738,18 +1738,15 @@ def enable_tf_xla_constant_folding(description):
|
||||||
return enable_tf_xla_constant_folding_impl
|
return enable_tf_xla_constant_folding_impl
|
||||||
|
|
||||||
|
|
||||||
# The description is just for documentation purposes.
|
# Updates test function by selectively disabling it.
|
||||||
def disable_xla(description):
|
def _disable_test(execute_func):
|
||||||
|
|
||||||
def disable_xla_impl(func):
|
def disable_test_impl(func):
|
||||||
"""Execute the test method only if xla is not enabled."""
|
|
||||||
|
|
||||||
def decorator(func):
|
def decorator(func):
|
||||||
|
|
||||||
def decorated(self, *args, **kwargs):
|
def decorated(self, *args, **kwargs):
|
||||||
if is_xla_enabled():
|
if execute_func:
|
||||||
return
|
|
||||||
else:
|
|
||||||
return func(self, *args, **kwargs)
|
return func(self, *args, **kwargs)
|
||||||
|
|
||||||
return decorated
|
return decorated
|
||||||
|
@ -1759,7 +1756,21 @@ def disable_xla(description):
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
return disable_xla_impl
|
return disable_test_impl
|
||||||
|
|
||||||
|
|
||||||
|
# The description is just for documentation purposes.
|
||||||
|
def disable_xla(description): # pylint: disable=unused-argument
|
||||||
|
"""Execute the test method only if xla is not enabled."""
|
||||||
|
execute_func = not is_xla_enabled()
|
||||||
|
return _disable_test(execute_func)
|
||||||
|
|
||||||
|
|
||||||
|
# The description is just for documentation purposes.
|
||||||
|
def disable_mlir_bridge(description): # pylint: disable=unused-argument
|
||||||
|
"""Execute the test method only if MLIR bridge is not enabled."""
|
||||||
|
execute_func = not is_mlir_bridge_enabled()
|
||||||
|
return _disable_test(execute_func)
|
||||||
|
|
||||||
|
|
||||||
def for_all_test_methods(decorator, *args, **kwargs):
|
def for_all_test_methods(decorator, *args, **kwargs):
|
||||||
|
@ -1791,27 +1802,9 @@ def for_all_test_methods(decorator, *args, **kwargs):
|
||||||
|
|
||||||
# The description is just for documentation purposes.
|
# The description is just for documentation purposes.
|
||||||
def no_xla_auto_jit(description): # pylint: disable=unused-argument
|
def no_xla_auto_jit(description): # pylint: disable=unused-argument
|
||||||
|
|
||||||
def no_xla_auto_jit_impl(func):
|
|
||||||
"""This test is not intended to be run with XLA auto jit enabled."""
|
"""This test is not intended to be run with XLA auto jit enabled."""
|
||||||
|
execute_func = not is_xla_enabled()
|
||||||
def decorator(func):
|
return _disable_test(execute_func)
|
||||||
|
|
||||||
def decorated(self, *args, **kwargs):
|
|
||||||
if is_xla_enabled():
|
|
||||||
# Skip test if using XLA is forced.
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
return func(self, *args, **kwargs)
|
|
||||||
|
|
||||||
return decorated
|
|
||||||
|
|
||||||
if func is not None:
|
|
||||||
return decorator(func)
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
return no_xla_auto_jit_impl
|
|
||||||
|
|
||||||
|
|
||||||
# The description is just for documentation purposes.
|
# The description is just for documentation purposes.
|
||||||
|
|
Loading…
Reference in New Issue