STT-tensorflow/tensorflow/python/compiler/xla/xla_test.py
Xiao Yu b8cd771a05 Disable a few failing tests in tensorflow.v2_tfrt and fix tensorflow/python/ops/ragged:ragged_map_fn_op_test.
PiperOrigin-RevId: 351449157
Change-Id: Ie6486b9c6e13e2f10143c403350aaca84d1b3441
2021-01-12 14:56:59 -08:00

344 lines
13 KiB
Python

# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Tests for python.compiler.xla.xla."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python import summary
from tensorflow.python.compiler.xla import xla
from tensorflow.python.eager import def_function
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
from tensorflow.python.tpu import tpu_feed
_TRAIN = model_fn_lib.ModeKeys.TRAIN
_EVAL = model_fn_lib.ModeKeys.EVAL
_EXPECTED_LOSS = 1
_EXPECTED_FEATURE = 2
_EXPECTED_LABEL = 3
class XLACompileContextTest(test.TestCase, parameterized.TestCase):
def create_test_xla_compile_context(self):
computation_name = ops.get_default_graph().unique_name('computation')
pivot = control_flow_ops.no_op(name=computation_name + '/pivot')
return xla.XLACompileContext(name=computation_name, pivot=pivot)
@test_util.run_v1_only('Testing graph mode behavior only')
def test_report_unsupported_operations_graph_mode(self):
"""Tests that unsupported operations are detected."""
context = self.create_test_xla_compile_context()
context.Enter()
dummy_tensor = constant_op.constant(1.1)
audio_summary = summary.audio('audio_summary', dummy_tensor, 0.5)
histogram_summary = summary.histogram('histogram_summary', dummy_tensor)
image_summary = summary.image('image_summary', dummy_tensor)
scalar_summary = summary.scalar('scalar_summary', dummy_tensor)
tensor_summary = summary.tensor_summary('tensor_summary', dummy_tensor)
summary.merge(
[
audio_summary, histogram_summary, image_summary, scalar_summary,
tensor_summary
],
name='merge_summary')
logging_ops.Print(dummy_tensor, [dummy_tensor], name='print_op')
context.Exit()
unsupported_ops_names = [op.name for op in context._unsupported_ops]
self.assertEqual(unsupported_ops_names, [
u'audio_summary', u'histogram_summary', u'image_summary',
u'scalar_summary', u'tensor_summary', u'merge_summary/merge_summary',
u'print_op'
])
@test_util.run_v1_only('Testing graph mode behavior only')
def test_resource_variable_graph_mode(self):
"""Tests that resource variable usage is allowed."""
a = variable_scope.get_variable(
name='variable_a', use_resource=True, initializer=1)
context = self.create_test_xla_compile_context()
context.Enter()
a.assign(2)
context.Exit()
def test_resource_variable_in_function(self):
"""Tests that resource variable usage is allowed."""
a = variable_scope.get_variable(
name='variable_a', use_resource=True, initializer=1)
@def_function.function
def func():
context = self.create_test_xla_compile_context()
context.Enter()
o = a.assign(2)
context.Exit()
return o
self.assertEqual(self.evaluate(func()), 2)
@test_util.run_v1_only('Testing v1-only ref variable handling.')
def test_non_resource_variable_error(self):
"""Tests that non-resource variable usage is disallowed."""
a = variable_scope.get_variable(
name='variable_a', shape=(1), use_resource=False)
context = self.create_test_xla_compile_context()
context.Enter()
with self.assertRaisesRegex(
NotImplementedError, 'Non-resource Variables are not supported inside '
r'XLA computations \(operator name: Assign\)'):
state_ops.assign(a, a + 1)
context.Exit()
@test_util.build_as_function_and_v1_graph
def test_nested_xla_compile_error(self):
"""Tests that nested XLA computation leads to fatal error."""
context1 = self.create_test_xla_compile_context()
context1.Enter()
context2 = self.create_test_xla_compile_context()
context2.Enter()
with self.assertRaisesRegex(ValueError,
'XLA compiled computations cannot be nested'):
constant_op.constant(1)
context2.Exit()
context1.Exit()
@test_util.build_as_function_and_v1_graph
def test_xla_compile_attr(self):
"""Tests that ops are tagged with XLA compile ID attribute."""
context = self.create_test_xla_compile_context()
context.Enter()
op = constant_op.constant(1)
context.Exit()
self.assertIn('_xla_compile_id', op.op.node_def.attr)
@test_util.build_as_function_and_v1_graph
def test_op_without_input(self):
"""Tests that ops without inputs depend on pivot correctly."""
context = self.create_test_xla_compile_context()
context.Enter()
op = constant_op.constant(1)
context.Exit()
self.assertIn(context._pivot, op.op.control_inputs)
@test_util.run_v1_only('Testing graph mode behavior only')
def test_external_control_edges_graph_mode(self):
"""Tests that external control edges are handled correctly."""
i = constant_op.constant(1)
op1 = constant_op.constant(1)
with ops.control_dependencies([op1]):
op2 = constant_op.constant(1)
self.assertIn(op1.op, op2.op.control_inputs)
def while_body(i):
del i # unused
context = self.create_test_xla_compile_context()
context.Enter()
with ops.control_dependencies([op1]):
op3 = constant_op.constant(1)
context.Exit()
self.assertNotIn(op1.op, op3.op.control_inputs)
return op3
control_flow_ops.while_loop(
cond=lambda i: math_ops.less(i, 10), body=while_body, loop_vars=[i])
@test_util.build_as_function_and_v1_graph
def test_op_output_marked_as_seen(self):
"""Tests that any op output is marked as seen in context."""
context = self.create_test_xla_compile_context()
context.Enter()
op = constant_op.constant(1)
context.Exit()
self.assertIn(op.name, context._values)
@test_util.build_as_function_and_v1_graph
def test_op_is_in_context(self):
"""Tests that XLACompileContext is recognized as an XLA context."""
op1 = constant_op.constant(1)
context = self.create_test_xla_compile_context()
context.Enter()
op2 = constant_op.constant(2)
context.Exit()
self.assertFalse(control_flow_util.IsInXLAContext(op1.op))
self.assertTrue(control_flow_util.IsInXLAContext(op2.op))
@test_util.build_as_function_and_v1_graph
def test_op_prevent_feeding(self):
"""Tests that ops created inside XLACompileContext can not be fed."""
context = self.create_test_xla_compile_context()
context.Enter()
op = constant_op.constant(1)
context.Exit()
self.assertFalse(op.graph.is_feedable(op.op))
@test_util.build_as_function_and_v1_graph
def test_op_prevent_fetching(self):
"""Tests that ops created inside XLACompileContext can not be fetched."""
context = self.create_test_xla_compile_context()
context.Enter()
op = constant_op.constant(1)
context.Exit()
self.assertFalse(op.graph.is_fetchable(op.op))
class XlaCompileTest(test.TestCase):
@test_util.run_v2_only
@test_util.disable_tfrt(
'Legacy XLA test. It depends on EncapsulateXlaComputationsPass.')
def test_xla_compile_eager(self):
"""Tests that xla.compile raises proper exception when used eagerly."""
def computation(a, b):
return a + b
self.assertEqual(self.evaluate(xla.compile(computation, [1, 2])[0]), 3)
@test_util.disable_tfrt(
'Legacy XLA test. It depends on EncapsulateXlaComputationsPass.')
def test_xla_compile_in_function(self):
"""Tests that xla.compile works in tf.function."""
@def_function.function
def func_wrapper(a):
def compute(a):
return a + 1
return xla.compile(compute, [a])
self.assertEqual(self.evaluate(func_wrapper(1))[0], 2)
@test_util.disable_tfrt(
'Legacy XLA test. It depends on EncapsulateXlaComputationsPass.')
def test_xla_compile_write_variable_in_function(self):
"""Tests that xla.compile works with variable in tf.function."""
a = variable_scope.get_variable(
name='variable_a', use_resource=True, initializer=1)
@def_function.function
def func_wrapper():
def compute():
a.assign_add(1)
a.assign_sub(2)
return a.read_value()
return xla.compile(compute)
self.evaluate(a.initializer)
self.assertEqual(self.evaluate(func_wrapper())[0], 0)
class CheckFunctionArgumentCountTest(test.TestCase):
def test_simple(self):
"""Tests that arg checker works for functions with no varargs or defaults.
"""
def func(x, y, z):
return x + y + z
self.assertEqual(None, xla.check_function_argument_count(func, 3, None))
self.assertEqual('exactly 3 arguments',
xla.check_function_argument_count(func, 2, None))
queue = tpu_feed.InfeedQueue(2)
self.assertEqual(None, xla.check_function_argument_count(func, 1, queue))
self.assertEqual('exactly 3 arguments',
xla.check_function_argument_count(func, 2, queue))
def test_default_args(self):
"""Tests that arg checker works for a function with no varargs."""
def func(x, y, z=17):
return x + y + z
self.assertEqual(None, xla.check_function_argument_count(func, 3, None))
self.assertEqual(None, xla.check_function_argument_count(func, 2, None))
self.assertEqual('at least 2 arguments',
xla.check_function_argument_count(func, 1, None))
self.assertEqual('at most 3 arguments',
xla.check_function_argument_count(func, 4, None))
queue = tpu_feed.InfeedQueue(1)
self.assertEqual(None, xla.check_function_argument_count(func, 2, queue))
self.assertEqual(None, xla.check_function_argument_count(func, 1, queue))
self.assertEqual('at least 2 arguments',
xla.check_function_argument_count(func, 0, queue))
self.assertEqual('at most 3 arguments',
xla.check_function_argument_count(func, 4, queue))
def test_var_args(self):
"""Tests that arg checker works for a function with varargs."""
def func(x, y, *z):
return x + y + len(z)
self.assertEqual(None, xla.check_function_argument_count(func, 2, None))
self.assertEqual(None, xla.check_function_argument_count(func, 3, None))
self.assertEqual(None, xla.check_function_argument_count(func, 4, None))
self.assertEqual('at least 2 arguments',
xla.check_function_argument_count(func, 1, None))
queue = tpu_feed.InfeedQueue(1)
self.assertEqual(None, xla.check_function_argument_count(func, 1, queue))
self.assertEqual(None, xla.check_function_argument_count(func, 2, queue))
self.assertEqual(None, xla.check_function_argument_count(func, 3, queue))
self.assertEqual('at least 2 arguments',
xla.check_function_argument_count(func, 0, queue))
def test_var_args_and_defaults(self):
"""Tests that arg checker works for a function with varargs and defaults."""
def func(x, y, z=17, *q): # pylint: disable=keyword-arg-before-vararg
return x + y + z + len(q)
self.assertEqual(None, xla.check_function_argument_count(func, 2, None))
self.assertEqual(None, xla.check_function_argument_count(func, 3, None))
self.assertEqual(None, xla.check_function_argument_count(func, 4, None))
self.assertEqual(None, xla.check_function_argument_count(func, 5, None))
self.assertEqual('at least 2 arguments',
xla.check_function_argument_count(func, 1, None))
queue = tpu_feed.InfeedQueue(1)
self.assertEqual(None, xla.check_function_argument_count(func, 1, queue))
self.assertEqual(None, xla.check_function_argument_count(func, 2, queue))
self.assertEqual(None, xla.check_function_argument_count(func, 3, queue))
self.assertEqual(None, xla.check_function_argument_count(func, 4, queue))
self.assertEqual('at least 2 arguments',
xla.check_function_argument_count(func, 0, queue))
if __name__ == '__main__':
test.main()