STT-tensorflow/tensorflow/python/compiler/xla/experimental_compile_test.py
George Karpenkov f3c0a3e0e4 Fix experimental_compile=True for graph mode
Previously the attribute only worked in eager mode, and was a no-op otherwise.
Note that this also resolves the problem of a function with
experimental_compile=True not being compiled when called from
experimental_compile=False context.
PiperOrigin-RevId: 286682281
Change-Id: Ifbc6efa2c82ae13f5d124ec6aaf440e1639a42c3
2019-12-20 21:24:40 -08:00

114 lines
4.2 KiB
Python

# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.client import session
from tensorflow.python.eager import backprop
from tensorflow.python.eager import def_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
class ExperimentalCompileTest(test.TestCase):
def testBasic(self):
with ops.Graph().as_default() as g:
def fn(x, a):
return x + a
xla_func = def_function.function(fn, experimental_compile=True)
inputs = array_ops.placeholder(dtypes.float32, [5])
# XLA support is not yet enabled for TF ROCm
if not test.is_built_with_rocm():
x = xla_func(inputs, 1)
with session.Session(graph=g) as sess:
y = sess.run(x, feed_dict={inputs: [1, 2, 2, 3, 3]})
self.assertTrue(x.graph.as_graph_def().library.function[0]
.attr["_XlaMustCompile"].b)
self.assertAllClose([2, 3, 3, 4, 4], y)
def testDerivative(self):
# XLA support is not yet enabled for TF ROCm
if test.is_built_with_rocm():
return
def fn(x, a):
return 2 * x + a
with ops.Graph().as_default() as g:
xla_func = def_function.function(fn, experimental_compile=True)
with backprop.GradientTape() as tape:
inputs = array_ops.placeholder(dtypes.float32, [5])
tape.watch(inputs)
outputs = xla_func(inputs, 1)
grads = tape.gradient(outputs, inputs)
with session.Session(graph=g) as sess:
grads_tensor = sess.run(grads, feed_dict={inputs: [1, 2, 2, 3, 3]})
self.assertAllClose([2, 2, 2, 2, 2], grads_tensor)
(forward, backward) = xla_func.get_concrete_function(
inputs, 1)._delayed_rewrite_functions.forward_backward()
# Check that the must-compile attribute gets correctly propagated to the
# created derivatives.
self.assertTrue(forward.definition.attr["_XlaMustCompile"])
self.assertTrue(backward.function_def.attr["_XlaMustCompile"])
def testBasicInt32(self):
with ops.Graph().as_default() as g:
def fn(x, a):
return x + a
xla_func = def_function.function(fn, experimental_compile=True)
inputs = array_ops.placeholder(dtypes.int32, [5])
# XLA support is not yet enabled for TF ROCm
if not test.is_built_with_rocm():
x = xla_func(inputs, 1)
with session.Session(graph=g) as sess:
y = sess.run(x, feed_dict={inputs: [1, 2, 2, 3, 3]})
self.assertTrue(x.graph.as_graph_def().library.function[0]
.attr["_XlaMustCompile"].b)
self.assertAllClose([2, 3, 3, 4, 4], y)
# Checking that we crash on an unsupported operation lets us test that the XLA
# compiler was actually invoked.
def testUnsupportedOps(self):
with ops.Graph().as_default() as g:
def fn(x):
return array_ops.unique(x).y # Unique is not supported by XLA
xla_func = def_function.function(fn, experimental_compile=True)
inputs = array_ops.placeholder(dtypes.float32, [5])
x = xla_func(inputs)
# XLA support is not yet enabled for TF ROCm
if not test.is_built_with_rocm():
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"not compilable"):
with session.Session(graph=g) as sess:
sess.run(x, feed_dict={inputs: [1, 2, 2, 3, 3]})
if __name__ == "__main__":
test.main()