From 13a1a9a71c084cda8f676f40a711df26e6f3c637 Mon Sep 17 00:00:00 2001
From: Eugene Brevdo <ebrevdo@google.com>
Date: Thu, 26 Jan 2017 12:06:38 -0800
Subject: [PATCH] Make LSTMCell use Defuns to speed up static graph builds, add
 compiled flag. Change: 145703555

---
 configure                                     |   4 +-
 tensorflow/contrib/rnn/BUILD                  |   2 +
 .../python/kernel_tests/core_rnn_cell_test.py | 165 ++++++++++++++-
 .../rnn/python/kernel_tests/core_rnn_test.py  | 103 ++++-----
 .../rnn/python/ops/core_rnn_cell_impl.py      | 195 ++++++++++++------
 .../seq2seq/python/ops/sampling_decoder.py    |  14 +-
 .../core/platform/default/build_config.bzl    |  13 --
 .../platform/default/build_config_root.bzl    |  17 ++
 tensorflow/python/BUILD                       |   2 +-
 tensorflow/tensorflow.bzl                     |  29 ++-
 tensorflow/tools/pip_package/BUILD            |   2 +-
 11 files changed, 388 insertions(+), 158 deletions(-)

diff --git a/configure b/configure
index a8e7bb77385..372ec2cee87 100755
--- a/configure
+++ b/configure
@@ -168,10 +168,10 @@ done
 
 if [ "$TF_ENABLE_XLA" == "1" ]; then
   # Update Bazel build configuration.
-  perl -pi -e "s,WITH_XLA_SUPPORT = (False|True),WITH_XLA_SUPPORT = True,s" tensorflow/core/platform/default/build_config.bzl
+  sed -i -e "s/WITH_XLA_SUPPORT = (False|True)/WITH_XLA_SUPPORT = True/" tensorflow/core/platform/default/build_config_root.bzl
 else
   # Update Bazel build configuration.
-  perl -pi -e "s,WITH_XLA_SUPPORT = (False|True),WITH_XLA_SUPPORT = False,s" tensorflow/core/platform/default/build_config.bzl
+  sed -i -e "s/WITH_XLA_SUPPORT = (False|True)/WITH_XLA_SUPPORT = False/" tensorflow/core/platform/default/build_config_root.bzl
 fi
 
 
diff --git a/tensorflow/contrib/rnn/BUILD b/tensorflow/contrib/rnn/BUILD
index bed23625d32..c02423f7a39 100644
--- a/tensorflow/contrib/rnn/BUILD
+++ b/tensorflow/contrib/rnn/BUILD
@@ -71,6 +71,7 @@ cuda_py_tests(
         "//tensorflow/python:variable_scope",
         "//tensorflow/python:variables",
     ],
+    xla_enabled = True,
 )
 
 cuda_py_tests(
@@ -91,6 +92,7 @@ cuda_py_tests(
         "//tensorflow/python:variable_scope",
         "//tensorflow/python:variables",
     ],
+    xla_enabled = True,
 )
 
 cuda_py_tests(
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
index 0d9285ccb8f..8090743e6cf 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
@@ -19,6 +19,7 @@ from __future__ import division
 from __future__ import print_function
 
 import functools
+import itertools
 import sys
 
 # TODO: #6568 Remove this hack that makes dlopen() not crash.
@@ -33,9 +34,14 @@ import numpy as np
 
 from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
 from tensorflow.contrib.rnn.python.ops.core_rnn_cell_impl import _linear as linear
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.client import session
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
+from tensorflow.python.framework import random_seed
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gradients_impl
 from tensorflow.python.ops import init_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import random_ops
@@ -43,10 +49,41 @@ from tensorflow.python.ops import rnn
 from tensorflow.python.ops import variable_scope
 from tensorflow.python.ops import variables as variables_lib
 from tensorflow.python.platform import test
+from tensorflow.python.util import nest
 
 # pylint: enable=protected-access
 
 
+def _CreateMultiLSTMCellOps(batch_size, num_units, input_depth,
+                            num_layers, max_time, compiled):
+  with variable_scope.variable_scope(
+      "root",
+      initializer=init_ops.random_uniform_initializer(-0.1, 0.1, seed=2)):
+    inputs = random_ops.random_uniform(
+        (max_time, batch_size, input_depth), seed=1)
+    rnn_cell = core_rnn_cell_impl.MultiRNNCell(
+        [core_rnn_cell_impl.LSTMCell(num_units, compiled=compiled)
+         for _ in range(num_layers)])
+    initial_state = rnn_cell.zero_state(
+        batch_size=batch_size, dtype=dtypes.float32)
+    outputs, final_state = rnn.dynamic_rnn(
+        cell=rnn_cell, inputs=inputs, initial_state=initial_state,
+        time_major=True)
+    flat_final_state = nest.flatten(final_state)
+    trainable_variables = variables_lib.trainable_variables()
+    outputs_grad = gradients_impl.gradients(
+        [outputs],
+        trainable_variables + [inputs] + nest.flatten(initial_state))
+    final_state_grad = gradients_impl.gradients(
+        flat_final_state,
+        trainable_variables + [inputs] + nest.flatten(initial_state))
+
+    return {"outputs": outputs,
+            "final_state": flat_final_state,
+            "outputs_grad": outputs_grad,
+            "final_state_grad": final_state_grad}
+
+
 class RNNCellTest(test.TestCase):
 
   def testLinear(self):
@@ -117,8 +154,8 @@ class RNNCellTest(test.TestCase):
         x = array_ops.zeros([1, 2])
         m = array_ops.zeros([1, 8])
         g, out_m = core_rnn_cell_impl.MultiRNNCell(
-            [core_rnn_cell_impl.BasicLSTMCell(
-                2, state_is_tuple=False)] * 2,
+            [core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)
+             for _ in range(2)],
             state_is_tuple=False)(x, m)
         sess.run([variables_lib.global_variables_initializer()])
         res = sess.run(
@@ -165,7 +202,8 @@ class RNNCellTest(test.TestCase):
         m0 = (array_ops.zeros([1, 2]),) * 2
         m1 = (array_ops.zeros([1, 2]),) * 2
         cell = core_rnn_cell_impl.MultiRNNCell(
-            [core_rnn_cell_impl.BasicLSTMCell(2)] * 2, state_is_tuple=True)
+            [core_rnn_cell_impl.BasicLSTMCell(2) for _ in range(2)],
+            state_is_tuple=True)
         self.assertTrue(isinstance(cell.state_size, tuple))
         self.assertTrue(
             isinstance(cell.state_size[0], core_rnn_cell_impl.LSTMStateTuple))
@@ -197,8 +235,8 @@ class RNNCellTest(test.TestCase):
         m0 = array_ops.zeros([1, 4])
         m1 = array_ops.zeros([1, 4])
         cell = core_rnn_cell_impl.MultiRNNCell(
-            [core_rnn_cell_impl.BasicLSTMCell(
-                2, state_is_tuple=False)] * 2,
+            [core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)
+             for _ in range(2)],
             state_is_tuple=True)
         g, (out_m0, out_m1) = cell(x, (m0, m1))
         sess.run([variables_lib.global_variables_initializer()])
@@ -407,7 +445,8 @@ class RNNCellTest(test.TestCase):
         x = array_ops.zeros([1, 2])
         m = array_ops.zeros([1, 4])
         _, ml = core_rnn_cell_impl.MultiRNNCell(
-            [core_rnn_cell_impl.GRUCell(2)] * 2, state_is_tuple=False)(x, m)
+            [core_rnn_cell_impl.GRUCell(2) for _ in range(2)],
+            state_is_tuple=False)(x, m)
         sess.run([variables_lib.global_variables_initializer()])
         res = sess.run(ml, {
             x.name: np.array([[1., 1.]]),
@@ -416,6 +455,48 @@ class RNNCellTest(test.TestCase):
         # The numbers in results were not calculated, this is just a smoke test.
         self.assertAllClose(res, [[0.175991, 0.175991, 0.13248, 0.13248]])
 
+  def testMultiRNNCellWithLSTMCellAndXLA(self):
+    # TODO(b/34735319): Don't run this test if XLA is not available.
+    batch_size = 16
+    num_units = 32
+    input_depth = 12
+    num_layers = 2
+    max_time = 20
+
+    random_seed.set_random_seed(1234)
+    with self.test_session(graph=ops.Graph()) as sess:
+      xla_ops = _CreateMultiLSTMCellOps(
+          batch_size=batch_size, num_units=num_units,
+          input_depth=input_depth, num_layers=num_layers,
+          max_time=max_time,
+          compiled=True)
+      sess.run([variables_lib.global_variables_initializer()])
+      xla_results = sess.run(xla_ops)
+
+    random_seed.set_random_seed(1234)
+    with self.test_session(graph=ops.Graph()) as sess:
+      non_xla_ops = _CreateMultiLSTMCellOps(
+          batch_size=batch_size, num_units=num_units,
+          input_depth=input_depth, num_layers=num_layers,
+          max_time=max_time,
+          compiled=False)
+      sess.run([variables_lib.global_variables_initializer()])
+      non_xla_results = sess.run(non_xla_ops)
+
+    self.assertAllClose(non_xla_results["outputs"], xla_results["outputs"])
+
+    for xla_value, non_xla_value in zip(
+        xla_results["final_state"], non_xla_results["final_state"]):
+      self.assertAllClose(xla_value, non_xla_value)
+
+    for xla_g, non_xla_g in zip(
+        xla_results["outputs_grad"], non_xla_results["outputs_grad"]):
+      self.assertAllClose(xla_g, non_xla_g)
+
+    for xla_g, non_xla_g in zip(
+        xla_results["final_state_grad"], non_xla_results["final_state_grad"]):
+      self.assertAllClose(xla_g, non_xla_g)
+
   def testMultiRNNCellWithStateTuple(self):
     with self.test_session() as sess:
       with variable_scope.variable_scope(
@@ -427,11 +508,12 @@ class RNNCellTest(test.TestCase):
         # Test incorrectness of state
         with self.assertRaisesRegexp(ValueError, "Expected state .* a tuple"):
           core_rnn_cell_impl.MultiRNNCell(
-              [core_rnn_cell_impl.GRUCell(2)] * 2,
+              [core_rnn_cell_impl.GRUCell(2) for _ in range(2)],
               state_is_tuple=True)(x, m_bad)
 
         _, ml = core_rnn_cell_impl.MultiRNNCell(
-            [core_rnn_cell_impl.GRUCell(2)] * 2, state_is_tuple=True)(x, m_good)
+            [core_rnn_cell_impl.GRUCell(2) for _ in range(2)],
+            state_is_tuple=True)(x, m_good)
 
         sess.run([variables_lib.global_variables_initializer()])
         res = sess.run(ml, {
@@ -490,7 +572,7 @@ class SlimRNNCellTest(test.TestCase):
         self.assertAllClose(res[1], res[3])
 
 
-def basic_rnn_cell(inputs, state, num_units, scope=None):
+def basic_rnn_cell(inputs, state, num_units, scope=None):  # pylint: disable=invalid-name
   if state is None:
     if inputs is not None:
       batch_size = inputs.get_shape()[0]
@@ -512,5 +594,70 @@ def basic_rnn_cell(inputs, state, num_units, scope=None):
     return output, output
 
 
+class BenchmarkLSTMCellXLA(test.Benchmark):
+
+  def benchmarkDynamicRNNWithMultiLSTMCell(self):
+    num_layers = 3
+    max_time = 50
+    print("benchmarkDynamicRNNWithMultiLSTMCell")
+    print("\t" +
+          "\t".join(["inter_th", "intra_th",
+                     "batch_size", "num_units", "input_depth", "device",
+                     "compiled", "wall_time"]))
+
+    warmup_run = True
+    for (threads,
+         device,
+         num_units,
+         batch_size,
+         input_depth,
+         compiled) in itertools.product(
+             [{"inter": 0, "intra": 0}, {"inter": 1, "intra": 4}],
+             ["cpu", "gpu"],
+             [32, 512],
+             [1, 32, 256],
+             [32, 512],
+             [False, True]):
+      if threads["inter"] != 0:
+        # We only care about testing inter/intra op limitations on
+        # CPU with small batch size, to mimic embedded devices.
+        if device != "cpu" or batch_size != 1:
+          continue
+      if device == "cpu" and batch_size > 32:
+        continue
+      random_seed.set_random_seed(1234)
+      config = config_pb2.ConfigProto(
+          inter_op_parallelism_threads=threads["inter"],
+          intra_op_parallelism_threads=threads["intra"],
+          allow_soft_placement=False)
+      with session.Session(config=config, graph=ops.Graph()) as sess:
+        with ops.device("/%s:0" % device):
+          ops_dict = _CreateMultiLSTMCellOps(
+              batch_size=batch_size, num_units=num_units,
+              input_depth=input_depth, num_layers=num_layers,
+              max_time=max_time,
+              compiled=compiled)
+        sess.run([variables_lib.global_variables_initializer()])
+        all_ops = nest.flatten(ops_dict.values())
+        all_ops_group = control_flow_ops.group(*all_ops)
+        name_suffix = (
+            "inter_th_%d_intra_th_%d_bs_%d_units_%d_inputdepth_%d"
+            "_device_%s_xla_%s" % (
+                threads["inter"], threads["intra"],
+                batch_size, num_units, input_depth, device, compiled))
+        if warmup_run:
+          self.run_op_benchmark(
+              sess, all_ops_group, min_iters=30, name="ignore_warmup")
+          warmup_run = False
+        benchmark_results = self.run_op_benchmark(
+            sess, all_ops_group, min_iters=30,
+            name="benchmarkDynamicRNNWithMultiLSTMCell_%s" % name_suffix)
+        print("\t" +
+              "\t".join(["%s" % x for x in [
+                  threads["inter"], threads["intra"],
+                  batch_size, num_units, input_depth, device, compiled,
+                  benchmark_results["wall_time"]]]))
+
+
 if __name__ == "__main__":
   test.main()
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
index 3c84c34726f..67e026dabf8 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py
@@ -154,6 +154,7 @@ class RNNTest(test.TestCase):
   def setUp(self):
     self._seed = 23489
     np.random.seed(self._seed)
+    ops_lib.reset_default_graph()
 
   def testInvalidSequenceLengthShape(self):
     cell = Plus1RNNCell()
@@ -583,7 +584,7 @@ class LSTMTest(test.TestCase):
       (state_notuple_v,) = sess.run((state_notuple,),
                                     feed_dict={inputs[0]: input_value})
       state_tuple_v = sess.run(state_tuple, feed_dict={inputs[0]: input_value})
-      self.assertAllEqual(state_notuple_v, np.hstack(state_tuple_v))
+      self.assertAllClose(state_notuple_v, np.hstack(state_tuple_v))
 
   def _testProjSharding(self, use_gpu):
     num_units = 3
@@ -806,7 +807,7 @@ class LSTMTest(test.TestCase):
       self.assertEqual(len(outputs0_values), len(outputs2_values))
       for o1, o2, o3 in zip(outputs0_values, outputs1_values, outputs2_values):
         # Same weights used by both RNNs so outputs should be the same.
-        self.assertAllEqual(o1, o2)
+        self.assertAllClose(o1, o2)
         # Different weights used so outputs should be different.
         self.assertTrue(np.linalg.norm(o1 - o3) > 1e-6)
 
@@ -844,7 +845,7 @@ class LSTMTest(test.TestCase):
       outputs1_values = output_values[max_length:]
       self.assertEqual(len(outputs0_values), len(outputs1_values))
       for out0, out1 in zip(outputs0_values, outputs1_values):
-        self.assertAllEqual(out0, out1)
+        self.assertAllClose(out0, out1)
 
   def testNoProjNoShardingSimpleStateSaver(self):
     self._testNoProjNoShardingSimpleStateSaver(use_gpu=False)
@@ -934,13 +935,13 @@ class LSTMTest(test.TestCase):
                                   feed_dict={inputs[0]: input_value})
       outputs_dynamic_v = sess.run(outputs_dynamic,
                                    feed_dict={inputs[0]: input_value})
-      self.assertAllEqual(outputs_static_v, outputs_dynamic_v)
+      self.assertAllClose(outputs_static_v, outputs_dynamic_v)
 
       state_static_v = sess.run(state_static,
                                 feed_dict={inputs[0]: input_value})
       state_dynamic_v = sess.run(state_dynamic,
                                  feed_dict={inputs[0]: input_value})
-      self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_dynamic_v))
+      self.assertAllClose(np.hstack(state_static_v), np.hstack(state_dynamic_v))
 
   def testDynamicRNNWithNestedTupleStates(self):
     num_units = 3
@@ -1003,13 +1004,13 @@ class LSTMTest(test.TestCase):
                                   feed_dict={inputs[0]: input_value})
       outputs_dynamic_v = sess.run(outputs_dynamic,
                                    feed_dict={inputs[0]: input_value})
-      self.assertAllEqual(outputs_static_v, outputs_dynamic_v)
+      self.assertAllClose(outputs_static_v, outputs_dynamic_v)
 
       state_static_v = sess.run(nest.flatten(state_static),
                                 feed_dict={inputs[0]: input_value})
       state_dynamic_v = sess.run(nest.flatten(state_dynamic),
                                  feed_dict={inputs[0]: input_value})
-      self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_dynamic_v))
+      self.assertAllClose(np.hstack(state_static_v), np.hstack(state_dynamic_v))
 
   def _testDynamicEquivalentToStaticRNN(self, use_gpu, use_sequence_length):
     time_steps = 8
@@ -1038,7 +1039,9 @@ class LSTMTest(test.TestCase):
           use_peepholes=True,
           initializer=initializer,
           num_proj=num_proj,
-          state_is_tuple=False)
+          state_is_tuple=False,
+          # TODO(b/XXX): Defun name aliasing causes errors
+          compiled=False)
 
       with variable_scope.variable_scope("dynamic_scope"):
         outputs_static, state_static = core_rnn.static_rnn(
@@ -1096,7 +1099,9 @@ class LSTMTest(test.TestCase):
           use_peepholes=True,
           initializer=initializer,
           num_proj=num_proj,
-          state_is_tuple=False)
+          state_is_tuple=False,
+          # TODO(b/XXX): Defun name aliasing causes errors
+          compiled=False)
 
       with variable_scope.variable_scope("dynamic_scope"):
         outputs_dynamic, state_dynamic = rnn.dynamic_rnn(
@@ -1150,10 +1155,10 @@ class LSTMTest(test.TestCase):
     ######### Step 3: Comparisons
     self.assertEqual(len(values_static), len(values_dynamic))
     for (value_static, value_dynamic) in zip(values_static, values_dynamic):
-      self.assertAllEqual(value_static, value_dynamic)
-    self.assertAllEqual(state_value_static, state_value_dynamic)
+      self.assertAllClose(value_static, value_dynamic)
+    self.assertAllClose(state_value_static, state_value_dynamic)
 
-    self.assertAllEqual(static_grad_values, dynamic_grad_values)
+    self.assertAllClose(static_grad_values, dynamic_grad_values)
 
     self.assertEqual(
         len(static_individual_grad_values), len(dynamic_individual_grad_values))
@@ -1164,14 +1169,14 @@ class LSTMTest(test.TestCase):
     for i, (a, b) in enumerate(
         zip(static_individual_grad_values, dynamic_individual_grad_values)):
       tf_logging.info("Comparing individual gradients iteration %d" % i)
-      self.assertAllEqual(a, b)
+      self.assertAllClose(a, b)
 
     for i, (a, b) in enumerate(
         zip(static_individual_var_grad_values,
             dynamic_individual_var_grad_values)):
       tf_logging.info("Comparing individual variable gradients iteration %d" %
                       i)
-      self.assertAllEqual(a, b)
+      self.assertAllClose(a, b)
 
   def testDynamicEquivalentToStaticRNN(self):
     self._testDynamicEquivalentToStaticRNN(
@@ -1293,13 +1298,13 @@ class BidirectionalRNNTest(test.TestCase):
       # Both sequences in batch are length=8.  Check that the time=i
       # forward output is equal to time=8-1-i backward output
       for i in xrange(8):
-        self.assertEqual(out[i][0][0], out[8 - 1 - i][0][3])
-        self.assertEqual(out[i][0][1], out[8 - 1 - i][0][4])
-        self.assertEqual(out[i][0][2], out[8 - 1 - i][0][5])
+        self.assertAllClose(out[i][0][0], out[8 - 1 - i][0][3])
+        self.assertAllClose(out[i][0][1], out[8 - 1 - i][0][4])
+        self.assertAllClose(out[i][0][2], out[8 - 1 - i][0][5])
       for i in xrange(8):
-        self.assertEqual(out[i][1][0], out[8 - 1 - i][1][3])
-        self.assertEqual(out[i][1][1], out[8 - 1 - i][1][4])
-        self.assertEqual(out[i][1][2], out[8 - 1 - i][1][5])
+        self.assertAllClose(out[i][1][0], out[8 - 1 - i][1][3])
+        self.assertAllClose(out[i][1][1], out[8 - 1 - i][1][4])
+        self.assertAllClose(out[i][1][2], out[8 - 1 - i][1][5])
       # Via the reasoning above, the forward and backward final state should be
       # exactly the same
       self.assertAllClose(s_fw, s_bw)
@@ -1399,27 +1404,27 @@ class BidirectionalRNNTest(test.TestCase):
       # Check that the time=0 forward output is equal to time=1 backward output
       if not use_time_major:
         out = np.swapaxes(out, 0, 1)
-      self.assertEqual(out[0][0][0], out[1][0][3])
-      self.assertEqual(out[0][0][1], out[1][0][4])
-      self.assertEqual(out[0][0][2], out[1][0][5])
+      self.assertAllClose(out[0][0][0], out[1][0][3])
+      self.assertAllClose(out[0][0][1], out[1][0][4])
+      self.assertAllClose(out[0][0][2], out[1][0][5])
       # Check that the time=1 forward output is equal to time=0 backward output
-      self.assertEqual(out[1][0][0], out[0][0][3])
-      self.assertEqual(out[1][0][1], out[0][0][4])
-      self.assertEqual(out[1][0][2], out[0][0][5])
+      self.assertAllClose(out[1][0][0], out[0][0][3])
+      self.assertAllClose(out[1][0][1], out[0][0][4])
+      self.assertAllClose(out[1][0][2], out[0][0][5])
 
       # Second sequence in batch is length=3
       # Check that the time=0 forward output is equal to time=2 backward output
-      self.assertEqual(out[0][1][0], out[2][1][3])
-      self.assertEqual(out[0][1][1], out[2][1][4])
-      self.assertEqual(out[0][1][2], out[2][1][5])
+      self.assertAllClose(out[0][1][0], out[2][1][3])
+      self.assertAllClose(out[0][1][1], out[2][1][4])
+      self.assertAllClose(out[0][1][2], out[2][1][5])
       # Check that the time=1 forward output is equal to time=1 backward output
-      self.assertEqual(out[1][1][0], out[1][1][3])
-      self.assertEqual(out[1][1][1], out[1][1][4])
-      self.assertEqual(out[1][1][2], out[1][1][5])
+      self.assertAllClose(out[1][1][0], out[1][1][3])
+      self.assertAllClose(out[1][1][1], out[1][1][4])
+      self.assertAllClose(out[1][1][2], out[1][1][5])
       # Check that the time=2 forward output is equal to time=0 backward output
-      self.assertEqual(out[2][1][0], out[0][1][3])
-      self.assertEqual(out[2][1][1], out[0][1][4])
-      self.assertEqual(out[2][1][2], out[0][1][5])
+      self.assertAllClose(out[2][1][0], out[0][1][3])
+      self.assertAllClose(out[2][1][1], out[0][1][4])
+      self.assertAllClose(out[2][1][2], out[0][1][5])
       # Via the reasoning above, the forward and backward final state should be
       # exactly the same
       self.assertAllClose(s_fw, s_bw)
@@ -1560,13 +1565,13 @@ class MultiDimensionalLSTMTest(test.TestCase):
       outputs_sav_v = sess.run(outputs_sav,
                                feed_dict={inputs_using_dim[0]: input_value})
 
-      self.assertAllEqual(outputs_static_v, outputs_dynamic_v)
-      self.assertAllEqual(outputs_static_v, outputs_sav_v)
+      self.assertAllClose(outputs_static_v, outputs_dynamic_v)
+      self.assertAllClose(outputs_static_v, outputs_sav_v)
       outputs_static_array = np.array(outputs_static_v)
       outputs_static_array_double = np.concatenate(
           (outputs_static_array, outputs_static_array), axis=2)
       outputs_bid_array = np.array(outputs_bid_v)
-      self.assertAllEqual(outputs_static_array_double, outputs_bid_array)
+      self.assertAllClose(outputs_static_array_double, outputs_bid_array)
 
       state_static_v = sess.run(state_static,
                                 feed_dict={inputs[0]: input_value})
@@ -1578,10 +1583,10 @@ class MultiDimensionalLSTMTest(test.TestCase):
                                 feed_dict={inputs_using_dim[0]: input_value})
       state_sav_v = sess.run(state_sav,
                              feed_dict={inputs_using_dim[0]: input_value})
-      self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_dynamic_v))
-      self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_sav_v))
-      self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_bid_fw_v))
-      self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_bid_bw_v))
+      self.assertAllClose(np.hstack(state_static_v), np.hstack(state_dynamic_v))
+      self.assertAllClose(np.hstack(state_static_v), np.hstack(state_sav_v))
+      self.assertAllClose(np.hstack(state_static_v), np.hstack(state_bid_fw_v))
+      self.assertAllClose(np.hstack(state_static_v), np.hstack(state_bid_bw_v))
 
 
 class NestedLSTMTest(test.TestCase):
@@ -1663,14 +1668,14 @@ class NestedLSTMTest(test.TestCase):
       outputs_bid_v = sess.run(outputs_bid,
                                feed_dict={single_input_using_dim: input_value})
 
-      self.assertAllEqual(outputs_static_v,
+      self.assertAllClose(outputs_static_v,
                           np.transpose(outputs_dynamic_v, (1, 0, 2, 3)))
-      self.assertAllEqual(outputs_static_v, outputs_sav_v)
+      self.assertAllClose(outputs_static_v, outputs_sav_v)
       outputs_static_array = np.array(outputs_static_v)
       outputs_static_array_double = np.concatenate(
           (outputs_static_array, outputs_static_array), axis=3)
       outputs_bid_array = np.array(outputs_bid_v)
-      self.assertAllEqual(outputs_static_array_double, outputs_bid_array)
+      self.assertAllClose(outputs_static_array_double, outputs_bid_array)
 
       state_dynamic_v = sess.run(state_dynamic,
                                  feed_dict={single_input: input_value})
@@ -1682,10 +1687,10 @@ class NestedLSTMTest(test.TestCase):
                                 feed_dict={single_input_using_dim: input_value})
       state_sav_v = sess.run(state_sav,
                              feed_dict={single_input_using_dim: input_value})
-      self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_dynamic_v))
-      self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_sav_v))
-      self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_bid_fw_v))
-      self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_bid_bw_v))
+      self.assertAllClose(np.hstack(state_static_v), np.hstack(state_dynamic_v))
+      self.assertAllClose(np.hstack(state_static_v), np.hstack(state_sav_v))
+      self.assertAllClose(np.hstack(state_static_v), np.hstack(state_bid_fw_v))
+      self.assertAllClose(np.hstack(state_static_v), np.hstack(state_bid_bw_v))
 
 
 class StateSaverRNNTest(test.TestCase):
diff --git a/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py b/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py
index 2d65d956a8b..c2843edaf2e 100644
--- a/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py
+++ b/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py
@@ -22,6 +22,7 @@ from __future__ import print_function
 import collections
 import math
 
+from tensorflow.python.framework import function
 from tensorflow.python.framework import ops
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import clip_ops
@@ -61,7 +62,7 @@ class BasicRNNCell(RNNCell):
     """Most basic RNN: output = new_state = act(W * input + U * state + B)."""
     with vs.variable_scope(scope or "basic_rnn_cell"):
       output = self._activation(
-          _linear([inputs, state], self._num_units, True, scope=scope))
+          _linear([inputs, state], self._num_units, True))
     return output, output
 
 
@@ -89,14 +90,13 @@ class GRUCell(RNNCell):
         # We start with bias of 1.0 to not reset and not update.
         r, u = array_ops.split(
             value=_linear(
-                [inputs, state], 2 * self._num_units, True, 1.0, scope=scope),
+                [inputs, state], 2 * self._num_units, True, 1.0),
             num_or_size_splits=2,
             axis=1)
         r, u = sigmoid(r), sigmoid(u)
       with vs.variable_scope("candidate"):
         c = self._activation(_linear([inputs, r * state],
-                                     self._num_units, True,
-                                     scope=scope))
+                                     self._num_units, True))
       new_h = u * state + (1 - u) * c
     return new_h, new_h
 
@@ -176,7 +176,7 @@ class BasicLSTMCell(RNNCell):
         c, h = state
       else:
         c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1)
-      concat = _linear([inputs, h], 4 * self._num_units, True, scope=scope)
+      concat = _linear([inputs, h], 4 * self._num_units, True)
 
       # i = input_gate, j = new_input, f = forget_gate, o = output_gate
       i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)
@@ -192,6 +192,13 @@ class BasicLSTMCell(RNNCell):
       return new_h, new_state
 
 
+def _maybe_compile(fun, compiled):
+  if not compiled:
+    return fun
+  else:
+    return function.Defun(noinline=True, compiled=True)(fun)
+
+
 class LSTMCell(RNNCell):
   """Long short-term memory unit (LSTM) recurrent network cell.
 
@@ -219,7 +226,7 @@ class LSTMCell(RNNCell):
                initializer=None, num_proj=None, proj_clip=None,
                num_unit_shards=None, num_proj_shards=None,
                forget_bias=1.0, state_is_tuple=True,
-               activation=tanh):
+               activation=tanh, compiled=False):
     """Initialize the parameters for an LSTM cell.
 
     Args:
@@ -246,6 +253,12 @@ class LSTMCell(RNNCell):
         the `c_state` and `m_state`.  If False, they are concatenated
         along the column axis.  This latter behavior will soon be deprecated.
       activation: Activation function of the inner states.
+      compiled: Python boolean.  If `True`, the core computation of the LSTM
+        cell is compiled via XLA.  As of now, this provides speedups for
+        most GPU calculations, and on small batch CPU and embedded calculations.
+
+    Raises:
+      ValueError: if compiled=True and state_is_tuple=False (not supported).
     """
     if not state_is_tuple:
       logging.warn("%s: Using a concatenated state is slower and will soon be "
@@ -257,6 +270,9 @@ class LSTMCell(RNNCell):
           "%s: The num_unit_shards and proj_unit_shards parameters are "
           "deprecated and will be removed in Jan 2017.  "
           "Use a variable scope with a partitioner instead.", self)
+    if not state_is_tuple and compiled:
+      raise ValueError(
+          "Combining state_is_tuple=False and compiled=True is not supported.")
 
     self._num_units = num_units
     self._use_peepholes = use_peepholes
@@ -269,6 +285,7 @@ class LSTMCell(RNNCell):
     self._forget_bias = forget_bias
     self._state_is_tuple = state_is_tuple
     self._activation = activation
+    self._compiled = compiled
 
     if num_proj:
       self._state_size = (
@@ -317,73 +334,111 @@ class LSTMCell(RNNCell):
     """
     num_proj = self._num_units if self._num_proj is None else self._num_proj
 
-    if self._state_is_tuple:
-      (c_prev, m_prev) = state
-    else:
-      c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units])
-      m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj])
+    def _kernel(k_inputs, state_p0, state_p1):
+      """Internal kernel for the single step of LSTM.
 
-    dtype = inputs.dtype
-    input_size = inputs.get_shape().with_rank(2)[1]
-    if input_size.value is None:
-      raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
-    with vs.variable_scope(scope or "lstm_cell",
-                           initializer=self._initializer) as unit_scope:
-      if self._num_unit_shards is not None:
-        unit_scope.set_partitioner(
-            partitioned_variables.fixed_size_partitioner(
-                self._num_unit_shards))
-      # i = input_gate, j = new_input, f = forget_gate, o = output_gate
-      lstm_matrix = _linear([inputs, m_prev], 4 * self._num_units, bias=True,
-                            scope=scope)
-      i, j, f, o = array_ops.split(
-          value=lstm_matrix, num_or_size_splits=4, axis=1)
+      Args:
+        k_inputs: Input Tensor.
+        state_p0: Either the state or the c component of the state.
+        state_p1: Either the state or the m component of the state.
 
-      # Diagonal connections
-      if self._use_peepholes:
-        with vs.variable_scope(unit_scope) as projection_scope:
-          if self._num_unit_shards is not None:
-            projection_scope.set_partitioner(None)
-          w_f_diag = vs.get_variable(
-              "w_f_diag", shape=[self._num_units], dtype=dtype)
-          w_i_diag = vs.get_variable(
-              "w_i_diag", shape=[self._num_units], dtype=dtype)
-          w_o_diag = vs.get_variable(
-              "w_o_diag", shape=[self._num_units], dtype=dtype)
+      Returns:
+        (m, c) or (m, concat([c, m])) depending on state_is_tuple.
 
-      if self._use_peepholes:
-        c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev +
-             sigmoid(i + w_i_diag * c_prev) * self._activation(j))
+      Raises:
+        ValueError: see above docstring.
+      """
+      k_inputs.set_shape(inputs.get_shape())
+      if self._state_is_tuple:
+        (c_prev, m_prev) = state_p0, state_p1
+        c_prev.set_shape(state[0].get_shape())
+        m_prev.set_shape(state[1].get_shape())
       else:
-        c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) *
-             self._activation(j))
+        k_state = state_p0
+        c_prev = array_ops.slice(k_state, [0, 0], [-1, self._num_units])
+        m_prev = array_ops.slice(k_state, [0, self._num_units], [-1, num_proj])
 
-      if self._cell_clip is not None:
-        # pylint: disable=invalid-unary-operand-type
-        c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)
-        # pylint: enable=invalid-unary-operand-type
+      dtype = k_inputs.dtype
+      input_size = k_inputs.get_shape().with_rank(2)[1]
+      if input_size.value is None:
+        raise ValueError(
+            "Could not infer input size from inputs.get_shape()[-1]")
+      with vs.variable_scope(scope or "lstm_cell",
+                             initializer=self._initializer) as unit_scope:
+        if self._num_unit_shards is not None:
+          unit_scope.set_partitioner(
+              partitioned_variables.fixed_size_partitioner(
+                  self._num_unit_shards))
+          # i = input_gate, j = new_input, f = forget_gate, o = output_gate
+        lstm_matrix = _linear(
+            [k_inputs, m_prev], 4 * self._num_units, bias=True,
+            compiled=self._compiled)
+        i, j, f, o = array_ops.split(
+            value=lstm_matrix, num_or_size_splits=4, axis=1)
 
-      if self._use_peepholes:
-        m = sigmoid(o + w_o_diag * c) * self._activation(c)
-      else:
-        m = sigmoid(o) * self._activation(c)
+        # Diagonal connections
+        if self._use_peepholes:
+          with vs.variable_scope(unit_scope) as projection_scope:
+            if self._num_unit_shards is not None:
+              projection_scope.set_partitioner(None)
+            w_f_diag = vs.get_variable(
+                "w_f_diag", shape=[self._num_units], dtype=dtype)
+            w_i_diag = vs.get_variable(
+                "w_i_diag", shape=[self._num_units], dtype=dtype)
+            w_o_diag = vs.get_variable(
+                "w_o_diag", shape=[self._num_units], dtype=dtype)
+          c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev +
+               sigmoid(i + w_i_diag * c_prev) * self._activation(j))
+        else:
+          c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) *
+               self._activation(j))
 
-      if self._num_proj is not None:
-        with vs.variable_scope("projection") as proj_scope:
-          if self._num_proj_shards is not None:
-            proj_scope.set_partitioner(
-                partitioned_variables.fixed_size_partitioner(
-                    self._num_proj_shards))
-          m = _linear(m, self._num_proj, bias=False, scope=scope)
-
-        if self._proj_clip is not None:
+        if self._cell_clip is not None:
           # pylint: disable=invalid-unary-operand-type
-          m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
+          c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)
           # pylint: enable=invalid-unary-operand-type
 
-    new_state = (LSTMStateTuple(c, m) if self._state_is_tuple else
-                 array_ops.concat([c, m], 1))
-    return m, new_state
+        if self._use_peepholes:
+          m = sigmoid(o + w_o_diag * c) * self._activation(c)
+        else:
+          m = sigmoid(o) * self._activation(c)
+
+        if self._num_proj is not None:
+          with vs.variable_scope("projection") as proj_scope:
+            if self._num_proj_shards is not None:
+              proj_scope.set_partitioner(
+                  partitioned_variables.fixed_size_partitioner(
+                      self._num_proj_shards))
+            m = _linear(m, self._num_proj, bias=False, compiled=self._compiled)
+
+          if self._proj_clip is not None:
+            # pylint: disable=invalid-unary-operand-type
+            m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
+            # pylint: enable=invalid-unary-operand-type
+
+      if self._state_is_tuple:
+        return m, c
+      else:
+        return m, array_ops.concat([c, m], 1)
+
+    compiled_kernel = _maybe_compile(_kernel, self._compiled)
+
+    if self._state_is_tuple:
+      batch_shape = (
+          inputs.get_shape()[:1].merge_with(
+              state[0].get_shape()[:1]).merge_with(
+                  state[1].get_shape()[:1]))
+      emit_m, emit_c = compiled_kernel(inputs, state[0], state[1])
+      emit_c.set_shape(batch_shape.concatenate([state[0].get_shape()[1]]))
+      emit_m.set_shape(batch_shape.concatenate([state[1].get_shape()[1]]))
+      emit_state = LSTMStateTuple(emit_c, emit_m)
+    else:
+      batch_shape = inputs.get_shape()[:1].merge_with(state.get_shape()[:1])
+      emit_m, emit_state = compiled_kernel(inputs, state, state)
+      emit_m.set_shape(batch_shape.concatenate([num_proj]))
+      emit_state.set_shape(batch_shape.concatenate([state.get_shape()[1]]))
+
+    return emit_m, emit_state
 
 
 class OutputProjectionWrapper(RNNCell):
@@ -426,7 +481,7 @@ class OutputProjectionWrapper(RNNCell):
     output, res_state = self._cell(inputs, state)
     # Default scope: "OutputProjectionWrapper"
     with vs.variable_scope(scope or "output_projection_wrapper"):
-      projected = _linear(output, self._output_size, True, scope=scope)
+      projected = _linear(output, self._output_size, True)
     return projected, res_state
 
 
@@ -468,7 +523,7 @@ class InputProjectionWrapper(RNNCell):
     """Run the input projection and then the cell."""
     # Default scope: "InputProjectionWrapper"
     with vs.variable_scope(scope or "input_projection_wrapper"):
-      projected = _linear(inputs, self._num_proj, True, scope=scope)
+      projected = _linear(inputs, self._num_proj, True)
     return self._cell(projected, state)
 
 
@@ -762,7 +817,7 @@ class _SlimRNNCell(RNNCell):
     return output, state
 
 
-def _linear(args, output_size, bias, bias_start=0.0, scope=None):
+def _linear(args, output_size, bias, bias_start=0.0, compiled=False):
   """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable.
 
   Args:
@@ -770,7 +825,7 @@ def _linear(args, output_size, bias, bias_start=0.0, scope=None):
     output_size: int, second dimension of W[i].
     bias: boolean, whether to add a bias term or not.
     bias_start: starting value to initialize the bias; 0 by default.
-    scope: (optional) Variable scope to create parameters in.
+    compiled: boolean, _linear plays nicely with XLA if it is enabled.
 
   Returns:
     A 2D Tensor with shape [batch x output_size] equal to
@@ -815,4 +870,8 @@ def _linear(args, output_size, bias, bias_start=0.0, scope=None):
           "biases", [output_size],
           dtype=dtype,
           initializer=init_ops.constant_initializer(bias_start, dtype=dtype))
-  return nn_ops.bias_add(res, biases)
+  if compiled:
+    # TODO(b/34505635): Defuns don't play well with bias_add
+    return res + biases
+  else:
+    return nn_ops.bias_add(res, biases)
diff --git a/tensorflow/contrib/seq2seq/python/ops/sampling_decoder.py b/tensorflow/contrib/seq2seq/python/ops/sampling_decoder.py
index fc36c3eae05..c082f7b5309 100644
--- a/tensorflow/contrib/seq2seq/python/ops/sampling_decoder.py
+++ b/tensorflow/contrib/seq2seq/python/ops/sampling_decoder.py
@@ -113,7 +113,8 @@ class BasicSamplingDecoder(decoder.Decoder):
         dtypes.int32)
 
   def initialize(self, name=None):
-    return self._sampler.initialize() + (self._initial_state,)
+    with ops.name_scope("basic_sampling_decoder_initialize"):
+      return self._sampler.initialize() + (self._initial_state,)
 
   def step(self, time, inputs, state):
     """Perform a decoding step.
@@ -126,11 +127,12 @@ class BasicSamplingDecoder(decoder.Decoder):
     Returns:
       `(outputs, next_state, next_inputs, finished)`.
     """
-    cell_outputs, next_state = self._cell(inputs, state)
-    (sample_id, finished, next_inputs) = self._sampler.sample(
-        time=time, outputs=cell_outputs, state=next_state)
-    outputs = SamplingDecoderOutput(cell_outputs, sample_id)
-    return (outputs, next_state, next_inputs, finished)
+    with ops.name_scope("basic_sampling_decoder_step"):
+      cell_outputs, next_state = self._cell(inputs, state)
+      (sample_id, finished, next_inputs) = self._sampler.sample(
+          time=time, outputs=cell_outputs, state=next_state)
+      outputs = SamplingDecoderOutput(cell_outputs, sample_id)
+      return (outputs, next_state, next_inputs, finished)
 
 
 class BasicTrainingSampler(Sampler):
diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl
index ebf835d1102..56d4f6ff58d 100644
--- a/tensorflow/core/platform/default/build_config.bzl
+++ b/tensorflow/core/platform/default/build_config.bzl
@@ -7,7 +7,6 @@ load("//tensorflow:tensorflow.bzl", "if_not_mobile")
 # configure may change the following lines
 WITH_GCP_SUPPORT = False
 WITH_HDFS_SUPPORT = False
-WITH_XLA_SUPPORT = False
 WITH_JEMALLOC = True
 
 # Appends a suffix to a list of deps.
@@ -242,15 +241,3 @@ def tf_additional_cloud_kernel_deps():
   #if WITH_GCP_SUPPORT:
   #  deps = if_not_mobile(["//tensorflow/core:cloud_ops_op_lib"])
   return deps
-
-def tf_additional_plugin_deps():
-  deps = []
-  if WITH_XLA_SUPPORT:
-    deps.append("//tensorflow/compiler/jit")
-  return deps
-
-def tf_additional_license_deps():
-  licenses = []
-  if WITH_XLA_SUPPORT:
-    licenses.append("@llvm//:LICENSE.TXT")
-  return licenses
diff --git a/tensorflow/core/platform/default/build_config_root.bzl b/tensorflow/core/platform/default/build_config_root.bzl
index 2fa2726bde7..23a7b9065a6 100644
--- a/tensorflow/core/platform/default/build_config_root.bzl
+++ b/tensorflow/core/platform/default/build_config_root.bzl
@@ -2,8 +2,25 @@
 # The functions in this file might be referred by tensorflow.bzl. They have to
 # be separate to avoid cyclic references.
 
+WITH_XLA_SUPPORT = False
+
 def tf_cuda_tests_tags():
   return ["local"]
 
 def tf_sycl_tests_tags():
   return ["local"]
+
+def tf_additional_plugin_deps():
+  deps = []
+  if WITH_XLA_SUPPORT:
+    deps.append("//tensorflow/compiler/jit")
+  return deps
+
+def tf_additional_xla_deps_py():
+  return []
+
+def tf_additional_license_deps():
+  licenses = []
+  if WITH_XLA_SUPPORT:
+    licenses.append("@llvm//:LICENSE.TXT")
+  return licenses
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 1834ce570ef..2befe43be6a 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -23,7 +23,7 @@ load("//tensorflow:tensorflow.bzl", "cuda_py_tests")
 load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library")
 load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library_py")
 load("//tensorflow/core:platform/default/build_config.bzl", "tf_additional_lib_deps")
-load("//tensorflow/core:platform/default/build_config.bzl", "tf_additional_plugin_deps")
+load("//tensorflow/core:platform/default/build_config_root.bzl", "tf_additional_plugin_deps")
 load("//tensorflow/python:build_defs.bzl", "tf_gen_op_wrapper_private_py")
 
 py_library(
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index 7fa7e4a91db..0e5b39af10d 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -12,6 +12,7 @@ load(
     "//tensorflow/core:platform/default/build_config_root.bzl",
     "tf_cuda_tests_tags",
     "tf_sycl_tests_tags",
+    "tf_additional_xla_deps_py",
 )
 load(
     "@local_config_cuda//cuda:build_defs.bzl",
@@ -789,7 +790,10 @@ def py_test(deps=[], **kwargs):
       **kwargs)
 
 def tf_py_test(name, srcs, size="medium", data=[], main=None, args=[],
-               tags=[], shard_count=1, additional_deps=[], flaky=0):
+               tags=[], shard_count=1, additional_deps=[], flaky=0,
+               xla_enabled=False):
+  if xla_enabled:
+    additional_deps += tf_additional_xla_deps_py()
   native.py_test(
       name=name,
       size=size,
@@ -811,7 +815,8 @@ def tf_py_test(name, srcs, size="medium", data=[], main=None, args=[],
       srcs_version="PY2AND3")
 
 def cuda_py_test(name, srcs, size="medium", data=[], main=None, args=[],
-                 shard_count=1, additional_deps=[], tags=[], flaky=0):
+                 shard_count=1, additional_deps=[], tags=[], flaky=0,
+                 xla_enabled=False):
   test_tags = tags + tf_cuda_tests_tags()
   tf_py_test(name=name,
              size=size,
@@ -822,10 +827,12 @@ def cuda_py_test(name, srcs, size="medium", data=[], main=None, args=[],
              tags=test_tags,
              shard_count=shard_count,
              additional_deps=additional_deps,
-             flaky=flaky)
+             flaky=flaky,
+             xla_enabled=xla_enabled)
 
 def sycl_py_test(name, srcs, size="medium", data=[], main=None, args=[],
-                shard_count=1, additional_deps=[], tags=[], flaky=0):
+                 shard_count=1, additional_deps=[], tags=[], flaky=0,
+                 xla_enabled=False):
  test_tags = tags + tf_sycl_tests_tags()
  tf_py_test(name=name,
             size=size,
@@ -836,7 +843,8 @@ def sycl_py_test(name, srcs, size="medium", data=[], main=None, args=[],
             tags=test_tags,
             shard_count=shard_count,
             additional_deps=additional_deps,
-            flaky=flaky)
+            flaky=flaky,
+            xla_enabled=xla_enabled)
 
 def py_tests(name,
              srcs,
@@ -845,7 +853,8 @@ def py_tests(name,
              data=[],
              tags=[],
              shard_count=1,
-             prefix=""):
+             prefix="",
+             xla_enabled=False):
   for src in srcs:
     test_name = src.split("/")[-1].split(".")[0]
     if prefix:
@@ -857,13 +866,15 @@ def py_tests(name,
                tags=tags,
                shard_count=shard_count,
                data=data,
-               additional_deps=additional_deps)
+               additional_deps=additional_deps,
+               xla_enabled=xla_enabled)
 
 def cuda_py_tests(name, srcs, size="medium", additional_deps=[], data=[],
-                  shard_count=1, tags=[], prefix=""):
+                  shard_count=1, tags=[], prefix="", xla_enabled=False):
   test_tags = tags + tf_cuda_tests_tags()
   py_tests(name=name, size=size, srcs=srcs, additional_deps=additional_deps,
-           data=data, tags=test_tags, shard_count=shard_count,prefix=prefix)
+           data=data, tags=test_tags, shard_count=shard_count,prefix=prefix,
+           xla_enabled=xla_enabled)
 
 # Creates a genrule named <name> for running tools/proto_text's generator to
 # make the proto_text functions, for the protos passed in <srcs>.
diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD
index 0ffbec8b3cb..85a8b79f859 100644
--- a/tensorflow/tools/pip_package/BUILD
+++ b/tensorflow/tools/pip_package/BUILD
@@ -4,7 +4,7 @@
 package(default_visibility = ["//visibility:private"])
 
 load("//tensorflow:tensorflow.bzl", "transitive_hdrs")
-load("//tensorflow/core:platform/default/build_config.bzl", "tf_additional_license_deps")
+load("//tensorflow/core:platform/default/build_config_root.bzl", "tf_additional_license_deps")
 
 # This returns a list of headers of all public header libraries (e.g.,
 # framework, lib), and all of the transitive dependencies of those