Automated rollback of change 145703555
Change: 145809900
This commit is contained in:
parent
a2417ef5b0
commit
8fc3981e02
4
configure
vendored
4
configure
vendored
@ -168,10 +168,10 @@ done
|
|||||||
|
|
||||||
if [ "$TF_ENABLE_XLA" == "1" ]; then
|
if [ "$TF_ENABLE_XLA" == "1" ]; then
|
||||||
# Update Bazel build configuration.
|
# Update Bazel build configuration.
|
||||||
sed -i -e "s/WITH_XLA_SUPPORT = (False|True)/WITH_XLA_SUPPORT = True/" tensorflow/core/platform/default/build_config_root.bzl
|
perl -pi -e "s,WITH_XLA_SUPPORT = (False|True),WITH_XLA_SUPPORT = True,s" tensorflow/core/platform/default/build_config.bzl
|
||||||
else
|
else
|
||||||
# Update Bazel build configuration.
|
# Update Bazel build configuration.
|
||||||
sed -i -e "s/WITH_XLA_SUPPORT = (False|True)/WITH_XLA_SUPPORT = False/" tensorflow/core/platform/default/build_config_root.bzl
|
perl -pi -e "s,WITH_XLA_SUPPORT = (False|True),WITH_XLA_SUPPORT = False,s" tensorflow/core/platform/default/build_config.bzl
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
@ -71,7 +71,6 @@ cuda_py_tests(
|
|||||||
"//tensorflow/python:variable_scope",
|
"//tensorflow/python:variable_scope",
|
||||||
"//tensorflow/python:variables",
|
"//tensorflow/python:variables",
|
||||||
],
|
],
|
||||||
xla_enabled = True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_tests(
|
cuda_py_tests(
|
||||||
@ -92,7 +91,6 @@ cuda_py_tests(
|
|||||||
"//tensorflow/python:variable_scope",
|
"//tensorflow/python:variable_scope",
|
||||||
"//tensorflow/python:variables",
|
"//tensorflow/python:variables",
|
||||||
],
|
],
|
||||||
xla_enabled = True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_py_tests(
|
cuda_py_tests(
|
||||||
|
@ -19,7 +19,6 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import itertools
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
# TODO: #6568 Remove this hack that makes dlopen() not crash.
|
# TODO: #6568 Remove this hack that makes dlopen() not crash.
|
||||||
@ -34,14 +33,9 @@ import numpy as np
|
|||||||
|
|
||||||
from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
|
from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
|
||||||
from tensorflow.contrib.rnn.python.ops.core_rnn_cell_impl import _linear as linear
|
from tensorflow.contrib.rnn.python.ops.core_rnn_cell_impl import _linear as linear
|
||||||
from tensorflow.core.protobuf import config_pb2
|
|
||||||
from tensorflow.python.client import session
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import random_seed
|
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
|
||||||
from tensorflow.python.ops import gradients_impl
|
|
||||||
from tensorflow.python.ops import init_ops
|
from tensorflow.python.ops import init_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import random_ops
|
from tensorflow.python.ops import random_ops
|
||||||
@ -49,41 +43,10 @@ from tensorflow.python.ops import rnn
|
|||||||
from tensorflow.python.ops import variable_scope
|
from tensorflow.python.ops import variable_scope
|
||||||
from tensorflow.python.ops import variables as variables_lib
|
from tensorflow.python.ops import variables as variables_lib
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.util import nest
|
|
||||||
|
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
|
|
||||||
def _CreateMultiLSTMCellOps(batch_size, num_units, input_depth,
|
|
||||||
num_layers, max_time, compiled):
|
|
||||||
with variable_scope.variable_scope(
|
|
||||||
"root",
|
|
||||||
initializer=init_ops.random_uniform_initializer(-0.1, 0.1, seed=2)):
|
|
||||||
inputs = random_ops.random_uniform(
|
|
||||||
(max_time, batch_size, input_depth), seed=1)
|
|
||||||
rnn_cell = core_rnn_cell_impl.MultiRNNCell(
|
|
||||||
[core_rnn_cell_impl.LSTMCell(num_units, compiled=compiled)
|
|
||||||
for _ in range(num_layers)])
|
|
||||||
initial_state = rnn_cell.zero_state(
|
|
||||||
batch_size=batch_size, dtype=dtypes.float32)
|
|
||||||
outputs, final_state = rnn.dynamic_rnn(
|
|
||||||
cell=rnn_cell, inputs=inputs, initial_state=initial_state,
|
|
||||||
time_major=True)
|
|
||||||
flat_final_state = nest.flatten(final_state)
|
|
||||||
trainable_variables = variables_lib.trainable_variables()
|
|
||||||
outputs_grad = gradients_impl.gradients(
|
|
||||||
[outputs],
|
|
||||||
trainable_variables + [inputs] + nest.flatten(initial_state))
|
|
||||||
final_state_grad = gradients_impl.gradients(
|
|
||||||
flat_final_state,
|
|
||||||
trainable_variables + [inputs] + nest.flatten(initial_state))
|
|
||||||
|
|
||||||
return {"outputs": outputs,
|
|
||||||
"final_state": flat_final_state,
|
|
||||||
"outputs_grad": outputs_grad,
|
|
||||||
"final_state_grad": final_state_grad}
|
|
||||||
|
|
||||||
|
|
||||||
class RNNCellTest(test.TestCase):
|
class RNNCellTest(test.TestCase):
|
||||||
|
|
||||||
def testLinear(self):
|
def testLinear(self):
|
||||||
@ -154,8 +117,8 @@ class RNNCellTest(test.TestCase):
|
|||||||
x = array_ops.zeros([1, 2])
|
x = array_ops.zeros([1, 2])
|
||||||
m = array_ops.zeros([1, 8])
|
m = array_ops.zeros([1, 8])
|
||||||
g, out_m = core_rnn_cell_impl.MultiRNNCell(
|
g, out_m = core_rnn_cell_impl.MultiRNNCell(
|
||||||
[core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)
|
[core_rnn_cell_impl.BasicLSTMCell(
|
||||||
for _ in range(2)],
|
2, state_is_tuple=False)] * 2,
|
||||||
state_is_tuple=False)(x, m)
|
state_is_tuple=False)(x, m)
|
||||||
sess.run([variables_lib.global_variables_initializer()])
|
sess.run([variables_lib.global_variables_initializer()])
|
||||||
res = sess.run(
|
res = sess.run(
|
||||||
@ -202,8 +165,7 @@ class RNNCellTest(test.TestCase):
|
|||||||
m0 = (array_ops.zeros([1, 2]),) * 2
|
m0 = (array_ops.zeros([1, 2]),) * 2
|
||||||
m1 = (array_ops.zeros([1, 2]),) * 2
|
m1 = (array_ops.zeros([1, 2]),) * 2
|
||||||
cell = core_rnn_cell_impl.MultiRNNCell(
|
cell = core_rnn_cell_impl.MultiRNNCell(
|
||||||
[core_rnn_cell_impl.BasicLSTMCell(2) for _ in range(2)],
|
[core_rnn_cell_impl.BasicLSTMCell(2)] * 2, state_is_tuple=True)
|
||||||
state_is_tuple=True)
|
|
||||||
self.assertTrue(isinstance(cell.state_size, tuple))
|
self.assertTrue(isinstance(cell.state_size, tuple))
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
isinstance(cell.state_size[0], core_rnn_cell_impl.LSTMStateTuple))
|
isinstance(cell.state_size[0], core_rnn_cell_impl.LSTMStateTuple))
|
||||||
@ -235,8 +197,8 @@ class RNNCellTest(test.TestCase):
|
|||||||
m0 = array_ops.zeros([1, 4])
|
m0 = array_ops.zeros([1, 4])
|
||||||
m1 = array_ops.zeros([1, 4])
|
m1 = array_ops.zeros([1, 4])
|
||||||
cell = core_rnn_cell_impl.MultiRNNCell(
|
cell = core_rnn_cell_impl.MultiRNNCell(
|
||||||
[core_rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)
|
[core_rnn_cell_impl.BasicLSTMCell(
|
||||||
for _ in range(2)],
|
2, state_is_tuple=False)] * 2,
|
||||||
state_is_tuple=True)
|
state_is_tuple=True)
|
||||||
g, (out_m0, out_m1) = cell(x, (m0, m1))
|
g, (out_m0, out_m1) = cell(x, (m0, m1))
|
||||||
sess.run([variables_lib.global_variables_initializer()])
|
sess.run([variables_lib.global_variables_initializer()])
|
||||||
@ -445,8 +407,7 @@ class RNNCellTest(test.TestCase):
|
|||||||
x = array_ops.zeros([1, 2])
|
x = array_ops.zeros([1, 2])
|
||||||
m = array_ops.zeros([1, 4])
|
m = array_ops.zeros([1, 4])
|
||||||
_, ml = core_rnn_cell_impl.MultiRNNCell(
|
_, ml = core_rnn_cell_impl.MultiRNNCell(
|
||||||
[core_rnn_cell_impl.GRUCell(2) for _ in range(2)],
|
[core_rnn_cell_impl.GRUCell(2)] * 2, state_is_tuple=False)(x, m)
|
||||||
state_is_tuple=False)(x, m)
|
|
||||||
sess.run([variables_lib.global_variables_initializer()])
|
sess.run([variables_lib.global_variables_initializer()])
|
||||||
res = sess.run(ml, {
|
res = sess.run(ml, {
|
||||||
x.name: np.array([[1., 1.]]),
|
x.name: np.array([[1., 1.]]),
|
||||||
@ -455,48 +416,6 @@ class RNNCellTest(test.TestCase):
|
|||||||
# The numbers in results were not calculated, this is just a smoke test.
|
# The numbers in results were not calculated, this is just a smoke test.
|
||||||
self.assertAllClose(res, [[0.175991, 0.175991, 0.13248, 0.13248]])
|
self.assertAllClose(res, [[0.175991, 0.175991, 0.13248, 0.13248]])
|
||||||
|
|
||||||
def testMultiRNNCellWithLSTMCellAndXLA(self):
|
|
||||||
# TODO(b/34735319): Don't run this test if XLA is not available.
|
|
||||||
batch_size = 16
|
|
||||||
num_units = 32
|
|
||||||
input_depth = 12
|
|
||||||
num_layers = 2
|
|
||||||
max_time = 20
|
|
||||||
|
|
||||||
random_seed.set_random_seed(1234)
|
|
||||||
with self.test_session(graph=ops.Graph()) as sess:
|
|
||||||
xla_ops = _CreateMultiLSTMCellOps(
|
|
||||||
batch_size=batch_size, num_units=num_units,
|
|
||||||
input_depth=input_depth, num_layers=num_layers,
|
|
||||||
max_time=max_time,
|
|
||||||
compiled=True)
|
|
||||||
sess.run([variables_lib.global_variables_initializer()])
|
|
||||||
xla_results = sess.run(xla_ops)
|
|
||||||
|
|
||||||
random_seed.set_random_seed(1234)
|
|
||||||
with self.test_session(graph=ops.Graph()) as sess:
|
|
||||||
non_xla_ops = _CreateMultiLSTMCellOps(
|
|
||||||
batch_size=batch_size, num_units=num_units,
|
|
||||||
input_depth=input_depth, num_layers=num_layers,
|
|
||||||
max_time=max_time,
|
|
||||||
compiled=False)
|
|
||||||
sess.run([variables_lib.global_variables_initializer()])
|
|
||||||
non_xla_results = sess.run(non_xla_ops)
|
|
||||||
|
|
||||||
self.assertAllClose(non_xla_results["outputs"], xla_results["outputs"])
|
|
||||||
|
|
||||||
for xla_value, non_xla_value in zip(
|
|
||||||
xla_results["final_state"], non_xla_results["final_state"]):
|
|
||||||
self.assertAllClose(xla_value, non_xla_value)
|
|
||||||
|
|
||||||
for xla_g, non_xla_g in zip(
|
|
||||||
xla_results["outputs_grad"], non_xla_results["outputs_grad"]):
|
|
||||||
self.assertAllClose(xla_g, non_xla_g)
|
|
||||||
|
|
||||||
for xla_g, non_xla_g in zip(
|
|
||||||
xla_results["final_state_grad"], non_xla_results["final_state_grad"]):
|
|
||||||
self.assertAllClose(xla_g, non_xla_g)
|
|
||||||
|
|
||||||
def testMultiRNNCellWithStateTuple(self):
|
def testMultiRNNCellWithStateTuple(self):
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
with variable_scope.variable_scope(
|
with variable_scope.variable_scope(
|
||||||
@ -508,12 +427,11 @@ class RNNCellTest(test.TestCase):
|
|||||||
# Test incorrectness of state
|
# Test incorrectness of state
|
||||||
with self.assertRaisesRegexp(ValueError, "Expected state .* a tuple"):
|
with self.assertRaisesRegexp(ValueError, "Expected state .* a tuple"):
|
||||||
core_rnn_cell_impl.MultiRNNCell(
|
core_rnn_cell_impl.MultiRNNCell(
|
||||||
[core_rnn_cell_impl.GRUCell(2) for _ in range(2)],
|
[core_rnn_cell_impl.GRUCell(2)] * 2,
|
||||||
state_is_tuple=True)(x, m_bad)
|
state_is_tuple=True)(x, m_bad)
|
||||||
|
|
||||||
_, ml = core_rnn_cell_impl.MultiRNNCell(
|
_, ml = core_rnn_cell_impl.MultiRNNCell(
|
||||||
[core_rnn_cell_impl.GRUCell(2) for _ in range(2)],
|
[core_rnn_cell_impl.GRUCell(2)] * 2, state_is_tuple=True)(x, m_good)
|
||||||
state_is_tuple=True)(x, m_good)
|
|
||||||
|
|
||||||
sess.run([variables_lib.global_variables_initializer()])
|
sess.run([variables_lib.global_variables_initializer()])
|
||||||
res = sess.run(ml, {
|
res = sess.run(ml, {
|
||||||
@ -572,7 +490,7 @@ class SlimRNNCellTest(test.TestCase):
|
|||||||
self.assertAllClose(res[1], res[3])
|
self.assertAllClose(res[1], res[3])
|
||||||
|
|
||||||
|
|
||||||
def basic_rnn_cell(inputs, state, num_units, scope=None): # pylint: disable=invalid-name
|
def basic_rnn_cell(inputs, state, num_units, scope=None):
|
||||||
if state is None:
|
if state is None:
|
||||||
if inputs is not None:
|
if inputs is not None:
|
||||||
batch_size = inputs.get_shape()[0]
|
batch_size = inputs.get_shape()[0]
|
||||||
@ -594,70 +512,5 @@ def basic_rnn_cell(inputs, state, num_units, scope=None): # pylint: disable=inv
|
|||||||
return output, output
|
return output, output
|
||||||
|
|
||||||
|
|
||||||
class BenchmarkLSTMCellXLA(test.Benchmark):
|
|
||||||
|
|
||||||
def benchmarkDynamicRNNWithMultiLSTMCell(self):
|
|
||||||
num_layers = 3
|
|
||||||
max_time = 50
|
|
||||||
print("benchmarkDynamicRNNWithMultiLSTMCell")
|
|
||||||
print("\t" +
|
|
||||||
"\t".join(["inter_th", "intra_th",
|
|
||||||
"batch_size", "num_units", "input_depth", "device",
|
|
||||||
"compiled", "wall_time"]))
|
|
||||||
|
|
||||||
warmup_run = True
|
|
||||||
for (threads,
|
|
||||||
device,
|
|
||||||
num_units,
|
|
||||||
batch_size,
|
|
||||||
input_depth,
|
|
||||||
compiled) in itertools.product(
|
|
||||||
[{"inter": 0, "intra": 0}, {"inter": 1, "intra": 4}],
|
|
||||||
["cpu", "gpu"],
|
|
||||||
[32, 512],
|
|
||||||
[1, 32, 256],
|
|
||||||
[32, 512],
|
|
||||||
[False, True]):
|
|
||||||
if threads["inter"] != 0:
|
|
||||||
# We only care about testing inter/intra op limitations on
|
|
||||||
# CPU with small batch size, to mimic embedded devices.
|
|
||||||
if device != "cpu" or batch_size != 1:
|
|
||||||
continue
|
|
||||||
if device == "cpu" and batch_size > 32:
|
|
||||||
continue
|
|
||||||
random_seed.set_random_seed(1234)
|
|
||||||
config = config_pb2.ConfigProto(
|
|
||||||
inter_op_parallelism_threads=threads["inter"],
|
|
||||||
intra_op_parallelism_threads=threads["intra"],
|
|
||||||
allow_soft_placement=False)
|
|
||||||
with session.Session(config=config, graph=ops.Graph()) as sess:
|
|
||||||
with ops.device("/%s:0" % device):
|
|
||||||
ops_dict = _CreateMultiLSTMCellOps(
|
|
||||||
batch_size=batch_size, num_units=num_units,
|
|
||||||
input_depth=input_depth, num_layers=num_layers,
|
|
||||||
max_time=max_time,
|
|
||||||
compiled=compiled)
|
|
||||||
sess.run([variables_lib.global_variables_initializer()])
|
|
||||||
all_ops = nest.flatten(ops_dict.values())
|
|
||||||
all_ops_group = control_flow_ops.group(*all_ops)
|
|
||||||
name_suffix = (
|
|
||||||
"inter_th_%d_intra_th_%d_bs_%d_units_%d_inputdepth_%d"
|
|
||||||
"_device_%s_xla_%s" % (
|
|
||||||
threads["inter"], threads["intra"],
|
|
||||||
batch_size, num_units, input_depth, device, compiled))
|
|
||||||
if warmup_run:
|
|
||||||
self.run_op_benchmark(
|
|
||||||
sess, all_ops_group, min_iters=30, name="ignore_warmup")
|
|
||||||
warmup_run = False
|
|
||||||
benchmark_results = self.run_op_benchmark(
|
|
||||||
sess, all_ops_group, min_iters=30,
|
|
||||||
name="benchmarkDynamicRNNWithMultiLSTMCell_%s" % name_suffix)
|
|
||||||
print("\t" +
|
|
||||||
"\t".join(["%s" % x for x in [
|
|
||||||
threads["inter"], threads["intra"],
|
|
||||||
batch_size, num_units, input_depth, device, compiled,
|
|
||||||
benchmark_results["wall_time"]]]))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -154,7 +154,6 @@ class RNNTest(test.TestCase):
|
|||||||
def setUp(self):
|
def setUp(self):
|
||||||
self._seed = 23489
|
self._seed = 23489
|
||||||
np.random.seed(self._seed)
|
np.random.seed(self._seed)
|
||||||
ops_lib.reset_default_graph()
|
|
||||||
|
|
||||||
def testInvalidSequenceLengthShape(self):
|
def testInvalidSequenceLengthShape(self):
|
||||||
cell = Plus1RNNCell()
|
cell = Plus1RNNCell()
|
||||||
@ -584,7 +583,7 @@ class LSTMTest(test.TestCase):
|
|||||||
(state_notuple_v,) = sess.run((state_notuple,),
|
(state_notuple_v,) = sess.run((state_notuple,),
|
||||||
feed_dict={inputs[0]: input_value})
|
feed_dict={inputs[0]: input_value})
|
||||||
state_tuple_v = sess.run(state_tuple, feed_dict={inputs[0]: input_value})
|
state_tuple_v = sess.run(state_tuple, feed_dict={inputs[0]: input_value})
|
||||||
self.assertAllClose(state_notuple_v, np.hstack(state_tuple_v))
|
self.assertAllEqual(state_notuple_v, np.hstack(state_tuple_v))
|
||||||
|
|
||||||
def _testProjSharding(self, use_gpu):
|
def _testProjSharding(self, use_gpu):
|
||||||
num_units = 3
|
num_units = 3
|
||||||
@ -807,7 +806,7 @@ class LSTMTest(test.TestCase):
|
|||||||
self.assertEqual(len(outputs0_values), len(outputs2_values))
|
self.assertEqual(len(outputs0_values), len(outputs2_values))
|
||||||
for o1, o2, o3 in zip(outputs0_values, outputs1_values, outputs2_values):
|
for o1, o2, o3 in zip(outputs0_values, outputs1_values, outputs2_values):
|
||||||
# Same weights used by both RNNs so outputs should be the same.
|
# Same weights used by both RNNs so outputs should be the same.
|
||||||
self.assertAllClose(o1, o2)
|
self.assertAllEqual(o1, o2)
|
||||||
# Different weights used so outputs should be different.
|
# Different weights used so outputs should be different.
|
||||||
self.assertTrue(np.linalg.norm(o1 - o3) > 1e-6)
|
self.assertTrue(np.linalg.norm(o1 - o3) > 1e-6)
|
||||||
|
|
||||||
@ -845,7 +844,7 @@ class LSTMTest(test.TestCase):
|
|||||||
outputs1_values = output_values[max_length:]
|
outputs1_values = output_values[max_length:]
|
||||||
self.assertEqual(len(outputs0_values), len(outputs1_values))
|
self.assertEqual(len(outputs0_values), len(outputs1_values))
|
||||||
for out0, out1 in zip(outputs0_values, outputs1_values):
|
for out0, out1 in zip(outputs0_values, outputs1_values):
|
||||||
self.assertAllClose(out0, out1)
|
self.assertAllEqual(out0, out1)
|
||||||
|
|
||||||
def testNoProjNoShardingSimpleStateSaver(self):
|
def testNoProjNoShardingSimpleStateSaver(self):
|
||||||
self._testNoProjNoShardingSimpleStateSaver(use_gpu=False)
|
self._testNoProjNoShardingSimpleStateSaver(use_gpu=False)
|
||||||
@ -935,13 +934,13 @@ class LSTMTest(test.TestCase):
|
|||||||
feed_dict={inputs[0]: input_value})
|
feed_dict={inputs[0]: input_value})
|
||||||
outputs_dynamic_v = sess.run(outputs_dynamic,
|
outputs_dynamic_v = sess.run(outputs_dynamic,
|
||||||
feed_dict={inputs[0]: input_value})
|
feed_dict={inputs[0]: input_value})
|
||||||
self.assertAllClose(outputs_static_v, outputs_dynamic_v)
|
self.assertAllEqual(outputs_static_v, outputs_dynamic_v)
|
||||||
|
|
||||||
state_static_v = sess.run(state_static,
|
state_static_v = sess.run(state_static,
|
||||||
feed_dict={inputs[0]: input_value})
|
feed_dict={inputs[0]: input_value})
|
||||||
state_dynamic_v = sess.run(state_dynamic,
|
state_dynamic_v = sess.run(state_dynamic,
|
||||||
feed_dict={inputs[0]: input_value})
|
feed_dict={inputs[0]: input_value})
|
||||||
self.assertAllClose(np.hstack(state_static_v), np.hstack(state_dynamic_v))
|
self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_dynamic_v))
|
||||||
|
|
||||||
def testDynamicRNNWithNestedTupleStates(self):
|
def testDynamicRNNWithNestedTupleStates(self):
|
||||||
num_units = 3
|
num_units = 3
|
||||||
@ -1004,13 +1003,13 @@ class LSTMTest(test.TestCase):
|
|||||||
feed_dict={inputs[0]: input_value})
|
feed_dict={inputs[0]: input_value})
|
||||||
outputs_dynamic_v = sess.run(outputs_dynamic,
|
outputs_dynamic_v = sess.run(outputs_dynamic,
|
||||||
feed_dict={inputs[0]: input_value})
|
feed_dict={inputs[0]: input_value})
|
||||||
self.assertAllClose(outputs_static_v, outputs_dynamic_v)
|
self.assertAllEqual(outputs_static_v, outputs_dynamic_v)
|
||||||
|
|
||||||
state_static_v = sess.run(nest.flatten(state_static),
|
state_static_v = sess.run(nest.flatten(state_static),
|
||||||
feed_dict={inputs[0]: input_value})
|
feed_dict={inputs[0]: input_value})
|
||||||
state_dynamic_v = sess.run(nest.flatten(state_dynamic),
|
state_dynamic_v = sess.run(nest.flatten(state_dynamic),
|
||||||
feed_dict={inputs[0]: input_value})
|
feed_dict={inputs[0]: input_value})
|
||||||
self.assertAllClose(np.hstack(state_static_v), np.hstack(state_dynamic_v))
|
self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_dynamic_v))
|
||||||
|
|
||||||
def _testDynamicEquivalentToStaticRNN(self, use_gpu, use_sequence_length):
|
def _testDynamicEquivalentToStaticRNN(self, use_gpu, use_sequence_length):
|
||||||
time_steps = 8
|
time_steps = 8
|
||||||
@ -1039,9 +1038,7 @@ class LSTMTest(test.TestCase):
|
|||||||
use_peepholes=True,
|
use_peepholes=True,
|
||||||
initializer=initializer,
|
initializer=initializer,
|
||||||
num_proj=num_proj,
|
num_proj=num_proj,
|
||||||
state_is_tuple=False,
|
state_is_tuple=False)
|
||||||
# TODO(b/XXX): Defun name aliasing causes errors
|
|
||||||
compiled=False)
|
|
||||||
|
|
||||||
with variable_scope.variable_scope("dynamic_scope"):
|
with variable_scope.variable_scope("dynamic_scope"):
|
||||||
outputs_static, state_static = core_rnn.static_rnn(
|
outputs_static, state_static = core_rnn.static_rnn(
|
||||||
@ -1099,9 +1096,7 @@ class LSTMTest(test.TestCase):
|
|||||||
use_peepholes=True,
|
use_peepholes=True,
|
||||||
initializer=initializer,
|
initializer=initializer,
|
||||||
num_proj=num_proj,
|
num_proj=num_proj,
|
||||||
state_is_tuple=False,
|
state_is_tuple=False)
|
||||||
# TODO(b/XXX): Defun name aliasing causes errors
|
|
||||||
compiled=False)
|
|
||||||
|
|
||||||
with variable_scope.variable_scope("dynamic_scope"):
|
with variable_scope.variable_scope("dynamic_scope"):
|
||||||
outputs_dynamic, state_dynamic = rnn.dynamic_rnn(
|
outputs_dynamic, state_dynamic = rnn.dynamic_rnn(
|
||||||
@ -1155,10 +1150,10 @@ class LSTMTest(test.TestCase):
|
|||||||
######### Step 3: Comparisons
|
######### Step 3: Comparisons
|
||||||
self.assertEqual(len(values_static), len(values_dynamic))
|
self.assertEqual(len(values_static), len(values_dynamic))
|
||||||
for (value_static, value_dynamic) in zip(values_static, values_dynamic):
|
for (value_static, value_dynamic) in zip(values_static, values_dynamic):
|
||||||
self.assertAllClose(value_static, value_dynamic)
|
self.assertAllEqual(value_static, value_dynamic)
|
||||||
self.assertAllClose(state_value_static, state_value_dynamic)
|
self.assertAllEqual(state_value_static, state_value_dynamic)
|
||||||
|
|
||||||
self.assertAllClose(static_grad_values, dynamic_grad_values)
|
self.assertAllEqual(static_grad_values, dynamic_grad_values)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
len(static_individual_grad_values), len(dynamic_individual_grad_values))
|
len(static_individual_grad_values), len(dynamic_individual_grad_values))
|
||||||
@ -1169,14 +1164,14 @@ class LSTMTest(test.TestCase):
|
|||||||
for i, (a, b) in enumerate(
|
for i, (a, b) in enumerate(
|
||||||
zip(static_individual_grad_values, dynamic_individual_grad_values)):
|
zip(static_individual_grad_values, dynamic_individual_grad_values)):
|
||||||
tf_logging.info("Comparing individual gradients iteration %d" % i)
|
tf_logging.info("Comparing individual gradients iteration %d" % i)
|
||||||
self.assertAllClose(a, b)
|
self.assertAllEqual(a, b)
|
||||||
|
|
||||||
for i, (a, b) in enumerate(
|
for i, (a, b) in enumerate(
|
||||||
zip(static_individual_var_grad_values,
|
zip(static_individual_var_grad_values,
|
||||||
dynamic_individual_var_grad_values)):
|
dynamic_individual_var_grad_values)):
|
||||||
tf_logging.info("Comparing individual variable gradients iteration %d" %
|
tf_logging.info("Comparing individual variable gradients iteration %d" %
|
||||||
i)
|
i)
|
||||||
self.assertAllClose(a, b)
|
self.assertAllEqual(a, b)
|
||||||
|
|
||||||
def testDynamicEquivalentToStaticRNN(self):
|
def testDynamicEquivalentToStaticRNN(self):
|
||||||
self._testDynamicEquivalentToStaticRNN(
|
self._testDynamicEquivalentToStaticRNN(
|
||||||
@ -1298,13 +1293,13 @@ class BidirectionalRNNTest(test.TestCase):
|
|||||||
# Both sequences in batch are length=8. Check that the time=i
|
# Both sequences in batch are length=8. Check that the time=i
|
||||||
# forward output is equal to time=8-1-i backward output
|
# forward output is equal to time=8-1-i backward output
|
||||||
for i in xrange(8):
|
for i in xrange(8):
|
||||||
self.assertAllClose(out[i][0][0], out[8 - 1 - i][0][3])
|
self.assertEqual(out[i][0][0], out[8 - 1 - i][0][3])
|
||||||
self.assertAllClose(out[i][0][1], out[8 - 1 - i][0][4])
|
self.assertEqual(out[i][0][1], out[8 - 1 - i][0][4])
|
||||||
self.assertAllClose(out[i][0][2], out[8 - 1 - i][0][5])
|
self.assertEqual(out[i][0][2], out[8 - 1 - i][0][5])
|
||||||
for i in xrange(8):
|
for i in xrange(8):
|
||||||
self.assertAllClose(out[i][1][0], out[8 - 1 - i][1][3])
|
self.assertEqual(out[i][1][0], out[8 - 1 - i][1][3])
|
||||||
self.assertAllClose(out[i][1][1], out[8 - 1 - i][1][4])
|
self.assertEqual(out[i][1][1], out[8 - 1 - i][1][4])
|
||||||
self.assertAllClose(out[i][1][2], out[8 - 1 - i][1][5])
|
self.assertEqual(out[i][1][2], out[8 - 1 - i][1][5])
|
||||||
# Via the reasoning above, the forward and backward final state should be
|
# Via the reasoning above, the forward and backward final state should be
|
||||||
# exactly the same
|
# exactly the same
|
||||||
self.assertAllClose(s_fw, s_bw)
|
self.assertAllClose(s_fw, s_bw)
|
||||||
@ -1404,27 +1399,27 @@ class BidirectionalRNNTest(test.TestCase):
|
|||||||
# Check that the time=0 forward output is equal to time=1 backward output
|
# Check that the time=0 forward output is equal to time=1 backward output
|
||||||
if not use_time_major:
|
if not use_time_major:
|
||||||
out = np.swapaxes(out, 0, 1)
|
out = np.swapaxes(out, 0, 1)
|
||||||
self.assertAllClose(out[0][0][0], out[1][0][3])
|
self.assertEqual(out[0][0][0], out[1][0][3])
|
||||||
self.assertAllClose(out[0][0][1], out[1][0][4])
|
self.assertEqual(out[0][0][1], out[1][0][4])
|
||||||
self.assertAllClose(out[0][0][2], out[1][0][5])
|
self.assertEqual(out[0][0][2], out[1][0][5])
|
||||||
# Check that the time=1 forward output is equal to time=0 backward output
|
# Check that the time=1 forward output is equal to time=0 backward output
|
||||||
self.assertAllClose(out[1][0][0], out[0][0][3])
|
self.assertEqual(out[1][0][0], out[0][0][3])
|
||||||
self.assertAllClose(out[1][0][1], out[0][0][4])
|
self.assertEqual(out[1][0][1], out[0][0][4])
|
||||||
self.assertAllClose(out[1][0][2], out[0][0][5])
|
self.assertEqual(out[1][0][2], out[0][0][5])
|
||||||
|
|
||||||
# Second sequence in batch is length=3
|
# Second sequence in batch is length=3
|
||||||
# Check that the time=0 forward output is equal to time=2 backward output
|
# Check that the time=0 forward output is equal to time=2 backward output
|
||||||
self.assertAllClose(out[0][1][0], out[2][1][3])
|
self.assertEqual(out[0][1][0], out[2][1][3])
|
||||||
self.assertAllClose(out[0][1][1], out[2][1][4])
|
self.assertEqual(out[0][1][1], out[2][1][4])
|
||||||
self.assertAllClose(out[0][1][2], out[2][1][5])
|
self.assertEqual(out[0][1][2], out[2][1][5])
|
||||||
# Check that the time=1 forward output is equal to time=1 backward output
|
# Check that the time=1 forward output is equal to time=1 backward output
|
||||||
self.assertAllClose(out[1][1][0], out[1][1][3])
|
self.assertEqual(out[1][1][0], out[1][1][3])
|
||||||
self.assertAllClose(out[1][1][1], out[1][1][4])
|
self.assertEqual(out[1][1][1], out[1][1][4])
|
||||||
self.assertAllClose(out[1][1][2], out[1][1][5])
|
self.assertEqual(out[1][1][2], out[1][1][5])
|
||||||
# Check that the time=2 forward output is equal to time=0 backward output
|
# Check that the time=2 forward output is equal to time=0 backward output
|
||||||
self.assertAllClose(out[2][1][0], out[0][1][3])
|
self.assertEqual(out[2][1][0], out[0][1][3])
|
||||||
self.assertAllClose(out[2][1][1], out[0][1][4])
|
self.assertEqual(out[2][1][1], out[0][1][4])
|
||||||
self.assertAllClose(out[2][1][2], out[0][1][5])
|
self.assertEqual(out[2][1][2], out[0][1][5])
|
||||||
# Via the reasoning above, the forward and backward final state should be
|
# Via the reasoning above, the forward and backward final state should be
|
||||||
# exactly the same
|
# exactly the same
|
||||||
self.assertAllClose(s_fw, s_bw)
|
self.assertAllClose(s_fw, s_bw)
|
||||||
@ -1565,13 +1560,13 @@ class MultiDimensionalLSTMTest(test.TestCase):
|
|||||||
outputs_sav_v = sess.run(outputs_sav,
|
outputs_sav_v = sess.run(outputs_sav,
|
||||||
feed_dict={inputs_using_dim[0]: input_value})
|
feed_dict={inputs_using_dim[0]: input_value})
|
||||||
|
|
||||||
self.assertAllClose(outputs_static_v, outputs_dynamic_v)
|
self.assertAllEqual(outputs_static_v, outputs_dynamic_v)
|
||||||
self.assertAllClose(outputs_static_v, outputs_sav_v)
|
self.assertAllEqual(outputs_static_v, outputs_sav_v)
|
||||||
outputs_static_array = np.array(outputs_static_v)
|
outputs_static_array = np.array(outputs_static_v)
|
||||||
outputs_static_array_double = np.concatenate(
|
outputs_static_array_double = np.concatenate(
|
||||||
(outputs_static_array, outputs_static_array), axis=2)
|
(outputs_static_array, outputs_static_array), axis=2)
|
||||||
outputs_bid_array = np.array(outputs_bid_v)
|
outputs_bid_array = np.array(outputs_bid_v)
|
||||||
self.assertAllClose(outputs_static_array_double, outputs_bid_array)
|
self.assertAllEqual(outputs_static_array_double, outputs_bid_array)
|
||||||
|
|
||||||
state_static_v = sess.run(state_static,
|
state_static_v = sess.run(state_static,
|
||||||
feed_dict={inputs[0]: input_value})
|
feed_dict={inputs[0]: input_value})
|
||||||
@ -1583,10 +1578,10 @@ class MultiDimensionalLSTMTest(test.TestCase):
|
|||||||
feed_dict={inputs_using_dim[0]: input_value})
|
feed_dict={inputs_using_dim[0]: input_value})
|
||||||
state_sav_v = sess.run(state_sav,
|
state_sav_v = sess.run(state_sav,
|
||||||
feed_dict={inputs_using_dim[0]: input_value})
|
feed_dict={inputs_using_dim[0]: input_value})
|
||||||
self.assertAllClose(np.hstack(state_static_v), np.hstack(state_dynamic_v))
|
self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_dynamic_v))
|
||||||
self.assertAllClose(np.hstack(state_static_v), np.hstack(state_sav_v))
|
self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_sav_v))
|
||||||
self.assertAllClose(np.hstack(state_static_v), np.hstack(state_bid_fw_v))
|
self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_bid_fw_v))
|
||||||
self.assertAllClose(np.hstack(state_static_v), np.hstack(state_bid_bw_v))
|
self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_bid_bw_v))
|
||||||
|
|
||||||
|
|
||||||
class NestedLSTMTest(test.TestCase):
|
class NestedLSTMTest(test.TestCase):
|
||||||
@ -1668,14 +1663,14 @@ class NestedLSTMTest(test.TestCase):
|
|||||||
outputs_bid_v = sess.run(outputs_bid,
|
outputs_bid_v = sess.run(outputs_bid,
|
||||||
feed_dict={single_input_using_dim: input_value})
|
feed_dict={single_input_using_dim: input_value})
|
||||||
|
|
||||||
self.assertAllClose(outputs_static_v,
|
self.assertAllEqual(outputs_static_v,
|
||||||
np.transpose(outputs_dynamic_v, (1, 0, 2, 3)))
|
np.transpose(outputs_dynamic_v, (1, 0, 2, 3)))
|
||||||
self.assertAllClose(outputs_static_v, outputs_sav_v)
|
self.assertAllEqual(outputs_static_v, outputs_sav_v)
|
||||||
outputs_static_array = np.array(outputs_static_v)
|
outputs_static_array = np.array(outputs_static_v)
|
||||||
outputs_static_array_double = np.concatenate(
|
outputs_static_array_double = np.concatenate(
|
||||||
(outputs_static_array, outputs_static_array), axis=3)
|
(outputs_static_array, outputs_static_array), axis=3)
|
||||||
outputs_bid_array = np.array(outputs_bid_v)
|
outputs_bid_array = np.array(outputs_bid_v)
|
||||||
self.assertAllClose(outputs_static_array_double, outputs_bid_array)
|
self.assertAllEqual(outputs_static_array_double, outputs_bid_array)
|
||||||
|
|
||||||
state_dynamic_v = sess.run(state_dynamic,
|
state_dynamic_v = sess.run(state_dynamic,
|
||||||
feed_dict={single_input: input_value})
|
feed_dict={single_input: input_value})
|
||||||
@ -1687,10 +1682,10 @@ class NestedLSTMTest(test.TestCase):
|
|||||||
feed_dict={single_input_using_dim: input_value})
|
feed_dict={single_input_using_dim: input_value})
|
||||||
state_sav_v = sess.run(state_sav,
|
state_sav_v = sess.run(state_sav,
|
||||||
feed_dict={single_input_using_dim: input_value})
|
feed_dict={single_input_using_dim: input_value})
|
||||||
self.assertAllClose(np.hstack(state_static_v), np.hstack(state_dynamic_v))
|
self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_dynamic_v))
|
||||||
self.assertAllClose(np.hstack(state_static_v), np.hstack(state_sav_v))
|
self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_sav_v))
|
||||||
self.assertAllClose(np.hstack(state_static_v), np.hstack(state_bid_fw_v))
|
self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_bid_fw_v))
|
||||||
self.assertAllClose(np.hstack(state_static_v), np.hstack(state_bid_bw_v))
|
self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_bid_bw_v))
|
||||||
|
|
||||||
|
|
||||||
class StateSaverRNNTest(test.TestCase):
|
class StateSaverRNNTest(test.TestCase):
|
||||||
|
@ -22,7 +22,6 @@ from __future__ import print_function
|
|||||||
import collections
|
import collections
|
||||||
import math
|
import math
|
||||||
|
|
||||||
from tensorflow.python.framework import function
|
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import clip_ops
|
from tensorflow.python.ops import clip_ops
|
||||||
@ -62,7 +61,7 @@ class BasicRNNCell(RNNCell):
|
|||||||
"""Most basic RNN: output = new_state = act(W * input + U * state + B)."""
|
"""Most basic RNN: output = new_state = act(W * input + U * state + B)."""
|
||||||
with vs.variable_scope(scope or "basic_rnn_cell"):
|
with vs.variable_scope(scope or "basic_rnn_cell"):
|
||||||
output = self._activation(
|
output = self._activation(
|
||||||
_linear([inputs, state], self._num_units, True))
|
_linear([inputs, state], self._num_units, True, scope=scope))
|
||||||
return output, output
|
return output, output
|
||||||
|
|
||||||
|
|
||||||
@ -90,13 +89,14 @@ class GRUCell(RNNCell):
|
|||||||
# We start with bias of 1.0 to not reset and not update.
|
# We start with bias of 1.0 to not reset and not update.
|
||||||
r, u = array_ops.split(
|
r, u = array_ops.split(
|
||||||
value=_linear(
|
value=_linear(
|
||||||
[inputs, state], 2 * self._num_units, True, 1.0),
|
[inputs, state], 2 * self._num_units, True, 1.0, scope=scope),
|
||||||
num_or_size_splits=2,
|
num_or_size_splits=2,
|
||||||
axis=1)
|
axis=1)
|
||||||
r, u = sigmoid(r), sigmoid(u)
|
r, u = sigmoid(r), sigmoid(u)
|
||||||
with vs.variable_scope("candidate"):
|
with vs.variable_scope("candidate"):
|
||||||
c = self._activation(_linear([inputs, r * state],
|
c = self._activation(_linear([inputs, r * state],
|
||||||
self._num_units, True))
|
self._num_units, True,
|
||||||
|
scope=scope))
|
||||||
new_h = u * state + (1 - u) * c
|
new_h = u * state + (1 - u) * c
|
||||||
return new_h, new_h
|
return new_h, new_h
|
||||||
|
|
||||||
@ -176,7 +176,7 @@ class BasicLSTMCell(RNNCell):
|
|||||||
c, h = state
|
c, h = state
|
||||||
else:
|
else:
|
||||||
c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1)
|
c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1)
|
||||||
concat = _linear([inputs, h], 4 * self._num_units, True)
|
concat = _linear([inputs, h], 4 * self._num_units, True, scope=scope)
|
||||||
|
|
||||||
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
|
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
|
||||||
i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)
|
i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)
|
||||||
@ -192,13 +192,6 @@ class BasicLSTMCell(RNNCell):
|
|||||||
return new_h, new_state
|
return new_h, new_state
|
||||||
|
|
||||||
|
|
||||||
def _maybe_compile(fun, compiled):
|
|
||||||
if not compiled:
|
|
||||||
return fun
|
|
||||||
else:
|
|
||||||
return function.Defun(noinline=True, compiled=True)(fun)
|
|
||||||
|
|
||||||
|
|
||||||
class LSTMCell(RNNCell):
|
class LSTMCell(RNNCell):
|
||||||
"""Long short-term memory unit (LSTM) recurrent network cell.
|
"""Long short-term memory unit (LSTM) recurrent network cell.
|
||||||
|
|
||||||
@ -226,7 +219,7 @@ class LSTMCell(RNNCell):
|
|||||||
initializer=None, num_proj=None, proj_clip=None,
|
initializer=None, num_proj=None, proj_clip=None,
|
||||||
num_unit_shards=None, num_proj_shards=None,
|
num_unit_shards=None, num_proj_shards=None,
|
||||||
forget_bias=1.0, state_is_tuple=True,
|
forget_bias=1.0, state_is_tuple=True,
|
||||||
activation=tanh, compiled=False):
|
activation=tanh):
|
||||||
"""Initialize the parameters for an LSTM cell.
|
"""Initialize the parameters for an LSTM cell.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -253,12 +246,6 @@ class LSTMCell(RNNCell):
|
|||||||
the `c_state` and `m_state`. If False, they are concatenated
|
the `c_state` and `m_state`. If False, they are concatenated
|
||||||
along the column axis. This latter behavior will soon be deprecated.
|
along the column axis. This latter behavior will soon be deprecated.
|
||||||
activation: Activation function of the inner states.
|
activation: Activation function of the inner states.
|
||||||
compiled: Python boolean. If `True`, the core computation of the LSTM
|
|
||||||
cell is compiled via XLA. As of now, this provides speedups for
|
|
||||||
most GPU calculations, and on small batch CPU and embedded calculations.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: if compiled=True and state_is_tuple=False (not supported).
|
|
||||||
"""
|
"""
|
||||||
if not state_is_tuple:
|
if not state_is_tuple:
|
||||||
logging.warn("%s: Using a concatenated state is slower and will soon be "
|
logging.warn("%s: Using a concatenated state is slower and will soon be "
|
||||||
@ -270,9 +257,6 @@ class LSTMCell(RNNCell):
|
|||||||
"%s: The num_unit_shards and proj_unit_shards parameters are "
|
"%s: The num_unit_shards and proj_unit_shards parameters are "
|
||||||
"deprecated and will be removed in Jan 2017. "
|
"deprecated and will be removed in Jan 2017. "
|
||||||
"Use a variable scope with a partitioner instead.", self)
|
"Use a variable scope with a partitioner instead.", self)
|
||||||
if not state_is_tuple and compiled:
|
|
||||||
raise ValueError(
|
|
||||||
"Combining state_is_tuple=False and compiled=True is not supported.")
|
|
||||||
|
|
||||||
self._num_units = num_units
|
self._num_units = num_units
|
||||||
self._use_peepholes = use_peepholes
|
self._use_peepholes = use_peepholes
|
||||||
@ -285,7 +269,6 @@ class LSTMCell(RNNCell):
|
|||||||
self._forget_bias = forget_bias
|
self._forget_bias = forget_bias
|
||||||
self._state_is_tuple = state_is_tuple
|
self._state_is_tuple = state_is_tuple
|
||||||
self._activation = activation
|
self._activation = activation
|
||||||
self._compiled = compiled
|
|
||||||
|
|
||||||
if num_proj:
|
if num_proj:
|
||||||
self._state_size = (
|
self._state_size = (
|
||||||
@ -334,111 +317,73 @@ class LSTMCell(RNNCell):
|
|||||||
"""
|
"""
|
||||||
num_proj = self._num_units if self._num_proj is None else self._num_proj
|
num_proj = self._num_units if self._num_proj is None else self._num_proj
|
||||||
|
|
||||||
def _kernel(k_inputs, state_p0, state_p1):
|
if self._state_is_tuple:
|
||||||
"""Internal kernel for the single step of LSTM.
|
(c_prev, m_prev) = state
|
||||||
|
else:
|
||||||
|
c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units])
|
||||||
|
m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj])
|
||||||
|
|
||||||
Args:
|
dtype = inputs.dtype
|
||||||
k_inputs: Input Tensor.
|
input_size = inputs.get_shape().with_rank(2)[1]
|
||||||
state_p0: Either the state or the c component of the state.
|
if input_size.value is None:
|
||||||
state_p1: Either the state or the m component of the state.
|
raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
|
||||||
|
with vs.variable_scope(scope or "lstm_cell",
|
||||||
|
initializer=self._initializer) as unit_scope:
|
||||||
|
if self._num_unit_shards is not None:
|
||||||
|
unit_scope.set_partitioner(
|
||||||
|
partitioned_variables.fixed_size_partitioner(
|
||||||
|
self._num_unit_shards))
|
||||||
|
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
|
||||||
|
lstm_matrix = _linear([inputs, m_prev], 4 * self._num_units, bias=True,
|
||||||
|
scope=scope)
|
||||||
|
i, j, f, o = array_ops.split(
|
||||||
|
value=lstm_matrix, num_or_size_splits=4, axis=1)
|
||||||
|
|
||||||
Returns:
|
# Diagonal connections
|
||||||
(m, c) or (m, concat([c, m])) depending on state_is_tuple.
|
if self._use_peepholes:
|
||||||
|
with vs.variable_scope(unit_scope) as projection_scope:
|
||||||
|
if self._num_unit_shards is not None:
|
||||||
|
projection_scope.set_partitioner(None)
|
||||||
|
w_f_diag = vs.get_variable(
|
||||||
|
"w_f_diag", shape=[self._num_units], dtype=dtype)
|
||||||
|
w_i_diag = vs.get_variable(
|
||||||
|
"w_i_diag", shape=[self._num_units], dtype=dtype)
|
||||||
|
w_o_diag = vs.get_variable(
|
||||||
|
"w_o_diag", shape=[self._num_units], dtype=dtype)
|
||||||
|
|
||||||
Raises:
|
if self._use_peepholes:
|
||||||
ValueError: see above docstring.
|
c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev +
|
||||||
"""
|
sigmoid(i + w_i_diag * c_prev) * self._activation(j))
|
||||||
k_inputs.set_shape(inputs.get_shape())
|
|
||||||
if self._state_is_tuple:
|
|
||||||
(c_prev, m_prev) = state_p0, state_p1
|
|
||||||
c_prev.set_shape(state[0].get_shape())
|
|
||||||
m_prev.set_shape(state[1].get_shape())
|
|
||||||
else:
|
else:
|
||||||
k_state = state_p0
|
c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) *
|
||||||
c_prev = array_ops.slice(k_state, [0, 0], [-1, self._num_units])
|
self._activation(j))
|
||||||
m_prev = array_ops.slice(k_state, [0, self._num_units], [-1, num_proj])
|
|
||||||
|
|
||||||
dtype = k_inputs.dtype
|
if self._cell_clip is not None:
|
||||||
input_size = k_inputs.get_shape().with_rank(2)[1]
|
# pylint: disable=invalid-unary-operand-type
|
||||||
if input_size.value is None:
|
c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)
|
||||||
raise ValueError(
|
# pylint: enable=invalid-unary-operand-type
|
||||||
"Could not infer input size from inputs.get_shape()[-1]")
|
|
||||||
with vs.variable_scope(scope or "lstm_cell",
|
|
||||||
initializer=self._initializer) as unit_scope:
|
|
||||||
if self._num_unit_shards is not None:
|
|
||||||
unit_scope.set_partitioner(
|
|
||||||
partitioned_variables.fixed_size_partitioner(
|
|
||||||
self._num_unit_shards))
|
|
||||||
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
|
|
||||||
lstm_matrix = _linear(
|
|
||||||
[k_inputs, m_prev], 4 * self._num_units, bias=True,
|
|
||||||
compiled=self._compiled)
|
|
||||||
i, j, f, o = array_ops.split(
|
|
||||||
value=lstm_matrix, num_or_size_splits=4, axis=1)
|
|
||||||
|
|
||||||
# Diagonal connections
|
if self._use_peepholes:
|
||||||
if self._use_peepholes:
|
m = sigmoid(o + w_o_diag * c) * self._activation(c)
|
||||||
with vs.variable_scope(unit_scope) as projection_scope:
|
else:
|
||||||
if self._num_unit_shards is not None:
|
m = sigmoid(o) * self._activation(c)
|
||||||
projection_scope.set_partitioner(None)
|
|
||||||
w_f_diag = vs.get_variable(
|
|
||||||
"w_f_diag", shape=[self._num_units], dtype=dtype)
|
|
||||||
w_i_diag = vs.get_variable(
|
|
||||||
"w_i_diag", shape=[self._num_units], dtype=dtype)
|
|
||||||
w_o_diag = vs.get_variable(
|
|
||||||
"w_o_diag", shape=[self._num_units], dtype=dtype)
|
|
||||||
c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev +
|
|
||||||
sigmoid(i + w_i_diag * c_prev) * self._activation(j))
|
|
||||||
else:
|
|
||||||
c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) *
|
|
||||||
self._activation(j))
|
|
||||||
|
|
||||||
if self._cell_clip is not None:
|
if self._num_proj is not None:
|
||||||
|
with vs.variable_scope("projection") as proj_scope:
|
||||||
|
if self._num_proj_shards is not None:
|
||||||
|
proj_scope.set_partitioner(
|
||||||
|
partitioned_variables.fixed_size_partitioner(
|
||||||
|
self._num_proj_shards))
|
||||||
|
m = _linear(m, self._num_proj, bias=False, scope=scope)
|
||||||
|
|
||||||
|
if self._proj_clip is not None:
|
||||||
# pylint: disable=invalid-unary-operand-type
|
# pylint: disable=invalid-unary-operand-type
|
||||||
c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)
|
m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
|
||||||
# pylint: enable=invalid-unary-operand-type
|
# pylint: enable=invalid-unary-operand-type
|
||||||
|
|
||||||
if self._use_peepholes:
|
new_state = (LSTMStateTuple(c, m) if self._state_is_tuple else
|
||||||
m = sigmoid(o + w_o_diag * c) * self._activation(c)
|
array_ops.concat([c, m], 1))
|
||||||
else:
|
return m, new_state
|
||||||
m = sigmoid(o) * self._activation(c)
|
|
||||||
|
|
||||||
if self._num_proj is not None:
|
|
||||||
with vs.variable_scope("projection") as proj_scope:
|
|
||||||
if self._num_proj_shards is not None:
|
|
||||||
proj_scope.set_partitioner(
|
|
||||||
partitioned_variables.fixed_size_partitioner(
|
|
||||||
self._num_proj_shards))
|
|
||||||
m = _linear(m, self._num_proj, bias=False, compiled=self._compiled)
|
|
||||||
|
|
||||||
if self._proj_clip is not None:
|
|
||||||
# pylint: disable=invalid-unary-operand-type
|
|
||||||
m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
|
|
||||||
# pylint: enable=invalid-unary-operand-type
|
|
||||||
|
|
||||||
if self._state_is_tuple:
|
|
||||||
return m, c
|
|
||||||
else:
|
|
||||||
return m, array_ops.concat([c, m], 1)
|
|
||||||
|
|
||||||
compiled_kernel = _maybe_compile(_kernel, self._compiled)
|
|
||||||
|
|
||||||
if self._state_is_tuple:
|
|
||||||
batch_shape = (
|
|
||||||
inputs.get_shape()[:1].merge_with(
|
|
||||||
state[0].get_shape()[:1]).merge_with(
|
|
||||||
state[1].get_shape()[:1]))
|
|
||||||
emit_m, emit_c = compiled_kernel(inputs, state[0], state[1])
|
|
||||||
emit_c.set_shape(batch_shape.concatenate([state[0].get_shape()[1]]))
|
|
||||||
emit_m.set_shape(batch_shape.concatenate([state[1].get_shape()[1]]))
|
|
||||||
emit_state = LSTMStateTuple(emit_c, emit_m)
|
|
||||||
else:
|
|
||||||
batch_shape = inputs.get_shape()[:1].merge_with(state.get_shape()[:1])
|
|
||||||
emit_m, emit_state = compiled_kernel(inputs, state, state)
|
|
||||||
emit_m.set_shape(batch_shape.concatenate([num_proj]))
|
|
||||||
emit_state.set_shape(batch_shape.concatenate([state.get_shape()[1]]))
|
|
||||||
|
|
||||||
return emit_m, emit_state
|
|
||||||
|
|
||||||
|
|
||||||
class OutputProjectionWrapper(RNNCell):
|
class OutputProjectionWrapper(RNNCell):
|
||||||
@ -481,7 +426,7 @@ class OutputProjectionWrapper(RNNCell):
|
|||||||
output, res_state = self._cell(inputs, state)
|
output, res_state = self._cell(inputs, state)
|
||||||
# Default scope: "OutputProjectionWrapper"
|
# Default scope: "OutputProjectionWrapper"
|
||||||
with vs.variable_scope(scope or "output_projection_wrapper"):
|
with vs.variable_scope(scope or "output_projection_wrapper"):
|
||||||
projected = _linear(output, self._output_size, True)
|
projected = _linear(output, self._output_size, True, scope=scope)
|
||||||
return projected, res_state
|
return projected, res_state
|
||||||
|
|
||||||
|
|
||||||
@ -523,7 +468,7 @@ class InputProjectionWrapper(RNNCell):
|
|||||||
"""Run the input projection and then the cell."""
|
"""Run the input projection and then the cell."""
|
||||||
# Default scope: "InputProjectionWrapper"
|
# Default scope: "InputProjectionWrapper"
|
||||||
with vs.variable_scope(scope or "input_projection_wrapper"):
|
with vs.variable_scope(scope or "input_projection_wrapper"):
|
||||||
projected = _linear(inputs, self._num_proj, True)
|
projected = _linear(inputs, self._num_proj, True, scope=scope)
|
||||||
return self._cell(projected, state)
|
return self._cell(projected, state)
|
||||||
|
|
||||||
|
|
||||||
@ -817,7 +762,7 @@ class _SlimRNNCell(RNNCell):
|
|||||||
return output, state
|
return output, state
|
||||||
|
|
||||||
|
|
||||||
def _linear(args, output_size, bias, bias_start=0.0, compiled=False):
|
def _linear(args, output_size, bias, bias_start=0.0, scope=None):
|
||||||
"""Linear map: sum_i(args[i] * W[i]), where W[i] is a variable.
|
"""Linear map: sum_i(args[i] * W[i]), where W[i] is a variable.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -825,7 +770,7 @@ def _linear(args, output_size, bias, bias_start=0.0, compiled=False):
|
|||||||
output_size: int, second dimension of W[i].
|
output_size: int, second dimension of W[i].
|
||||||
bias: boolean, whether to add a bias term or not.
|
bias: boolean, whether to add a bias term or not.
|
||||||
bias_start: starting value to initialize the bias; 0 by default.
|
bias_start: starting value to initialize the bias; 0 by default.
|
||||||
compiled: boolean, _linear plays nicely with XLA if it is enabled.
|
scope: (optional) Variable scope to create parameters in.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A 2D Tensor with shape [batch x output_size] equal to
|
A 2D Tensor with shape [batch x output_size] equal to
|
||||||
@ -870,8 +815,4 @@ def _linear(args, output_size, bias, bias_start=0.0, compiled=False):
|
|||||||
"biases", [output_size],
|
"biases", [output_size],
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
initializer=init_ops.constant_initializer(bias_start, dtype=dtype))
|
initializer=init_ops.constant_initializer(bias_start, dtype=dtype))
|
||||||
if compiled:
|
return nn_ops.bias_add(res, biases)
|
||||||
# TODO(b/34505635): Defuns don't play well with bias_add
|
|
||||||
return res + biases
|
|
||||||
else:
|
|
||||||
return nn_ops.bias_add(res, biases)
|
|
||||||
|
@ -113,8 +113,7 @@ class BasicSamplingDecoder(decoder.Decoder):
|
|||||||
dtypes.int32)
|
dtypes.int32)
|
||||||
|
|
||||||
def initialize(self, name=None):
|
def initialize(self, name=None):
|
||||||
with ops.name_scope("basic_sampling_decoder_initialize"):
|
return self._sampler.initialize() + (self._initial_state,)
|
||||||
return self._sampler.initialize() + (self._initial_state,)
|
|
||||||
|
|
||||||
def step(self, time, inputs, state):
|
def step(self, time, inputs, state):
|
||||||
"""Perform a decoding step.
|
"""Perform a decoding step.
|
||||||
@ -127,12 +126,11 @@ class BasicSamplingDecoder(decoder.Decoder):
|
|||||||
Returns:
|
Returns:
|
||||||
`(outputs, next_state, next_inputs, finished)`.
|
`(outputs, next_state, next_inputs, finished)`.
|
||||||
"""
|
"""
|
||||||
with ops.name_scope("basic_sampling_decoder_step"):
|
cell_outputs, next_state = self._cell(inputs, state)
|
||||||
cell_outputs, next_state = self._cell(inputs, state)
|
(sample_id, finished, next_inputs) = self._sampler.sample(
|
||||||
(sample_id, finished, next_inputs) = self._sampler.sample(
|
time=time, outputs=cell_outputs, state=next_state)
|
||||||
time=time, outputs=cell_outputs, state=next_state)
|
outputs = SamplingDecoderOutput(cell_outputs, sample_id)
|
||||||
outputs = SamplingDecoderOutput(cell_outputs, sample_id)
|
return (outputs, next_state, next_inputs, finished)
|
||||||
return (outputs, next_state, next_inputs, finished)
|
|
||||||
|
|
||||||
|
|
||||||
class BasicTrainingSampler(Sampler):
|
class BasicTrainingSampler(Sampler):
|
||||||
|
@ -7,6 +7,7 @@ load("//tensorflow:tensorflow.bzl", "if_not_mobile")
|
|||||||
# configure may change the following lines
|
# configure may change the following lines
|
||||||
WITH_GCP_SUPPORT = False
|
WITH_GCP_SUPPORT = False
|
||||||
WITH_HDFS_SUPPORT = False
|
WITH_HDFS_SUPPORT = False
|
||||||
|
WITH_XLA_SUPPORT = False
|
||||||
WITH_JEMALLOC = True
|
WITH_JEMALLOC = True
|
||||||
|
|
||||||
# Appends a suffix to a list of deps.
|
# Appends a suffix to a list of deps.
|
||||||
@ -241,3 +242,15 @@ def tf_additional_cloud_kernel_deps():
|
|||||||
#if WITH_GCP_SUPPORT:
|
#if WITH_GCP_SUPPORT:
|
||||||
# deps = if_not_mobile(["//tensorflow/core:cloud_ops_op_lib"])
|
# deps = if_not_mobile(["//tensorflow/core:cloud_ops_op_lib"])
|
||||||
return deps
|
return deps
|
||||||
|
|
||||||
|
def tf_additional_plugin_deps():
|
||||||
|
deps = []
|
||||||
|
if WITH_XLA_SUPPORT:
|
||||||
|
deps.append("//tensorflow/compiler/jit")
|
||||||
|
return deps
|
||||||
|
|
||||||
|
def tf_additional_license_deps():
|
||||||
|
licenses = []
|
||||||
|
if WITH_XLA_SUPPORT:
|
||||||
|
licenses.append("@llvm//:LICENSE.TXT")
|
||||||
|
return licenses
|
||||||
|
@ -2,25 +2,8 @@
|
|||||||
# The functions in this file might be referred by tensorflow.bzl. They have to
|
# The functions in this file might be referred by tensorflow.bzl. They have to
|
||||||
# be separate to avoid cyclic references.
|
# be separate to avoid cyclic references.
|
||||||
|
|
||||||
WITH_XLA_SUPPORT = False
|
|
||||||
|
|
||||||
def tf_cuda_tests_tags():
|
def tf_cuda_tests_tags():
|
||||||
return ["local"]
|
return ["local"]
|
||||||
|
|
||||||
def tf_sycl_tests_tags():
|
def tf_sycl_tests_tags():
|
||||||
return ["local"]
|
return ["local"]
|
||||||
|
|
||||||
def tf_additional_plugin_deps():
|
|
||||||
deps = []
|
|
||||||
if WITH_XLA_SUPPORT:
|
|
||||||
deps.append("//tensorflow/compiler/jit")
|
|
||||||
return deps
|
|
||||||
|
|
||||||
def tf_additional_xla_deps_py():
|
|
||||||
return []
|
|
||||||
|
|
||||||
def tf_additional_license_deps():
|
|
||||||
licenses = []
|
|
||||||
if WITH_XLA_SUPPORT:
|
|
||||||
licenses.append("@llvm//:LICENSE.TXT")
|
|
||||||
return licenses
|
|
||||||
|
@ -23,7 +23,7 @@ load("//tensorflow:tensorflow.bzl", "cuda_py_tests")
|
|||||||
load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library")
|
load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library")
|
||||||
load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library_py")
|
load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library_py")
|
||||||
load("//tensorflow/core:platform/default/build_config.bzl", "tf_additional_lib_deps")
|
load("//tensorflow/core:platform/default/build_config.bzl", "tf_additional_lib_deps")
|
||||||
load("//tensorflow/core:platform/default/build_config_root.bzl", "tf_additional_plugin_deps")
|
load("//tensorflow/core:platform/default/build_config.bzl", "tf_additional_plugin_deps")
|
||||||
load("//tensorflow/python:build_defs.bzl", "tf_gen_op_wrapper_private_py")
|
load("//tensorflow/python:build_defs.bzl", "tf_gen_op_wrapper_private_py")
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
|
@ -12,7 +12,6 @@ load(
|
|||||||
"//tensorflow/core:platform/default/build_config_root.bzl",
|
"//tensorflow/core:platform/default/build_config_root.bzl",
|
||||||
"tf_cuda_tests_tags",
|
"tf_cuda_tests_tags",
|
||||||
"tf_sycl_tests_tags",
|
"tf_sycl_tests_tags",
|
||||||
"tf_additional_xla_deps_py",
|
|
||||||
)
|
)
|
||||||
load(
|
load(
|
||||||
"@local_config_cuda//cuda:build_defs.bzl",
|
"@local_config_cuda//cuda:build_defs.bzl",
|
||||||
@ -790,10 +789,7 @@ def py_test(deps=[], **kwargs):
|
|||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
def tf_py_test(name, srcs, size="medium", data=[], main=None, args=[],
|
def tf_py_test(name, srcs, size="medium", data=[], main=None, args=[],
|
||||||
tags=[], shard_count=1, additional_deps=[], flaky=0,
|
tags=[], shard_count=1, additional_deps=[], flaky=0):
|
||||||
xla_enabled=False):
|
|
||||||
if xla_enabled:
|
|
||||||
additional_deps += tf_additional_xla_deps_py()
|
|
||||||
native.py_test(
|
native.py_test(
|
||||||
name=name,
|
name=name,
|
||||||
size=size,
|
size=size,
|
||||||
@ -815,8 +811,7 @@ def tf_py_test(name, srcs, size="medium", data=[], main=None, args=[],
|
|||||||
srcs_version="PY2AND3")
|
srcs_version="PY2AND3")
|
||||||
|
|
||||||
def cuda_py_test(name, srcs, size="medium", data=[], main=None, args=[],
|
def cuda_py_test(name, srcs, size="medium", data=[], main=None, args=[],
|
||||||
shard_count=1, additional_deps=[], tags=[], flaky=0,
|
shard_count=1, additional_deps=[], tags=[], flaky=0):
|
||||||
xla_enabled=False):
|
|
||||||
test_tags = tags + tf_cuda_tests_tags()
|
test_tags = tags + tf_cuda_tests_tags()
|
||||||
tf_py_test(name=name,
|
tf_py_test(name=name,
|
||||||
size=size,
|
size=size,
|
||||||
@ -827,12 +822,10 @@ def cuda_py_test(name, srcs, size="medium", data=[], main=None, args=[],
|
|||||||
tags=test_tags,
|
tags=test_tags,
|
||||||
shard_count=shard_count,
|
shard_count=shard_count,
|
||||||
additional_deps=additional_deps,
|
additional_deps=additional_deps,
|
||||||
flaky=flaky,
|
flaky=flaky)
|
||||||
xla_enabled=xla_enabled)
|
|
||||||
|
|
||||||
def sycl_py_test(name, srcs, size="medium", data=[], main=None, args=[],
|
def sycl_py_test(name, srcs, size="medium", data=[], main=None, args=[],
|
||||||
shard_count=1, additional_deps=[], tags=[], flaky=0,
|
shard_count=1, additional_deps=[], tags=[], flaky=0):
|
||||||
xla_enabled=False):
|
|
||||||
test_tags = tags + tf_sycl_tests_tags()
|
test_tags = tags + tf_sycl_tests_tags()
|
||||||
tf_py_test(name=name,
|
tf_py_test(name=name,
|
||||||
size=size,
|
size=size,
|
||||||
@ -843,8 +836,7 @@ def sycl_py_test(name, srcs, size="medium", data=[], main=None, args=[],
|
|||||||
tags=test_tags,
|
tags=test_tags,
|
||||||
shard_count=shard_count,
|
shard_count=shard_count,
|
||||||
additional_deps=additional_deps,
|
additional_deps=additional_deps,
|
||||||
flaky=flaky,
|
flaky=flaky)
|
||||||
xla_enabled=xla_enabled)
|
|
||||||
|
|
||||||
def py_tests(name,
|
def py_tests(name,
|
||||||
srcs,
|
srcs,
|
||||||
@ -853,8 +845,7 @@ def py_tests(name,
|
|||||||
data=[],
|
data=[],
|
||||||
tags=[],
|
tags=[],
|
||||||
shard_count=1,
|
shard_count=1,
|
||||||
prefix="",
|
prefix=""):
|
||||||
xla_enabled=False):
|
|
||||||
for src in srcs:
|
for src in srcs:
|
||||||
test_name = src.split("/")[-1].split(".")[0]
|
test_name = src.split("/")[-1].split(".")[0]
|
||||||
if prefix:
|
if prefix:
|
||||||
@ -866,15 +857,13 @@ def py_tests(name,
|
|||||||
tags=tags,
|
tags=tags,
|
||||||
shard_count=shard_count,
|
shard_count=shard_count,
|
||||||
data=data,
|
data=data,
|
||||||
additional_deps=additional_deps,
|
additional_deps=additional_deps)
|
||||||
xla_enabled=xla_enabled)
|
|
||||||
|
|
||||||
def cuda_py_tests(name, srcs, size="medium", additional_deps=[], data=[],
|
def cuda_py_tests(name, srcs, size="medium", additional_deps=[], data=[],
|
||||||
shard_count=1, tags=[], prefix="", xla_enabled=False):
|
shard_count=1, tags=[], prefix=""):
|
||||||
test_tags = tags + tf_cuda_tests_tags()
|
test_tags = tags + tf_cuda_tests_tags()
|
||||||
py_tests(name=name, size=size, srcs=srcs, additional_deps=additional_deps,
|
py_tests(name=name, size=size, srcs=srcs, additional_deps=additional_deps,
|
||||||
data=data, tags=test_tags, shard_count=shard_count,prefix=prefix,
|
data=data, tags=test_tags, shard_count=shard_count,prefix=prefix)
|
||||||
xla_enabled=xla_enabled)
|
|
||||||
|
|
||||||
# Creates a genrule named <name> for running tools/proto_text's generator to
|
# Creates a genrule named <name> for running tools/proto_text's generator to
|
||||||
# make the proto_text functions, for the protos passed in <srcs>.
|
# make the proto_text functions, for the protos passed in <srcs>.
|
||||||
|
@ -4,7 +4,7 @@
|
|||||||
package(default_visibility = ["//visibility:private"])
|
package(default_visibility = ["//visibility:private"])
|
||||||
|
|
||||||
load("//tensorflow:tensorflow.bzl", "transitive_hdrs")
|
load("//tensorflow:tensorflow.bzl", "transitive_hdrs")
|
||||||
load("//tensorflow/core:platform/default/build_config_root.bzl", "tf_additional_license_deps")
|
load("//tensorflow/core:platform/default/build_config.bzl", "tf_additional_license_deps")
|
||||||
|
|
||||||
# This returns a list of headers of all public header libraries (e.g.,
|
# This returns a list of headers of all public header libraries (e.g.,
|
||||||
# framework, lib), and all of the transitive dependencies of those
|
# framework, lib), and all of the transitive dependencies of those
|
||||||
|
Loading…
Reference in New Issue
Block a user