From 13a1a9a71c084cda8f676f40a711df26e6f3c637 Mon Sep 17 00:00:00 2001 From: Eugene Brevdo <ebrevdo@google.com> Date: Thu, 26 Jan 2017 12:06:38 -0800 Subject: [PATCH] Make LSTMCell use Defuns to speed up static graph builds, add compiled flag. Change: 145703555 --- configure | 4 +- tensorflow/contrib/rnn/BUILD | 2 + .../python/kernel_tests/core_rnn_cell_test.py | 165 ++++++++++++++- .../rnn/python/kernel_tests/core_rnn_test.py | 103 ++++----- .../rnn/python/ops/core_rnn_cell_impl.py | 195 ++++++++++++------ .../seq2seq/python/ops/sampling_decoder.py | 14 +- .../core/platform/default/build_config.bzl | 13 -- .../platform/default/build_config_root.bzl | 17 ++ tensorflow/python/BUILD | 2 +- tensorflow/tensorflow.bzl | 29 ++- tensorflow/tools/pip_package/BUILD | 2 +- 11 files changed, 388 insertions(+), 158 deletions(-) diff --git a/configure b/configure index a8e7bb77385..372ec2cee87 100755 --- a/configure +++ b/configure @@ -168,10 +168,10 @@ done if [ "$TF_ENABLE_XLA" == "1" ]; then # Update Bazel build configuration. - perl -pi -e "s,WITH_XLA_SUPPORT = (False|True),WITH_XLA_SUPPORT = True,s" tensorflow/core/platform/default/build_config.bzl + sed -i -e "s/WITH_XLA_SUPPORT = (False|True)/WITH_XLA_SUPPORT = True/" tensorflow/core/platform/default/build_config_root.bzl else # Update Bazel build configuration. - perl -pi -e "s,WITH_XLA_SUPPORT = (False|True),WITH_XLA_SUPPORT = False,s" tensorflow/core/platform/default/build_config.bzl + sed -i -e "s/WITH_XLA_SUPPORT = (False|True)/WITH_XLA_SUPPORT = False/" tensorflow/core/platform/default/build_config_root.bzl fi diff --git a/tensorflow/contrib/rnn/BUILD b/tensorflow/contrib/rnn/BUILD index bed23625d32..c02423f7a39 100644 --- a/tensorflow/contrib/rnn/BUILD +++ b/tensorflow/contrib/rnn/BUILD @@ -71,6 +71,7 @@ cuda_py_tests( "//tensorflow/python:variable_scope", "//tensorflow/python:variables", ], + xla_enabled = True, ) cuda_py_tests( @@ -91,6 +92,7 @@ cuda_py_tests( "//tensorflow/python:variable_scope", "//tensorflow/python:variables", ], + xla_enabled = True, ) cuda_py_tests( diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py index 0d9285ccb8f..8090743e6cf 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import functools +import itertools import sys # TODO: #6568 Remove this hack that makes dlopen() not crash. @@ -33,9 +34,14 @@ import numpy as np from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl from tensorflow.contrib.rnn.python.ops.core_rnn_cell_impl import _linear as linear +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.client import session from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops @@ -43,10 +49,41 @@ from tensorflow.python.ops import rnn from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as variables_lib from tensorflow.python.platform import test +from tensorflow.python.util import nest # pylint: enable=protected-access +def _CreateMultiLSTMCellOps(batch_size, num_units, input_depth, + num_layers, max_time, compiled): + with variable_scope.variable_scope( + "root", + initializer=init_ops.random_uniform_initializer(-0.1, 0.1, seed=2)): + inputs = random_ops.random_uniform( + (max_time, batch_size, input_depth), seed=1) + rnn_cell = core_rnn_cell_impl.MultiRNNCell( + [core_rnn_cell_impl.LSTMCell(num_units, compiled=compiled) + for _ in range(num_layers)]) + initial_state = rnn_cell.zero_state( + batch_size=batch_size, dtype=dtypes.float32) + outputs, final_state = rnn.dynamic_rnn( + cell=rnn_cell, inputs=inputs, initial_state=initial_state, + time_major=True) + flat_final_state = nest.flatten(final_state) + trainable_variables = variables_lib.trainable_variables() + outputs_grad = gradients_impl.gradients( + [outputs], + trainable_variables + [inputs] + nest.flatten(initial_state)) + final_state_grad = gradients_impl.gradients( + flat_final_state, + trainable_variables + [inputs] + nest.flatten(initial_state)) + + return {"outputs": outputs, + "final_state": flat_final_state, + "outputs_grad": outputs_grad, + "final_state_grad": final_state_grad} + + class RNNCellTest(test.TestCase): def testLinear(self): @@ -117,8 +154,8 @@ class RNNCellTest(test.TestCase): x = array_ops.zeros([1, 2]) m = array_ops.zeros([1, 8]) g, out_m = core_rnn_cell_impl.MultiRNNCell( - [core_rnn_cell_impl.BasicLSTMCell( - 2, state_is_tuple=False)] * 2, + [core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False) + for _ in range(2)], state_is_tuple=False)(x, m) sess.run([variables_lib.global_variables_initializer()]) res = sess.run( @@ -165,7 +202,8 @@ class RNNCellTest(test.TestCase): m0 = (array_ops.zeros([1, 2]),) * 2 m1 = (array_ops.zeros([1, 2]),) * 2 cell = core_rnn_cell_impl.MultiRNNCell( - [core_rnn_cell_impl.BasicLSTMCell(2)] * 2, state_is_tuple=True) + [core_rnn_cell_impl.BasicLSTMCell(2) for _ in range(2)], + state_is_tuple=True) self.assertTrue(isinstance(cell.state_size, tuple)) self.assertTrue( isinstance(cell.state_size[0], core_rnn_cell_impl.LSTMStateTuple)) @@ -197,8 +235,8 @@ class RNNCellTest(test.TestCase): m0 = array_ops.zeros([1, 4]) m1 = array_ops.zeros([1, 4]) cell = core_rnn_cell_impl.MultiRNNCell( - [core_rnn_cell_impl.BasicLSTMCell( - 2, state_is_tuple=False)] * 2, + [core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False) + for _ in range(2)], state_is_tuple=True) g, (out_m0, out_m1) = cell(x, (m0, m1)) sess.run([variables_lib.global_variables_initializer()]) @@ -407,7 +445,8 @@ class RNNCellTest(test.TestCase): x = array_ops.zeros([1, 2]) m = array_ops.zeros([1, 4]) _, ml = core_rnn_cell_impl.MultiRNNCell( - [core_rnn_cell_impl.GRUCell(2)] * 2, state_is_tuple=False)(x, m) + [core_rnn_cell_impl.GRUCell(2) for _ in range(2)], + state_is_tuple=False)(x, m) sess.run([variables_lib.global_variables_initializer()]) res = sess.run(ml, { x.name: np.array([[1., 1.]]), @@ -416,6 +455,48 @@ class RNNCellTest(test.TestCase): # The numbers in results were not calculated, this is just a smoke test. self.assertAllClose(res, [[0.175991, 0.175991, 0.13248, 0.13248]]) + def testMultiRNNCellWithLSTMCellAndXLA(self): + # TODO(b/34735319): Don't run this test if XLA is not available. + batch_size = 16 + num_units = 32 + input_depth = 12 + num_layers = 2 + max_time = 20 + + random_seed.set_random_seed(1234) + with self.test_session(graph=ops.Graph()) as sess: + xla_ops = _CreateMultiLSTMCellOps( + batch_size=batch_size, num_units=num_units, + input_depth=input_depth, num_layers=num_layers, + max_time=max_time, + compiled=True) + sess.run([variables_lib.global_variables_initializer()]) + xla_results = sess.run(xla_ops) + + random_seed.set_random_seed(1234) + with self.test_session(graph=ops.Graph()) as sess: + non_xla_ops = _CreateMultiLSTMCellOps( + batch_size=batch_size, num_units=num_units, + input_depth=input_depth, num_layers=num_layers, + max_time=max_time, + compiled=False) + sess.run([variables_lib.global_variables_initializer()]) + non_xla_results = sess.run(non_xla_ops) + + self.assertAllClose(non_xla_results["outputs"], xla_results["outputs"]) + + for xla_value, non_xla_value in zip( + xla_results["final_state"], non_xla_results["final_state"]): + self.assertAllClose(xla_value, non_xla_value) + + for xla_g, non_xla_g in zip( + xla_results["outputs_grad"], non_xla_results["outputs_grad"]): + self.assertAllClose(xla_g, non_xla_g) + + for xla_g, non_xla_g in zip( + xla_results["final_state_grad"], non_xla_results["final_state_grad"]): + self.assertAllClose(xla_g, non_xla_g) + def testMultiRNNCellWithStateTuple(self): with self.test_session() as sess: with variable_scope.variable_scope( @@ -427,11 +508,12 @@ class RNNCellTest(test.TestCase): # Test incorrectness of state with self.assertRaisesRegexp(ValueError, "Expected state .* a tuple"): core_rnn_cell_impl.MultiRNNCell( - [core_rnn_cell_impl.GRUCell(2)] * 2, + [core_rnn_cell_impl.GRUCell(2) for _ in range(2)], state_is_tuple=True)(x, m_bad) _, ml = core_rnn_cell_impl.MultiRNNCell( - [core_rnn_cell_impl.GRUCell(2)] * 2, state_is_tuple=True)(x, m_good) + [core_rnn_cell_impl.GRUCell(2) for _ in range(2)], + state_is_tuple=True)(x, m_good) sess.run([variables_lib.global_variables_initializer()]) res = sess.run(ml, { @@ -490,7 +572,7 @@ class SlimRNNCellTest(test.TestCase): self.assertAllClose(res[1], res[3]) -def basic_rnn_cell(inputs, state, num_units, scope=None): +def basic_rnn_cell(inputs, state, num_units, scope=None): # pylint: disable=invalid-name if state is None: if inputs is not None: batch_size = inputs.get_shape()[0] @@ -512,5 +594,70 @@ def basic_rnn_cell(inputs, state, num_units, scope=None): return output, output +class BenchmarkLSTMCellXLA(test.Benchmark): + + def benchmarkDynamicRNNWithMultiLSTMCell(self): + num_layers = 3 + max_time = 50 + print("benchmarkDynamicRNNWithMultiLSTMCell") + print("\t" + + "\t".join(["inter_th", "intra_th", + "batch_size", "num_units", "input_depth", "device", + "compiled", "wall_time"])) + + warmup_run = True + for (threads, + device, + num_units, + batch_size, + input_depth, + compiled) in itertools.product( + [{"inter": 0, "intra": 0}, {"inter": 1, "intra": 4}], + ["cpu", "gpu"], + [32, 512], + [1, 32, 256], + [32, 512], + [False, True]): + if threads["inter"] != 0: + # We only care about testing inter/intra op limitations on + # CPU with small batch size, to mimic embedded devices. + if device != "cpu" or batch_size != 1: + continue + if device == "cpu" and batch_size > 32: + continue + random_seed.set_random_seed(1234) + config = config_pb2.ConfigProto( + inter_op_parallelism_threads=threads["inter"], + intra_op_parallelism_threads=threads["intra"], + allow_soft_placement=False) + with session.Session(config=config, graph=ops.Graph()) as sess: + with ops.device("/%s:0" % device): + ops_dict = _CreateMultiLSTMCellOps( + batch_size=batch_size, num_units=num_units, + input_depth=input_depth, num_layers=num_layers, + max_time=max_time, + compiled=compiled) + sess.run([variables_lib.global_variables_initializer()]) + all_ops = nest.flatten(ops_dict.values()) + all_ops_group = control_flow_ops.group(*all_ops) + name_suffix = ( + "inter_th_%d_intra_th_%d_bs_%d_units_%d_inputdepth_%d" + "_device_%s_xla_%s" % ( + threads["inter"], threads["intra"], + batch_size, num_units, input_depth, device, compiled)) + if warmup_run: + self.run_op_benchmark( + sess, all_ops_group, min_iters=30, name="ignore_warmup") + warmup_run = False + benchmark_results = self.run_op_benchmark( + sess, all_ops_group, min_iters=30, + name="benchmarkDynamicRNNWithMultiLSTMCell_%s" % name_suffix) + print("\t" + + "\t".join(["%s" % x for x in [ + threads["inter"], threads["intra"], + batch_size, num_units, input_depth, device, compiled, + benchmark_results["wall_time"]]])) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py index 3c84c34726f..67e026dabf8 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py @@ -154,6 +154,7 @@ class RNNTest(test.TestCase): def setUp(self): self._seed = 23489 np.random.seed(self._seed) + ops_lib.reset_default_graph() def testInvalidSequenceLengthShape(self): cell = Plus1RNNCell() @@ -583,7 +584,7 @@ class LSTMTest(test.TestCase): (state_notuple_v,) = sess.run((state_notuple,), feed_dict={inputs[0]: input_value}) state_tuple_v = sess.run(state_tuple, feed_dict={inputs[0]: input_value}) - self.assertAllEqual(state_notuple_v, np.hstack(state_tuple_v)) + self.assertAllClose(state_notuple_v, np.hstack(state_tuple_v)) def _testProjSharding(self, use_gpu): num_units = 3 @@ -806,7 +807,7 @@ class LSTMTest(test.TestCase): self.assertEqual(len(outputs0_values), len(outputs2_values)) for o1, o2, o3 in zip(outputs0_values, outputs1_values, outputs2_values): # Same weights used by both RNNs so outputs should be the same. - self.assertAllEqual(o1, o2) + self.assertAllClose(o1, o2) # Different weights used so outputs should be different. self.assertTrue(np.linalg.norm(o1 - o3) > 1e-6) @@ -844,7 +845,7 @@ class LSTMTest(test.TestCase): outputs1_values = output_values[max_length:] self.assertEqual(len(outputs0_values), len(outputs1_values)) for out0, out1 in zip(outputs0_values, outputs1_values): - self.assertAllEqual(out0, out1) + self.assertAllClose(out0, out1) def testNoProjNoShardingSimpleStateSaver(self): self._testNoProjNoShardingSimpleStateSaver(use_gpu=False) @@ -934,13 +935,13 @@ class LSTMTest(test.TestCase): feed_dict={inputs[0]: input_value}) outputs_dynamic_v = sess.run(outputs_dynamic, feed_dict={inputs[0]: input_value}) - self.assertAllEqual(outputs_static_v, outputs_dynamic_v) + self.assertAllClose(outputs_static_v, outputs_dynamic_v) state_static_v = sess.run(state_static, feed_dict={inputs[0]: input_value}) state_dynamic_v = sess.run(state_dynamic, feed_dict={inputs[0]: input_value}) - self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_dynamic_v)) + self.assertAllClose(np.hstack(state_static_v), np.hstack(state_dynamic_v)) def testDynamicRNNWithNestedTupleStates(self): num_units = 3 @@ -1003,13 +1004,13 @@ class LSTMTest(test.TestCase): feed_dict={inputs[0]: input_value}) outputs_dynamic_v = sess.run(outputs_dynamic, feed_dict={inputs[0]: input_value}) - self.assertAllEqual(outputs_static_v, outputs_dynamic_v) + self.assertAllClose(outputs_static_v, outputs_dynamic_v) state_static_v = sess.run(nest.flatten(state_static), feed_dict={inputs[0]: input_value}) state_dynamic_v = sess.run(nest.flatten(state_dynamic), feed_dict={inputs[0]: input_value}) - self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_dynamic_v)) + self.assertAllClose(np.hstack(state_static_v), np.hstack(state_dynamic_v)) def _testDynamicEquivalentToStaticRNN(self, use_gpu, use_sequence_length): time_steps = 8 @@ -1038,7 +1039,9 @@ class LSTMTest(test.TestCase): use_peepholes=True, initializer=initializer, num_proj=num_proj, - state_is_tuple=False) + state_is_tuple=False, + # TODO(b/XXX): Defun name aliasing causes errors + compiled=False) with variable_scope.variable_scope("dynamic_scope"): outputs_static, state_static = core_rnn.static_rnn( @@ -1096,7 +1099,9 @@ class LSTMTest(test.TestCase): use_peepholes=True, initializer=initializer, num_proj=num_proj, - state_is_tuple=False) + state_is_tuple=False, + # TODO(b/XXX): Defun name aliasing causes errors + compiled=False) with variable_scope.variable_scope("dynamic_scope"): outputs_dynamic, state_dynamic = rnn.dynamic_rnn( @@ -1150,10 +1155,10 @@ class LSTMTest(test.TestCase): ######### Step 3: Comparisons self.assertEqual(len(values_static), len(values_dynamic)) for (value_static, value_dynamic) in zip(values_static, values_dynamic): - self.assertAllEqual(value_static, value_dynamic) - self.assertAllEqual(state_value_static, state_value_dynamic) + self.assertAllClose(value_static, value_dynamic) + self.assertAllClose(state_value_static, state_value_dynamic) - self.assertAllEqual(static_grad_values, dynamic_grad_values) + self.assertAllClose(static_grad_values, dynamic_grad_values) self.assertEqual( len(static_individual_grad_values), len(dynamic_individual_grad_values)) @@ -1164,14 +1169,14 @@ class LSTMTest(test.TestCase): for i, (a, b) in enumerate( zip(static_individual_grad_values, dynamic_individual_grad_values)): tf_logging.info("Comparing individual gradients iteration %d" % i) - self.assertAllEqual(a, b) + self.assertAllClose(a, b) for i, (a, b) in enumerate( zip(static_individual_var_grad_values, dynamic_individual_var_grad_values)): tf_logging.info("Comparing individual variable gradients iteration %d" % i) - self.assertAllEqual(a, b) + self.assertAllClose(a, b) def testDynamicEquivalentToStaticRNN(self): self._testDynamicEquivalentToStaticRNN( @@ -1293,13 +1298,13 @@ class BidirectionalRNNTest(test.TestCase): # Both sequences in batch are length=8. Check that the time=i # forward output is equal to time=8-1-i backward output for i in xrange(8): - self.assertEqual(out[i][0][0], out[8 - 1 - i][0][3]) - self.assertEqual(out[i][0][1], out[8 - 1 - i][0][4]) - self.assertEqual(out[i][0][2], out[8 - 1 - i][0][5]) + self.assertAllClose(out[i][0][0], out[8 - 1 - i][0][3]) + self.assertAllClose(out[i][0][1], out[8 - 1 - i][0][4]) + self.assertAllClose(out[i][0][2], out[8 - 1 - i][0][5]) for i in xrange(8): - self.assertEqual(out[i][1][0], out[8 - 1 - i][1][3]) - self.assertEqual(out[i][1][1], out[8 - 1 - i][1][4]) - self.assertEqual(out[i][1][2], out[8 - 1 - i][1][5]) + self.assertAllClose(out[i][1][0], out[8 - 1 - i][1][3]) + self.assertAllClose(out[i][1][1], out[8 - 1 - i][1][4]) + self.assertAllClose(out[i][1][2], out[8 - 1 - i][1][5]) # Via the reasoning above, the forward and backward final state should be # exactly the same self.assertAllClose(s_fw, s_bw) @@ -1399,27 +1404,27 @@ class BidirectionalRNNTest(test.TestCase): # Check that the time=0 forward output is equal to time=1 backward output if not use_time_major: out = np.swapaxes(out, 0, 1) - self.assertEqual(out[0][0][0], out[1][0][3]) - self.assertEqual(out[0][0][1], out[1][0][4]) - self.assertEqual(out[0][0][2], out[1][0][5]) + self.assertAllClose(out[0][0][0], out[1][0][3]) + self.assertAllClose(out[0][0][1], out[1][0][4]) + self.assertAllClose(out[0][0][2], out[1][0][5]) # Check that the time=1 forward output is equal to time=0 backward output - self.assertEqual(out[1][0][0], out[0][0][3]) - self.assertEqual(out[1][0][1], out[0][0][4]) - self.assertEqual(out[1][0][2], out[0][0][5]) + self.assertAllClose(out[1][0][0], out[0][0][3]) + self.assertAllClose(out[1][0][1], out[0][0][4]) + self.assertAllClose(out[1][0][2], out[0][0][5]) # Second sequence in batch is length=3 # Check that the time=0 forward output is equal to time=2 backward output - self.assertEqual(out[0][1][0], out[2][1][3]) - self.assertEqual(out[0][1][1], out[2][1][4]) - self.assertEqual(out[0][1][2], out[2][1][5]) + self.assertAllClose(out[0][1][0], out[2][1][3]) + self.assertAllClose(out[0][1][1], out[2][1][4]) + self.assertAllClose(out[0][1][2], out[2][1][5]) # Check that the time=1 forward output is equal to time=1 backward output - self.assertEqual(out[1][1][0], out[1][1][3]) - self.assertEqual(out[1][1][1], out[1][1][4]) - self.assertEqual(out[1][1][2], out[1][1][5]) + self.assertAllClose(out[1][1][0], out[1][1][3]) + self.assertAllClose(out[1][1][1], out[1][1][4]) + self.assertAllClose(out[1][1][2], out[1][1][5]) # Check that the time=2 forward output is equal to time=0 backward output - self.assertEqual(out[2][1][0], out[0][1][3]) - self.assertEqual(out[2][1][1], out[0][1][4]) - self.assertEqual(out[2][1][2], out[0][1][5]) + self.assertAllClose(out[2][1][0], out[0][1][3]) + self.assertAllClose(out[2][1][1], out[0][1][4]) + self.assertAllClose(out[2][1][2], out[0][1][5]) # Via the reasoning above, the forward and backward final state should be # exactly the same self.assertAllClose(s_fw, s_bw) @@ -1560,13 +1565,13 @@ class MultiDimensionalLSTMTest(test.TestCase): outputs_sav_v = sess.run(outputs_sav, feed_dict={inputs_using_dim[0]: input_value}) - self.assertAllEqual(outputs_static_v, outputs_dynamic_v) - self.assertAllEqual(outputs_static_v, outputs_sav_v) + self.assertAllClose(outputs_static_v, outputs_dynamic_v) + self.assertAllClose(outputs_static_v, outputs_sav_v) outputs_static_array = np.array(outputs_static_v) outputs_static_array_double = np.concatenate( (outputs_static_array, outputs_static_array), axis=2) outputs_bid_array = np.array(outputs_bid_v) - self.assertAllEqual(outputs_static_array_double, outputs_bid_array) + self.assertAllClose(outputs_static_array_double, outputs_bid_array) state_static_v = sess.run(state_static, feed_dict={inputs[0]: input_value}) @@ -1578,10 +1583,10 @@ class MultiDimensionalLSTMTest(test.TestCase): feed_dict={inputs_using_dim[0]: input_value}) state_sav_v = sess.run(state_sav, feed_dict={inputs_using_dim[0]: input_value}) - self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_dynamic_v)) - self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_sav_v)) - self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_bid_fw_v)) - self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_bid_bw_v)) + self.assertAllClose(np.hstack(state_static_v), np.hstack(state_dynamic_v)) + self.assertAllClose(np.hstack(state_static_v), np.hstack(state_sav_v)) + self.assertAllClose(np.hstack(state_static_v), np.hstack(state_bid_fw_v)) + self.assertAllClose(np.hstack(state_static_v), np.hstack(state_bid_bw_v)) class NestedLSTMTest(test.TestCase): @@ -1663,14 +1668,14 @@ class NestedLSTMTest(test.TestCase): outputs_bid_v = sess.run(outputs_bid, feed_dict={single_input_using_dim: input_value}) - self.assertAllEqual(outputs_static_v, + self.assertAllClose(outputs_static_v, np.transpose(outputs_dynamic_v, (1, 0, 2, 3))) - self.assertAllEqual(outputs_static_v, outputs_sav_v) + self.assertAllClose(outputs_static_v, outputs_sav_v) outputs_static_array = np.array(outputs_static_v) outputs_static_array_double = np.concatenate( (outputs_static_array, outputs_static_array), axis=3) outputs_bid_array = np.array(outputs_bid_v) - self.assertAllEqual(outputs_static_array_double, outputs_bid_array) + self.assertAllClose(outputs_static_array_double, outputs_bid_array) state_dynamic_v = sess.run(state_dynamic, feed_dict={single_input: input_value}) @@ -1682,10 +1687,10 @@ class NestedLSTMTest(test.TestCase): feed_dict={single_input_using_dim: input_value}) state_sav_v = sess.run(state_sav, feed_dict={single_input_using_dim: input_value}) - self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_dynamic_v)) - self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_sav_v)) - self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_bid_fw_v)) - self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_bid_bw_v)) + self.assertAllClose(np.hstack(state_static_v), np.hstack(state_dynamic_v)) + self.assertAllClose(np.hstack(state_static_v), np.hstack(state_sav_v)) + self.assertAllClose(np.hstack(state_static_v), np.hstack(state_bid_fw_v)) + self.assertAllClose(np.hstack(state_static_v), np.hstack(state_bid_bw_v)) class StateSaverRNNTest(test.TestCase): diff --git a/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py b/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py index 2d65d956a8b..c2843edaf2e 100644 --- a/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py +++ b/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py @@ -22,6 +22,7 @@ from __future__ import print_function import collections import math +from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops @@ -61,7 +62,7 @@ class BasicRNNCell(RNNCell): """Most basic RNN: output = new_state = act(W * input + U * state + B).""" with vs.variable_scope(scope or "basic_rnn_cell"): output = self._activation( - _linear([inputs, state], self._num_units, True, scope=scope)) + _linear([inputs, state], self._num_units, True)) return output, output @@ -89,14 +90,13 @@ class GRUCell(RNNCell): # We start with bias of 1.0 to not reset and not update. r, u = array_ops.split( value=_linear( - [inputs, state], 2 * self._num_units, True, 1.0, scope=scope), + [inputs, state], 2 * self._num_units, True, 1.0), num_or_size_splits=2, axis=1) r, u = sigmoid(r), sigmoid(u) with vs.variable_scope("candidate"): c = self._activation(_linear([inputs, r * state], - self._num_units, True, - scope=scope)) + self._num_units, True)) new_h = u * state + (1 - u) * c return new_h, new_h @@ -176,7 +176,7 @@ class BasicLSTMCell(RNNCell): c, h = state else: c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1) - concat = _linear([inputs, h], 4 * self._num_units, True, scope=scope) + concat = _linear([inputs, h], 4 * self._num_units, True) # i = input_gate, j = new_input, f = forget_gate, o = output_gate i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1) @@ -192,6 +192,13 @@ class BasicLSTMCell(RNNCell): return new_h, new_state +def _maybe_compile(fun, compiled): + if not compiled: + return fun + else: + return function.Defun(noinline=True, compiled=True)(fun) + + class LSTMCell(RNNCell): """Long short-term memory unit (LSTM) recurrent network cell. @@ -219,7 +226,7 @@ class LSTMCell(RNNCell): initializer=None, num_proj=None, proj_clip=None, num_unit_shards=None, num_proj_shards=None, forget_bias=1.0, state_is_tuple=True, - activation=tanh): + activation=tanh, compiled=False): """Initialize the parameters for an LSTM cell. Args: @@ -246,6 +253,12 @@ class LSTMCell(RNNCell): the `c_state` and `m_state`. If False, they are concatenated along the column axis. This latter behavior will soon be deprecated. activation: Activation function of the inner states. + compiled: Python boolean. If `True`, the core computation of the LSTM + cell is compiled via XLA. As of now, this provides speedups for + most GPU calculations, and on small batch CPU and embedded calculations. + + Raises: + ValueError: if compiled=True and state_is_tuple=False (not supported). """ if not state_is_tuple: logging.warn("%s: Using a concatenated state is slower and will soon be " @@ -257,6 +270,9 @@ class LSTMCell(RNNCell): "%s: The num_unit_shards and proj_unit_shards parameters are " "deprecated and will be removed in Jan 2017. " "Use a variable scope with a partitioner instead.", self) + if not state_is_tuple and compiled: + raise ValueError( + "Combining state_is_tuple=False and compiled=True is not supported.") self._num_units = num_units self._use_peepholes = use_peepholes @@ -269,6 +285,7 @@ class LSTMCell(RNNCell): self._forget_bias = forget_bias self._state_is_tuple = state_is_tuple self._activation = activation + self._compiled = compiled if num_proj: self._state_size = ( @@ -317,73 +334,111 @@ class LSTMCell(RNNCell): """ num_proj = self._num_units if self._num_proj is None else self._num_proj - if self._state_is_tuple: - (c_prev, m_prev) = state - else: - c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units]) - m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj]) + def _kernel(k_inputs, state_p0, state_p1): + """Internal kernel for the single step of LSTM. - dtype = inputs.dtype - input_size = inputs.get_shape().with_rank(2)[1] - if input_size.value is None: - raise ValueError("Could not infer input size from inputs.get_shape()[-1]") - with vs.variable_scope(scope or "lstm_cell", - initializer=self._initializer) as unit_scope: - if self._num_unit_shards is not None: - unit_scope.set_partitioner( - partitioned_variables.fixed_size_partitioner( - self._num_unit_shards)) - # i = input_gate, j = new_input, f = forget_gate, o = output_gate - lstm_matrix = _linear([inputs, m_prev], 4 * self._num_units, bias=True, - scope=scope) - i, j, f, o = array_ops.split( - value=lstm_matrix, num_or_size_splits=4, axis=1) + Args: + k_inputs: Input Tensor. + state_p0: Either the state or the c component of the state. + state_p1: Either the state or the m component of the state. - # Diagonal connections - if self._use_peepholes: - with vs.variable_scope(unit_scope) as projection_scope: - if self._num_unit_shards is not None: - projection_scope.set_partitioner(None) - w_f_diag = vs.get_variable( - "w_f_diag", shape=[self._num_units], dtype=dtype) - w_i_diag = vs.get_variable( - "w_i_diag", shape=[self._num_units], dtype=dtype) - w_o_diag = vs.get_variable( - "w_o_diag", shape=[self._num_units], dtype=dtype) + Returns: + (m, c) or (m, concat([c, m])) depending on state_is_tuple. - if self._use_peepholes: - c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev + - sigmoid(i + w_i_diag * c_prev) * self._activation(j)) + Raises: + ValueError: see above docstring. + """ + k_inputs.set_shape(inputs.get_shape()) + if self._state_is_tuple: + (c_prev, m_prev) = state_p0, state_p1 + c_prev.set_shape(state[0].get_shape()) + m_prev.set_shape(state[1].get_shape()) else: - c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) * - self._activation(j)) + k_state = state_p0 + c_prev = array_ops.slice(k_state, [0, 0], [-1, self._num_units]) + m_prev = array_ops.slice(k_state, [0, self._num_units], [-1, num_proj]) - if self._cell_clip is not None: - # pylint: disable=invalid-unary-operand-type - c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip) - # pylint: enable=invalid-unary-operand-type + dtype = k_inputs.dtype + input_size = k_inputs.get_shape().with_rank(2)[1] + if input_size.value is None: + raise ValueError( + "Could not infer input size from inputs.get_shape()[-1]") + with vs.variable_scope(scope or "lstm_cell", + initializer=self._initializer) as unit_scope: + if self._num_unit_shards is not None: + unit_scope.set_partitioner( + partitioned_variables.fixed_size_partitioner( + self._num_unit_shards)) + # i = input_gate, j = new_input, f = forget_gate, o = output_gate + lstm_matrix = _linear( + [k_inputs, m_prev], 4 * self._num_units, bias=True, + compiled=self._compiled) + i, j, f, o = array_ops.split( + value=lstm_matrix, num_or_size_splits=4, axis=1) - if self._use_peepholes: - m = sigmoid(o + w_o_diag * c) * self._activation(c) - else: - m = sigmoid(o) * self._activation(c) + # Diagonal connections + if self._use_peepholes: + with vs.variable_scope(unit_scope) as projection_scope: + if self._num_unit_shards is not None: + projection_scope.set_partitioner(None) + w_f_diag = vs.get_variable( + "w_f_diag", shape=[self._num_units], dtype=dtype) + w_i_diag = vs.get_variable( + "w_i_diag", shape=[self._num_units], dtype=dtype) + w_o_diag = vs.get_variable( + "w_o_diag", shape=[self._num_units], dtype=dtype) + c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev + + sigmoid(i + w_i_diag * c_prev) * self._activation(j)) + else: + c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) * + self._activation(j)) - if self._num_proj is not None: - with vs.variable_scope("projection") as proj_scope: - if self._num_proj_shards is not None: - proj_scope.set_partitioner( - partitioned_variables.fixed_size_partitioner( - self._num_proj_shards)) - m = _linear(m, self._num_proj, bias=False, scope=scope) - - if self._proj_clip is not None: + if self._cell_clip is not None: # pylint: disable=invalid-unary-operand-type - m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip) + c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip) # pylint: enable=invalid-unary-operand-type - new_state = (LSTMStateTuple(c, m) if self._state_is_tuple else - array_ops.concat([c, m], 1)) - return m, new_state + if self._use_peepholes: + m = sigmoid(o + w_o_diag * c) * self._activation(c) + else: + m = sigmoid(o) * self._activation(c) + + if self._num_proj is not None: + with vs.variable_scope("projection") as proj_scope: + if self._num_proj_shards is not None: + proj_scope.set_partitioner( + partitioned_variables.fixed_size_partitioner( + self._num_proj_shards)) + m = _linear(m, self._num_proj, bias=False, compiled=self._compiled) + + if self._proj_clip is not None: + # pylint: disable=invalid-unary-operand-type + m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip) + # pylint: enable=invalid-unary-operand-type + + if self._state_is_tuple: + return m, c + else: + return m, array_ops.concat([c, m], 1) + + compiled_kernel = _maybe_compile(_kernel, self._compiled) + + if self._state_is_tuple: + batch_shape = ( + inputs.get_shape()[:1].merge_with( + state[0].get_shape()[:1]).merge_with( + state[1].get_shape()[:1])) + emit_m, emit_c = compiled_kernel(inputs, state[0], state[1]) + emit_c.set_shape(batch_shape.concatenate([state[0].get_shape()[1]])) + emit_m.set_shape(batch_shape.concatenate([state[1].get_shape()[1]])) + emit_state = LSTMStateTuple(emit_c, emit_m) + else: + batch_shape = inputs.get_shape()[:1].merge_with(state.get_shape()[:1]) + emit_m, emit_state = compiled_kernel(inputs, state, state) + emit_m.set_shape(batch_shape.concatenate([num_proj])) + emit_state.set_shape(batch_shape.concatenate([state.get_shape()[1]])) + + return emit_m, emit_state class OutputProjectionWrapper(RNNCell): @@ -426,7 +481,7 @@ class OutputProjectionWrapper(RNNCell): output, res_state = self._cell(inputs, state) # Default scope: "OutputProjectionWrapper" with vs.variable_scope(scope or "output_projection_wrapper"): - projected = _linear(output, self._output_size, True, scope=scope) + projected = _linear(output, self._output_size, True) return projected, res_state @@ -468,7 +523,7 @@ class InputProjectionWrapper(RNNCell): """Run the input projection and then the cell.""" # Default scope: "InputProjectionWrapper" with vs.variable_scope(scope or "input_projection_wrapper"): - projected = _linear(inputs, self._num_proj, True, scope=scope) + projected = _linear(inputs, self._num_proj, True) return self._cell(projected, state) @@ -762,7 +817,7 @@ class _SlimRNNCell(RNNCell): return output, state -def _linear(args, output_size, bias, bias_start=0.0, scope=None): +def _linear(args, output_size, bias, bias_start=0.0, compiled=False): """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable. Args: @@ -770,7 +825,7 @@ def _linear(args, output_size, bias, bias_start=0.0, scope=None): output_size: int, second dimension of W[i]. bias: boolean, whether to add a bias term or not. bias_start: starting value to initialize the bias; 0 by default. - scope: (optional) Variable scope to create parameters in. + compiled: boolean, _linear plays nicely with XLA if it is enabled. Returns: A 2D Tensor with shape [batch x output_size] equal to @@ -815,4 +870,8 @@ def _linear(args, output_size, bias, bias_start=0.0, scope=None): "biases", [output_size], dtype=dtype, initializer=init_ops.constant_initializer(bias_start, dtype=dtype)) - return nn_ops.bias_add(res, biases) + if compiled: + # TODO(b/34505635): Defuns don't play well with bias_add + return res + biases + else: + return nn_ops.bias_add(res, biases) diff --git a/tensorflow/contrib/seq2seq/python/ops/sampling_decoder.py b/tensorflow/contrib/seq2seq/python/ops/sampling_decoder.py index fc36c3eae05..c082f7b5309 100644 --- a/tensorflow/contrib/seq2seq/python/ops/sampling_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/sampling_decoder.py @@ -113,7 +113,8 @@ class BasicSamplingDecoder(decoder.Decoder): dtypes.int32) def initialize(self, name=None): - return self._sampler.initialize() + (self._initial_state,) + with ops.name_scope("basic_sampling_decoder_initialize"): + return self._sampler.initialize() + (self._initial_state,) def step(self, time, inputs, state): """Perform a decoding step. @@ -126,11 +127,12 @@ class BasicSamplingDecoder(decoder.Decoder): Returns: `(outputs, next_state, next_inputs, finished)`. """ - cell_outputs, next_state = self._cell(inputs, state) - (sample_id, finished, next_inputs) = self._sampler.sample( - time=time, outputs=cell_outputs, state=next_state) - outputs = SamplingDecoderOutput(cell_outputs, sample_id) - return (outputs, next_state, next_inputs, finished) + with ops.name_scope("basic_sampling_decoder_step"): + cell_outputs, next_state = self._cell(inputs, state) + (sample_id, finished, next_inputs) = self._sampler.sample( + time=time, outputs=cell_outputs, state=next_state) + outputs = SamplingDecoderOutput(cell_outputs, sample_id) + return (outputs, next_state, next_inputs, finished) class BasicTrainingSampler(Sampler): diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl index ebf835d1102..56d4f6ff58d 100644 --- a/tensorflow/core/platform/default/build_config.bzl +++ b/tensorflow/core/platform/default/build_config.bzl @@ -7,7 +7,6 @@ load("//tensorflow:tensorflow.bzl", "if_not_mobile") # configure may change the following lines WITH_GCP_SUPPORT = False WITH_HDFS_SUPPORT = False -WITH_XLA_SUPPORT = False WITH_JEMALLOC = True # Appends a suffix to a list of deps. @@ -242,15 +241,3 @@ def tf_additional_cloud_kernel_deps(): #if WITH_GCP_SUPPORT: # deps = if_not_mobile(["//tensorflow/core:cloud_ops_op_lib"]) return deps - -def tf_additional_plugin_deps(): - deps = [] - if WITH_XLA_SUPPORT: - deps.append("//tensorflow/compiler/jit") - return deps - -def tf_additional_license_deps(): - licenses = [] - if WITH_XLA_SUPPORT: - licenses.append("@llvm//:LICENSE.TXT") - return licenses diff --git a/tensorflow/core/platform/default/build_config_root.bzl b/tensorflow/core/platform/default/build_config_root.bzl index 2fa2726bde7..23a7b9065a6 100644 --- a/tensorflow/core/platform/default/build_config_root.bzl +++ b/tensorflow/core/platform/default/build_config_root.bzl @@ -2,8 +2,25 @@ # The functions in this file might be referred by tensorflow.bzl. They have to # be separate to avoid cyclic references. +WITH_XLA_SUPPORT = False + def tf_cuda_tests_tags(): return ["local"] def tf_sycl_tests_tags(): return ["local"] + +def tf_additional_plugin_deps(): + deps = [] + if WITH_XLA_SUPPORT: + deps.append("//tensorflow/compiler/jit") + return deps + +def tf_additional_xla_deps_py(): + return [] + +def tf_additional_license_deps(): + licenses = [] + if WITH_XLA_SUPPORT: + licenses.append("@llvm//:LICENSE.TXT") + return licenses diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 1834ce570ef..2befe43be6a 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -23,7 +23,7 @@ load("//tensorflow:tensorflow.bzl", "cuda_py_tests") load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library") load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library_py") load("//tensorflow/core:platform/default/build_config.bzl", "tf_additional_lib_deps") -load("//tensorflow/core:platform/default/build_config.bzl", "tf_additional_plugin_deps") +load("//tensorflow/core:platform/default/build_config_root.bzl", "tf_additional_plugin_deps") load("//tensorflow/python:build_defs.bzl", "tf_gen_op_wrapper_private_py") py_library( diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index 7fa7e4a91db..0e5b39af10d 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -12,6 +12,7 @@ load( "//tensorflow/core:platform/default/build_config_root.bzl", "tf_cuda_tests_tags", "tf_sycl_tests_tags", + "tf_additional_xla_deps_py", ) load( "@local_config_cuda//cuda:build_defs.bzl", @@ -789,7 +790,10 @@ def py_test(deps=[], **kwargs): **kwargs) def tf_py_test(name, srcs, size="medium", data=[], main=None, args=[], - tags=[], shard_count=1, additional_deps=[], flaky=0): + tags=[], shard_count=1, additional_deps=[], flaky=0, + xla_enabled=False): + if xla_enabled: + additional_deps += tf_additional_xla_deps_py() native.py_test( name=name, size=size, @@ -811,7 +815,8 @@ def tf_py_test(name, srcs, size="medium", data=[], main=None, args=[], srcs_version="PY2AND3") def cuda_py_test(name, srcs, size="medium", data=[], main=None, args=[], - shard_count=1, additional_deps=[], tags=[], flaky=0): + shard_count=1, additional_deps=[], tags=[], flaky=0, + xla_enabled=False): test_tags = tags + tf_cuda_tests_tags() tf_py_test(name=name, size=size, @@ -822,10 +827,12 @@ def cuda_py_test(name, srcs, size="medium", data=[], main=None, args=[], tags=test_tags, shard_count=shard_count, additional_deps=additional_deps, - flaky=flaky) + flaky=flaky, + xla_enabled=xla_enabled) def sycl_py_test(name, srcs, size="medium", data=[], main=None, args=[], - shard_count=1, additional_deps=[], tags=[], flaky=0): + shard_count=1, additional_deps=[], tags=[], flaky=0, + xla_enabled=False): test_tags = tags + tf_sycl_tests_tags() tf_py_test(name=name, size=size, @@ -836,7 +843,8 @@ def sycl_py_test(name, srcs, size="medium", data=[], main=None, args=[], tags=test_tags, shard_count=shard_count, additional_deps=additional_deps, - flaky=flaky) + flaky=flaky, + xla_enabled=xla_enabled) def py_tests(name, srcs, @@ -845,7 +853,8 @@ def py_tests(name, data=[], tags=[], shard_count=1, - prefix=""): + prefix="", + xla_enabled=False): for src in srcs: test_name = src.split("/")[-1].split(".")[0] if prefix: @@ -857,13 +866,15 @@ def py_tests(name, tags=tags, shard_count=shard_count, data=data, - additional_deps=additional_deps) + additional_deps=additional_deps, + xla_enabled=xla_enabled) def cuda_py_tests(name, srcs, size="medium", additional_deps=[], data=[], - shard_count=1, tags=[], prefix=""): + shard_count=1, tags=[], prefix="", xla_enabled=False): test_tags = tags + tf_cuda_tests_tags() py_tests(name=name, size=size, srcs=srcs, additional_deps=additional_deps, - data=data, tags=test_tags, shard_count=shard_count,prefix=prefix) + data=data, tags=test_tags, shard_count=shard_count,prefix=prefix, + xla_enabled=xla_enabled) # Creates a genrule named <name> for running tools/proto_text's generator to # make the proto_text functions, for the protos passed in <srcs>. diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index 0ffbec8b3cb..85a8b79f859 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -4,7 +4,7 @@ package(default_visibility = ["//visibility:private"]) load("//tensorflow:tensorflow.bzl", "transitive_hdrs") -load("//tensorflow/core:platform/default/build_config.bzl", "tf_additional_license_deps") +load("//tensorflow/core:platform/default/build_config_root.bzl", "tf_additional_license_deps") # This returns a list of headers of all public header libraries (e.g., # framework, lib), and all of the transitive dependencies of those