Merge pull request #27756 from houtoms:lstmp_upstream
PiperOrigin-RevId: 252553312
This commit is contained in:
commit
9380a41290
@ -27,6 +27,7 @@ import numpy as np
|
||||
|
||||
from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops
|
||||
from tensorflow.core.protobuf import saver_pb2
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
@ -71,7 +72,8 @@ def RunLSTM(sess,
|
||||
is_training=True,
|
||||
dropout=0.,
|
||||
num_dirs=True,
|
||||
dtype=dtypes.float32):
|
||||
dtype=dtypes.float32,
|
||||
num_proj=None):
|
||||
# TODO(jamesqin): add multi-layer tests.
|
||||
# TODO(jamesqin): add multi-dir tests
|
||||
assert num_layers == 1
|
||||
@ -91,10 +93,12 @@ def RunLSTM(sess,
|
||||
inputs_dynamic = array_ops.placeholder(
|
||||
dtype, shape=[None, None, None], name="inputs")
|
||||
inputs = inputs_dynamic if dynamic_shape_input else inputs_static
|
||||
unified_num_units = num_proj if num_proj else num_units
|
||||
unified_num_proj = num_proj if num_proj else None
|
||||
initial_h_op = variable_scope.get_variable(
|
||||
"initial_h_op",
|
||||
initializer=np.random.rand(batch_size,
|
||||
num_units).astype(dtype.as_numpy_dtype),
|
||||
initializer=np.random.rand(batch_size, unified_num_units).astype(
|
||||
dtype.as_numpy_dtype),
|
||||
dtype=dtype)
|
||||
initial_c_op = variable_scope.get_variable(
|
||||
"initial_c_op",
|
||||
@ -115,13 +119,19 @@ def RunLSTM(sess,
|
||||
with variable_scope.variable_scope("test", initializer=initializer):
|
||||
w = variable_scope.get_variable(
|
||||
"rnn/lstm_cell/kernel",
|
||||
shape=[input_size + num_units, num_units * 4],
|
||||
shape=[input_size + unified_num_units, num_units * 4],
|
||||
dtype=dtype)
|
||||
b = variable_scope.get_variable(
|
||||
"rnn/lstm_cell/bias", shape=[num_units * 4], dtype=dtype)
|
||||
if num_proj:
|
||||
pw = variable_scope.get_variable(
|
||||
"rnn/lstm_cell/projection/kernel",
|
||||
shape=[num_units, num_proj],
|
||||
dtype=dtype)
|
||||
|
||||
# canonical lstm. must set forget_bias to 0. to align with cudnn lstm.
|
||||
cell = rnn_cell_impl.LSTMCell(num_units, forget_bias=0., reuse=True)
|
||||
cell = rnn_cell_impl.LSTMCell(
|
||||
num_units, forget_bias=0., reuse=True, num_proj=unified_num_proj)
|
||||
outputs_op, state_tuple_op = rnn.dynamic_rnn(
|
||||
cell,
|
||||
inputs_static,
|
||||
@ -134,8 +144,13 @@ def RunLSTM(sess,
|
||||
|
||||
# Convert to cudnn opaque param.
|
||||
format_converter = cudnn_rnn_ops.CudnnParamsFormatConverterLSTM(
|
||||
num_layers, num_units, input_size)
|
||||
opaque_params = format_converter.tf_canonical_to_opaque([w, b])
|
||||
num_layers, num_units, input_size, num_proj=unified_num_proj)
|
||||
if num_proj:
|
||||
opaque_params = format_converter.tf_canonical_to_opaque([w, b], [
|
||||
pw,
|
||||
])
|
||||
else:
|
||||
opaque_params = format_converter.tf_canonical_to_opaque([w, b])
|
||||
|
||||
cu_initial_h_op = array_ops.expand_dims(
|
||||
initial_h_op, axis=(0 if time_major else 1))
|
||||
@ -150,16 +165,22 @@ def RunLSTM(sess,
|
||||
time_major=time_major,
|
||||
dropout=dropout,
|
||||
is_training=is_training,
|
||||
rnn_mode=cudnn_rnn_ops.CUDNN_LSTM)
|
||||
rnn_mode=cudnn_rnn_ops.CUDNN_LSTM,
|
||||
num_proj=unified_num_proj)
|
||||
# Remove the trivial 1st dimension.
|
||||
cu_state_tuple_op = rnn_cell_impl.LSTMStateTuple(
|
||||
c=array_ops.squeeze(cu_c_op, axis=0 if time_major else 1),
|
||||
h=array_ops.squeeze(cu_h_op, axis=0 if time_major else 1))
|
||||
|
||||
if is_training:
|
||||
(inp_grad_op, hgrad_op,
|
||||
cgrad_op, wgrad_op, bgrad_op) = gradients_impl.gradients(
|
||||
outputs_op, [inputs_static, initial_h_op, initial_c_op, w, b])
|
||||
if num_proj:
|
||||
(inp_grad_op, hgrad_op, cgrad_op,
|
||||
wgrad_op, bgrad_op, pwgrad_op) = gradients_impl.gradients(
|
||||
outputs_op, [inputs_static, initial_h_op, initial_c_op, w, b, pw])
|
||||
else:
|
||||
(inp_grad_op, hgrad_op,
|
||||
cgrad_op, wgrad_op, bgrad_op) = gradients_impl.gradients(
|
||||
outputs_op, [inputs_static, initial_h_op, initial_c_op, w, b])
|
||||
|
||||
(cu_inp_grad_op, cu_hgrad_op,
|
||||
cu_cgrad_op, opaque_grad_op) = gradients_impl.gradients(
|
||||
@ -170,10 +191,16 @@ def RunLSTM(sess,
|
||||
# Remove the trivial 1st dimension
|
||||
cu_cgrad_op = array_ops.squeeze(cu_cgrad_op, axis=0 if time_major else 1)
|
||||
|
||||
cu_wgrad_op, cu_bgrad_op = format_converter.opaque_to_tf_canonical(
|
||||
opaque_grad_op)
|
||||
if num_proj:
|
||||
cu_wgrad_op, cu_bgrad_op, cu_pwgrad_op = \
|
||||
format_converter.opaque_to_tf_canonical(opaque_grad_op)
|
||||
else:
|
||||
cu_wgrad_op, cu_bgrad_op = format_converter.opaque_to_tf_canonical(
|
||||
opaque_grad_op)
|
||||
cu_wgrad_op = cu_wgrad_op[0]
|
||||
cu_bgrad_op = cu_bgrad_op[0]
|
||||
if num_proj:
|
||||
cu_pwgrad_op = cu_pwgrad_op[0]
|
||||
# cudnn lstm has 2 biases each gate. When converting to tf canonical format,
|
||||
# the two biases are summed into one. Thus here bias gradient should be
|
||||
# halved when comparing with tf lstm.
|
||||
@ -183,17 +210,32 @@ def RunLSTM(sess,
|
||||
sess.run(init_op)
|
||||
|
||||
if is_training:
|
||||
outputs, state_tuple, inp_grad, state_grad, wgrad, bgrad = sess.run([
|
||||
outputs_op, state_tuple_op, inp_grad_op,
|
||||
(hgrad_op, cgrad_op), wgrad_op, bgrad_op
|
||||
])
|
||||
(cu_outputs, cu_state_tuple, cu_inp_grad, cu_state_grad, cu_wgrad,
|
||||
cu_bgrad) = sess.run(
|
||||
[
|
||||
cu_outputs_op, cu_state_tuple_op, cu_inp_grad_op,
|
||||
(cu_hgrad_op, cu_cgrad_op), cu_wgrad_op, cu_bgrad_op
|
||||
],
|
||||
feed_dict={inputs: inputs_np} if dynamic_shape_input else None)
|
||||
if num_proj:
|
||||
(outputs, state_tuple, inp_grad, state_grad, wgrad, bgrad,
|
||||
pwgrad) = sess.run([
|
||||
outputs_op, state_tuple_op, inp_grad_op, (hgrad_op, cgrad_op),
|
||||
wgrad_op, bgrad_op, pwgrad_op
|
||||
])
|
||||
(cu_outputs, cu_state_tuple, cu_inp_grad, cu_state_grad, cu_wgrad,
|
||||
cu_bgrad, cu_pwgrad) = sess.run(
|
||||
[
|
||||
cu_outputs_op, cu_state_tuple_op, cu_inp_grad_op,
|
||||
(cu_hgrad_op, cu_cgrad_op), cu_wgrad_op, cu_bgrad_op,
|
||||
cu_pwgrad_op
|
||||
],
|
||||
feed_dict={inputs: inputs_np} if dynamic_shape_input else None)
|
||||
else:
|
||||
outputs, state_tuple, inp_grad, state_grad, wgrad, bgrad = sess.run([
|
||||
outputs_op, state_tuple_op, inp_grad_op, (hgrad_op, cgrad_op),
|
||||
wgrad_op, bgrad_op
|
||||
])
|
||||
(cu_outputs, cu_state_tuple, cu_inp_grad, cu_state_grad, cu_wgrad,
|
||||
cu_bgrad) = sess.run(
|
||||
[
|
||||
cu_outputs_op, cu_state_tuple_op, cu_inp_grad_op,
|
||||
(cu_hgrad_op, cu_cgrad_op), cu_wgrad_op, cu_bgrad_op
|
||||
],
|
||||
feed_dict={inputs: inputs_np} if dynamic_shape_input else None)
|
||||
|
||||
logging.vlog(1, "outputs: %s" % outputs)
|
||||
logging.vlog(1, "cu_outputs: %s" % cu_outputs)
|
||||
@ -205,11 +247,20 @@ def RunLSTM(sess,
|
||||
logging.vlog(1, "cu_state_grad: %s" % str(cu_state_grad))
|
||||
logging.vlog(1, "wgrad: %s" % str(wgrad))
|
||||
logging.vlog(1, "bgrad: %s" % str(bgrad))
|
||||
if num_proj:
|
||||
logging.vlog(1, "pwgrad: %s" % str(bgrad))
|
||||
logging.vlog(1, "cu_wgrad: %s" % str(cu_wgrad))
|
||||
logging.vlog(1, "cu_bgrad: %s" % str(cu_bgrad))
|
||||
return (outputs, cu_outputs, state_tuple, cu_state_tuple, inp_grad,
|
||||
cu_inp_grad, state_grad, cu_state_grad, wgrad, bgrad, cu_wgrad,
|
||||
cu_bgrad)
|
||||
if num_proj:
|
||||
logging.vlog(1, "cu_pwgrad: %s" % str(cu_bgrad))
|
||||
if num_proj:
|
||||
return (outputs, cu_outputs, state_tuple, cu_state_tuple, inp_grad,
|
||||
cu_inp_grad, state_grad, cu_state_grad, wgrad, bgrad, pwgrad,
|
||||
cu_wgrad, cu_bgrad, cu_pwgrad)
|
||||
else:
|
||||
return (outputs, cu_outputs, state_tuple, cu_state_tuple, inp_grad,
|
||||
cu_inp_grad, state_grad, cu_state_grad, wgrad, bgrad, cu_wgrad,
|
||||
cu_bgrad)
|
||||
else:
|
||||
outputs, state_tuple = sess.run([outputs_op, state_tuple_op])
|
||||
cu_outputs, cu_state_tuple = sess.run([cu_outputs_op, cu_state_tuple_op],
|
||||
@ -256,7 +307,6 @@ NAMED_RNN_TESTCASES = ({
|
||||
"num_layers": 1,
|
||||
})
|
||||
|
||||
|
||||
def ExpandNamedTestCases(inputs, *remove_keys, **extra_configs):
|
||||
"""Expands testcase with new config dimensions.
|
||||
|
||||
@ -349,19 +399,35 @@ class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
time_major,
|
||||
dynamic_shape_input=False,
|
||||
rtol=3e-6,
|
||||
atol=3e-6):
|
||||
atol=3e-6,
|
||||
num_proj=None):
|
||||
with self.session(use_gpu=True) as sess:
|
||||
(outputs, cu_outputs, state_tuple, cu_state_tuple, inp_grad, cu_inp_grad,
|
||||
state_grad, cu_state_grad, wgrad, bgrad, cu_wgrad, cu_bgrad) = RunLSTM(
|
||||
sess,
|
||||
num_units,
|
||||
input_size,
|
||||
batch_size,
|
||||
time,
|
||||
num_layers,
|
||||
variable_seq_lengths=variable_seq_lengths,
|
||||
time_major=time_major,
|
||||
dynamic_shape_input=dynamic_shape_input)
|
||||
if num_proj is not None and num_proj != 0:
|
||||
(outputs, cu_outputs, state_tuple, cu_state_tuple, inp_grad,
|
||||
cu_inp_grad, state_grad, cu_state_grad, wgrad, bgrad, pwgrad, cu_wgrad,
|
||||
cu_bgrad, cu_pwgrad) = RunLSTM(
|
||||
sess,
|
||||
num_units,
|
||||
input_size,
|
||||
batch_size,
|
||||
time,
|
||||
num_layers,
|
||||
variable_seq_lengths=variable_seq_lengths,
|
||||
dynamic_shape_input=dynamic_shape_input,
|
||||
num_proj=num_proj)
|
||||
else:
|
||||
(outputs, cu_outputs, state_tuple, cu_state_tuple, inp_grad,
|
||||
cu_inp_grad, state_grad, cu_state_grad, wgrad, bgrad, cu_wgrad,
|
||||
cu_bgrad) = RunLSTM(
|
||||
sess,
|
||||
num_units,
|
||||
input_size,
|
||||
batch_size,
|
||||
time,
|
||||
num_layers,
|
||||
variable_seq_lengths=variable_seq_lengths,
|
||||
dynamic_shape_input=dynamic_shape_input,
|
||||
num_proj=num_proj)
|
||||
|
||||
self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol)
|
||||
for s, cu_s in zip(state_tuple, cu_state_tuple):
|
||||
@ -371,6 +437,8 @@ class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
self.assertAllClose(inp_grad, cu_inp_grad, rtol=rtol, atol=atol)
|
||||
self.assertAllClose(bgrad, cu_bgrad, rtol=rtol, atol=atol)
|
||||
self.assertAllClose(wgrad, cu_wgrad, rtol=rtol, atol=atol)
|
||||
if num_proj is not None and num_proj != 0:
|
||||
self.assertAllClose(pwgrad, cu_pwgrad, rtol=rtol, atol=atol)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
ExpandNamedTestCases(
|
||||
@ -378,20 +446,27 @@ class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
"variable_seq_lengths": [True, False],
|
||||
"time_major": [True, False],
|
||||
"dynamic_shape_input": [True, False],
|
||||
"use_proj": [True, False],
|
||||
}))
|
||||
@test_util.run_gpu_only
|
||||
def test_training(self, num_units, input_size, batch_size, time, num_layers,
|
||||
variable_seq_lengths, time_major, dynamic_shape_input):
|
||||
self._test_training_helper(
|
||||
num_units,
|
||||
input_size,
|
||||
batch_size,
|
||||
time,
|
||||
num_layers,
|
||||
dtypes.float32,
|
||||
variable_seq_lengths=variable_seq_lengths,
|
||||
time_major=time_major,
|
||||
dynamic_shape_input=dynamic_shape_input)
|
||||
variable_seq_lengths, time_major, dynamic_shape_input,
|
||||
use_proj):
|
||||
with compat.forward_compatibility_horizon(2019, 6, 27):
|
||||
num_proj = num_units // 2
|
||||
if use_proj and num_proj == 0:
|
||||
self.skipTest("num_proj cannot be 0")
|
||||
self._test_training_helper(
|
||||
num_units,
|
||||
input_size,
|
||||
batch_size,
|
||||
time,
|
||||
num_layers,
|
||||
dtypes.float32,
|
||||
variable_seq_lengths=variable_seq_lengths,
|
||||
time_major=time_major,
|
||||
dynamic_shape_input=dynamic_shape_input,
|
||||
num_proj=num_proj if use_proj else None)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
ExpandNamedTestCases(
|
||||
@ -399,52 +474,29 @@ class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
"variable_seq_lengths": [True, False],
|
||||
"time_major": [True, False],
|
||||
"dynamic_shape_input": [True, False],
|
||||
"use_proj": [True, False],
|
||||
}))
|
||||
@test_util.run_gpu_only
|
||||
def test_training_fp16(self, num_units, input_size, batch_size, time,
|
||||
num_layers, variable_seq_lengths, time_major,
|
||||
dynamic_shape_input):
|
||||
self._test_training_helper(
|
||||
num_units,
|
||||
input_size,
|
||||
batch_size,
|
||||
time,
|
||||
num_layers,
|
||||
dtypes.float16,
|
||||
rtol=5e-3,
|
||||
atol=5e-4,
|
||||
variable_seq_lengths=variable_seq_lengths,
|
||||
time_major=time_major,
|
||||
dynamic_shape_input=dynamic_shape_input)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
ExpandNamedTestCases(
|
||||
NAMED_RNN_TESTCASES, **{
|
||||
"variable_seq_lengths": [True, False],
|
||||
"time_major": [True, False],
|
||||
"dynamic_shape_input": [True, False],
|
||||
}))
|
||||
@test_util.run_gpu_only
|
||||
def test_inference(self, num_units, input_size, batch_size, time, num_layers,
|
||||
variable_seq_lengths, time_major, dynamic_shape_input):
|
||||
with self.session(use_gpu=True) as sess:
|
||||
(outputs, cu_outputs, state_tuple, cu_state_tuple) = RunLSTM(
|
||||
sess,
|
||||
dynamic_shape_input, use_proj):
|
||||
with compat.forward_compatibility_horizon(2019, 6, 27):
|
||||
num_proj = num_units // 2
|
||||
if use_proj and num_proj == 0:
|
||||
self.skipTest("num_proj cannot be 0")
|
||||
self._test_training_helper(
|
||||
num_units,
|
||||
input_size,
|
||||
batch_size,
|
||||
time,
|
||||
num_layers,
|
||||
is_training=False,
|
||||
dtypes.float16,
|
||||
rtol=5e-3,
|
||||
atol=5e-4,
|
||||
variable_seq_lengths=variable_seq_lengths,
|
||||
time_major=time_major,
|
||||
dynamic_shape_input=dynamic_shape_input)
|
||||
|
||||
self.assertAllClose(outputs, cu_outputs)
|
||||
# h
|
||||
self.assertAllClose(state_tuple.h, cu_state_tuple.h)
|
||||
# c
|
||||
self.assertAllClose(state_tuple.c, cu_state_tuple.c)
|
||||
dynamic_shape_input=dynamic_shape_input,
|
||||
num_proj=num_proj if use_proj else None)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
ExpandNamedTestCases(
|
||||
@ -452,33 +504,75 @@ class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
"variable_seq_lengths": [True, False],
|
||||
"time_major": [True, False],
|
||||
"dynamic_shape_input": [True, False],
|
||||
"use_proj": [True, False],
|
||||
}))
|
||||
@test_util.run_gpu_only
|
||||
def test_inference(self, num_units, input_size, batch_size, time, num_layers,
|
||||
variable_seq_lengths, time_major, dynamic_shape_input,
|
||||
use_proj):
|
||||
with compat.forward_compatibility_horizon(2019, 6, 27):
|
||||
num_proj = num_units // 2
|
||||
if use_proj and num_proj == 0:
|
||||
self.skipTest("num_proj cannot be 0")
|
||||
with self.session(use_gpu=True) as sess:
|
||||
(outputs, cu_outputs, state_tuple, cu_state_tuple) = RunLSTM(
|
||||
sess,
|
||||
num_units,
|
||||
input_size,
|
||||
batch_size,
|
||||
time,
|
||||
num_layers,
|
||||
is_training=False,
|
||||
variable_seq_lengths=variable_seq_lengths,
|
||||
time_major=time_major,
|
||||
dynamic_shape_input=dynamic_shape_input,
|
||||
num_proj=num_proj if use_proj else None)
|
||||
|
||||
self.assertAllClose(outputs, cu_outputs)
|
||||
# h
|
||||
self.assertAllClose(state_tuple.h, cu_state_tuple.h)
|
||||
# c
|
||||
self.assertAllClose(state_tuple.c, cu_state_tuple.c)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
ExpandNamedTestCases(
|
||||
NAMED_RNN_TESTCASES, **{
|
||||
"variable_seq_lengths": [True, False],
|
||||
"time_major": [True, False],
|
||||
"dynamic_shape_input": [True, False],
|
||||
"use_proj": [True, False],
|
||||
}))
|
||||
@test_util.run_gpu_only
|
||||
def test_inference_fp16(self, num_units, input_size, batch_size, time,
|
||||
num_layers, variable_seq_lengths, time_major,
|
||||
dynamic_shape_input):
|
||||
with self.session(use_gpu=True) as sess:
|
||||
(outputs, cu_outputs, state_tuple, cu_state_tuple) = RunLSTM(
|
||||
sess,
|
||||
num_units,
|
||||
input_size,
|
||||
batch_size,
|
||||
time,
|
||||
num_layers,
|
||||
is_training=False,
|
||||
dtype=dtypes.float16,
|
||||
variable_seq_lengths=variable_seq_lengths,
|
||||
time_major=time_major,
|
||||
dynamic_shape_input=dynamic_shape_input)
|
||||
dynamic_shape_input, use_proj):
|
||||
with compat.forward_compatibility_horizon(2019, 6, 27):
|
||||
num_proj = num_units // 2
|
||||
if use_proj and num_proj == 0:
|
||||
self.skipTest("num_proj cannot be 0")
|
||||
with self.session(use_gpu=True) as sess:
|
||||
(outputs, cu_outputs, state_tuple, cu_state_tuple) = RunLSTM(
|
||||
sess,
|
||||
num_units,
|
||||
input_size,
|
||||
batch_size,
|
||||
time,
|
||||
num_layers,
|
||||
is_training=False,
|
||||
dtype=dtypes.float16,
|
||||
variable_seq_lengths=variable_seq_lengths,
|
||||
time_major=time_major,
|
||||
dynamic_shape_input=dynamic_shape_input,
|
||||
num_proj=num_proj if use_proj else None)
|
||||
|
||||
rtol, atol = 5e-3, 5e-4
|
||||
self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol)
|
||||
# h
|
||||
self.assertAllClose(
|
||||
state_tuple.h, cu_state_tuple.h, rtol=rtol, atol=atol)
|
||||
# c
|
||||
self.assertAllClose(
|
||||
state_tuple.c, cu_state_tuple.c, rtol=rtol, atol=atol)
|
||||
rtol, atol = 5e-3, 5e-4
|
||||
self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol)
|
||||
# h
|
||||
self.assertAllClose(
|
||||
state_tuple.h, cu_state_tuple.h, rtol=rtol, atol=atol)
|
||||
# c
|
||||
self.assertAllClose(
|
||||
state_tuple.c, cu_state_tuple.c, rtol=rtol, atol=atol)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
ExpandNamedTestCases(
|
||||
@ -486,49 +580,56 @@ class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
"variable_seq_lengths": [True, False],
|
||||
"time_major": [True, False],
|
||||
"dynamic_shape_input": [True, False],
|
||||
"use_proj": [True, False],
|
||||
}))
|
||||
@test_util.run_gpu_only
|
||||
def test_inference_with_dropout(self, num_units, input_size, batch_size, time,
|
||||
num_layers, variable_seq_lengths, time_major,
|
||||
dynamic_shape_input):
|
||||
dynamic_shape_input, use_proj):
|
||||
"""Validates that dropout does not affect Cudnn Rnn inference."""
|
||||
# Hand-picked dropouts are used below (0. and 1.)
|
||||
with ops.Graph().as_default() as g:
|
||||
with self.session(use_gpu=True, graph=g) as sess:
|
||||
# 1st time w/o dropout.
|
||||
(_, cu_outputs, _, cu_state_tuple) = RunLSTM(
|
||||
sess,
|
||||
num_units,
|
||||
input_size,
|
||||
batch_size,
|
||||
time,
|
||||
num_layers,
|
||||
is_training=False,
|
||||
dropout=0.,
|
||||
variable_seq_lengths=variable_seq_lengths,
|
||||
time_major=time_major,
|
||||
dynamic_shape_input=dynamic_shape_input)
|
||||
with compat.forward_compatibility_horizon(2019, 6, 27):
|
||||
num_proj = num_units // 2
|
||||
if use_proj and num_proj == 0:
|
||||
self.skipTest("num_proj cannot be 0")
|
||||
# Hand-picked dropouts are used below (0. and 1.)
|
||||
with ops.Graph().as_default() as g:
|
||||
with self.session(use_gpu=True, graph=g) as sess:
|
||||
# 1st time w/o dropout.
|
||||
(_, cu_outputs, _, cu_state_tuple) = RunLSTM(
|
||||
sess,
|
||||
num_units,
|
||||
input_size,
|
||||
batch_size,
|
||||
time,
|
||||
num_layers,
|
||||
is_training=False,
|
||||
dropout=0.,
|
||||
variable_seq_lengths=variable_seq_lengths,
|
||||
time_major=time_major,
|
||||
dynamic_shape_input=dynamic_shape_input,
|
||||
num_proj=num_proj if use_proj else None)
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
with self.session(use_gpu=True, graph=g) as sess:
|
||||
(_, cu_outputs2, _, cu_state_tuple2) = RunLSTM(
|
||||
sess,
|
||||
num_units,
|
||||
input_size,
|
||||
batch_size,
|
||||
time,
|
||||
num_layers,
|
||||
is_training=False,
|
||||
dropout=1.,
|
||||
variable_seq_lengths=variable_seq_lengths,
|
||||
time_major=time_major,
|
||||
dynamic_shape_input=dynamic_shape_input)
|
||||
with ops.Graph().as_default() as g:
|
||||
with self.session(use_gpu=True, graph=g) as sess:
|
||||
(_, cu_outputs2, _, cu_state_tuple2) = RunLSTM(
|
||||
sess,
|
||||
num_units,
|
||||
input_size,
|
||||
batch_size,
|
||||
time,
|
||||
num_layers,
|
||||
is_training=False,
|
||||
dropout=1.,
|
||||
variable_seq_lengths=variable_seq_lengths,
|
||||
time_major=time_major,
|
||||
dynamic_shape_input=dynamic_shape_input,
|
||||
num_proj=num_proj if use_proj else None)
|
||||
|
||||
self.assertAllClose(cu_outputs, cu_outputs2)
|
||||
# h
|
||||
self.assertAllClose(cu_state_tuple.h, cu_state_tuple2.h)
|
||||
# c
|
||||
self.assertAllClose(cu_state_tuple.c, cu_state_tuple2.c)
|
||||
self.assertAllClose(cu_outputs, cu_outputs2)
|
||||
# h
|
||||
self.assertAllClose(cu_state_tuple.h, cu_state_tuple2.h)
|
||||
# c
|
||||
self.assertAllClose(cu_state_tuple.c, cu_state_tuple2.c)
|
||||
|
||||
|
||||
def RunGRU(sess,
|
||||
@ -890,40 +991,68 @@ class CudnnParamsFormatConverterTest(test_util.TensorFlowTestCase,
|
||||
parameterized.TestCase):
|
||||
"""Class for testing various format converters."""
|
||||
|
||||
def _test_lstm_helper(self, num_units, input_size, num_layers, direction):
|
||||
def _test_lstm_helper(self,
|
||||
num_units,
|
||||
input_size,
|
||||
num_layers,
|
||||
direction,
|
||||
num_proj=None):
|
||||
with self.session(use_gpu=True) as sess:
|
||||
random_seed.set_random_seed(0)
|
||||
np.random.seed(0)
|
||||
|
||||
num_dirs = 1 if direction == cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION else 2
|
||||
format_converter = cudnn_rnn_ops.CudnnParamsFormatConverterLSTM(
|
||||
num_layers, num_units, input_size, direction=direction)
|
||||
num_layers,
|
||||
num_units,
|
||||
input_size,
|
||||
direction=direction,
|
||||
num_proj=num_proj if num_proj else None)
|
||||
|
||||
ws, bs = [], []
|
||||
ws, bs, pws = [], [], []
|
||||
for _ in range(num_layers * num_dirs):
|
||||
w = constant_op.constant(
|
||||
np.random.rand(input_size + num_units, 4 * num_units),
|
||||
np.random.rand(input_size + (num_proj if num_proj else num_units),
|
||||
4 * num_units),
|
||||
dtype=dtypes.float32)
|
||||
b = constant_op.constant(
|
||||
np.random.rand(4 * num_units), dtype=dtypes.float32)
|
||||
ws.append(w)
|
||||
bs.append(b)
|
||||
if num_proj:
|
||||
pw = constant_op.constant(
|
||||
np.random.rand(num_units, num_proj), dtype=dtypes.float32)
|
||||
pws.append(pw)
|
||||
|
||||
if num_proj:
|
||||
opaque_params = format_converter.tf_canonical_to_opaque(ws + bs, pws)
|
||||
else:
|
||||
opaque_params = format_converter.tf_canonical_to_opaque(ws + bs)
|
||||
|
||||
opaque_params = format_converter.tf_canonical_to_opaque(ws + bs)
|
||||
opaque_params_size = cudnn_rnn_ops.cudnn_rnn_opaque_params_size(
|
||||
cudnn_rnn_ops.CUDNN_LSTM,
|
||||
num_layers,
|
||||
num_units,
|
||||
input_size,
|
||||
direction=direction)
|
||||
direction=direction,
|
||||
num_proj=num_proj if num_proj else None)
|
||||
|
||||
ws_r, bs_r = format_converter.opaque_to_tf_canonical(opaque_params)
|
||||
if num_proj:
|
||||
ws_r, bs_r, pws_r = format_converter.opaque_to_tf_canonical(
|
||||
opaque_params)
|
||||
ws, ws_r, pws, bs, bs_r, pws_r = sess.run(
|
||||
[ws, ws_r, pws, bs, bs_r, pws_r])
|
||||
else:
|
||||
ws_r, bs_r = format_converter.opaque_to_tf_canonical(opaque_params)
|
||||
ws, ws_r, bs, bs_r = sess.run([ws, ws_r, bs, bs_r])
|
||||
|
||||
# Test tf_canonical_to_opaque() followed by opaque_to_tf_canonical()
|
||||
# returns the original input.
|
||||
ws, ws_r, bs, bs_r = sess.run([ws, ws_r, bs, bs_r])
|
||||
for w, w_r in zip(ws, ws_r):
|
||||
self.assertAllClose(w, w_r)
|
||||
if num_proj:
|
||||
for pw, pw_r in zip(pws, pws_r):
|
||||
self.assertAllClose(pw, pw_r)
|
||||
for b, b_r in zip(bs, bs_r):
|
||||
self.assertAllClose(b, b_r)
|
||||
|
||||
@ -942,6 +1071,22 @@ class CudnnParamsFormatConverterTest(test_util.TensorFlowTestCase,
|
||||
self._test_lstm_helper(num_units, input_size, num_layers,
|
||||
cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
(c["testcase_name"], c["num_units"], c["input_size"], c["num_layers"])
|
||||
for c in NAMED_RNN_TESTCASES)
|
||||
@test_util.run_gpu_only
|
||||
def test_lstmp(self, num_units, input_size, num_layers):
|
||||
with compat.forward_compatibility_horizon(2019, 6, 27):
|
||||
num_proj = num_units // 2
|
||||
if num_proj == 0:
|
||||
self.skipTest("num_proj cannot be 0")
|
||||
self._test_lstm_helper(
|
||||
num_units,
|
||||
input_size,
|
||||
num_layers,
|
||||
cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION,
|
||||
num_proj=num_proj)
|
||||
|
||||
@parameterized.named_parameters((c["testcase_name"], c["num_units"],
|
||||
c["input_size"], c["num_layers"])
|
||||
for c in NAMED_RNN_TESTCASES)
|
||||
@ -950,6 +1095,22 @@ class CudnnParamsFormatConverterTest(test_util.TensorFlowTestCase,
|
||||
self._test_lstm_helper(num_units, input_size, num_layers,
|
||||
cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
(c["testcase_name"], c["num_units"], c["input_size"], c["num_layers"])
|
||||
for c in NAMED_RNN_TESTCASES)
|
||||
@test_util.run_gpu_only
|
||||
def test_lstmp_bidi(self, num_units, input_size, num_layers):
|
||||
with compat.forward_compatibility_horizon(2019, 6, 27):
|
||||
num_proj = num_units // 2
|
||||
if num_proj == 0:
|
||||
self.skipTest("num_proj cannot be 0")
|
||||
self._test_lstm_helper(
|
||||
num_units,
|
||||
input_size,
|
||||
num_layers,
|
||||
cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION,
|
||||
num_proj=num_proj)
|
||||
|
||||
def _test_gru_helper(self, num_units, input_size, num_layers, direction):
|
||||
with self.session(use_gpu=True) as sess:
|
||||
random_seed.set_random_seed(0)
|
||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
import os
|
||||
from tensorflow.contrib.checkpoint.python import split_dependency
|
||||
from tensorflow.contrib.rnn.python.ops import lstm_ops
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import random_seed
|
||||
@ -186,6 +187,7 @@ class CudnnParamsFormatConverter(object):
|
||||
num_layers,
|
||||
num_units,
|
||||
input_size,
|
||||
num_proj=None,
|
||||
input_mode=CUDNN_INPUT_LINEAR_MODE,
|
||||
direction=CUDNN_RNN_UNIDIRECTION):
|
||||
"""Constructor.
|
||||
@ -195,6 +197,8 @@ class CudnnParamsFormatConverter(object):
|
||||
num_units: the number of units within the RNN model.
|
||||
input_size: the size of the input, it could be different from the
|
||||
num_units.
|
||||
num_proj: The output dimensionality for the projection matrices.
|
||||
If None or 0, no projection is performed.
|
||||
input_mode: indicate whether there is a linear projection between the
|
||||
input and the actual computation before the first layer. It could be one
|
||||
of 'linear_input', 'skip_input' or 'auto_select'. * 'linear_input'
|
||||
@ -209,14 +213,16 @@ class CudnnParamsFormatConverter(object):
|
||||
self._input_size = input_size
|
||||
self._num_units = num_units
|
||||
self._input_mode = input_mode
|
||||
self._num_proj = num_proj
|
||||
self._direction = direction
|
||||
self._num_dirs = 1 if self._direction == CUDNN_RNN_UNIDIRECTION else 2
|
||||
self._num_params = (
|
||||
self._num_params_per_layer * self._num_layers * self._num_dirs)
|
||||
|
||||
def tf_canonical_to_opaque(self, tf_canonicals):
|
||||
def tf_canonical_to_opaque(self, tf_canonicals, weights_proj=None):
|
||||
r"""Converts tf canonical weights to cudnn opaque param."""
|
||||
cu_weights, cu_biases = self._tf_canonical_to_cu_canonical(tf_canonicals)
|
||||
cu_weights, cu_biases = self._tf_canonical_to_cu_canonical(
|
||||
tf_canonicals, weights_proj)
|
||||
cu_weights = [array_ops.reshape(w, [-1]) for w in cu_weights]
|
||||
opaque_params = self._cu_canonical_to_opaque(cu_weights, cu_biases)
|
||||
return opaque_params
|
||||
@ -224,8 +230,14 @@ class CudnnParamsFormatConverter(object):
|
||||
def opaque_to_tf_canonical(self, opaque_param):
|
||||
r"""Converts cudnn opaque param to tf canonical weights."""
|
||||
cu_weights, cu_biases = self._opaque_to_cu_canonical(opaque_param)
|
||||
weights, biases = self._cu_canonical_to_tf_canonical(cu_weights, cu_biases)
|
||||
return weights, biases
|
||||
if self._num_proj:
|
||||
weights, biases, weights_proj = self._cu_canonical_to_tf_canonical(
|
||||
cu_weights, cu_biases)
|
||||
return weights, biases, weights_proj
|
||||
else:
|
||||
weights, biases = self._cu_canonical_to_tf_canonical(
|
||||
cu_weights, cu_biases)
|
||||
return weights, biases
|
||||
|
||||
def _opaque_to_cu_canonical(self, opaque_param):
|
||||
"""Converts opaque params to Cudnn canonical format.
|
||||
@ -238,15 +250,31 @@ class CudnnParamsFormatConverter(object):
|
||||
2 list for weights and biases respectively.
|
||||
"""
|
||||
with ops.device("/gpu:0"):
|
||||
weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical(
|
||||
num_layers=self._num_layers,
|
||||
num_units=self._num_units,
|
||||
input_size=self._input_size,
|
||||
params=opaque_param,
|
||||
num_params=self._num_params,
|
||||
rnn_mode=self._rnn_mode,
|
||||
input_mode=self._input_mode,
|
||||
direction=self._direction)
|
||||
if compat.forward_compatible(2019, 6, 26) and self._num_proj:
|
||||
num_params_weights = (
|
||||
self._num_params + 1 * self._num_layers * self._num_dirs)
|
||||
num_params_biases = self._num_params
|
||||
weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical_v2(
|
||||
num_layers=self._num_layers,
|
||||
num_units=self._num_units,
|
||||
input_size=self._input_size,
|
||||
params=opaque_param,
|
||||
rnn_mode=self._rnn_mode,
|
||||
input_mode=self._input_mode,
|
||||
direction=self._direction,
|
||||
num_params_weights=num_params_weights,
|
||||
num_params_biases=num_params_biases,
|
||||
num_proj=self._num_proj)
|
||||
else:
|
||||
weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical(
|
||||
num_layers=self._num_layers,
|
||||
num_units=self._num_units,
|
||||
input_size=self._input_size,
|
||||
params=opaque_param,
|
||||
num_params=self._num_params,
|
||||
rnn_mode=self._rnn_mode,
|
||||
input_mode=self._input_mode,
|
||||
direction=self._direction)
|
||||
return (weights, biases)
|
||||
|
||||
def _cu_canonical_to_opaque(self, cu_weights, cu_biases):
|
||||
@ -260,15 +288,27 @@ class CudnnParamsFormatConverter(object):
|
||||
a single opaque tensor.
|
||||
"""
|
||||
with ops.device("/gpu:0"):
|
||||
return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params(
|
||||
num_layers=self._num_layers,
|
||||
num_units=self._num_units,
|
||||
input_size=self._input_size,
|
||||
weights=cu_weights,
|
||||
biases=cu_biases,
|
||||
rnn_mode=self._rnn_mode,
|
||||
input_mode=self._input_mode,
|
||||
direction=self._direction)
|
||||
if compat.forward_compatible(2019, 6, 26) and self._num_proj:
|
||||
return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params_v2(
|
||||
num_layers=self._num_layers,
|
||||
num_units=self._num_units,
|
||||
input_size=self._input_size,
|
||||
weights=cu_weights,
|
||||
biases=cu_biases,
|
||||
rnn_mode=self._rnn_mode,
|
||||
input_mode=self._input_mode,
|
||||
num_proj=self._num_proj,
|
||||
direction=self._direction)
|
||||
else:
|
||||
return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params(
|
||||
num_layers=self._num_layers,
|
||||
num_units=self._num_units,
|
||||
input_size=self._input_size,
|
||||
weights=cu_weights,
|
||||
biases=cu_biases,
|
||||
rnn_mode=self._rnn_mode,
|
||||
input_mode=self._input_mode,
|
||||
direction=self._direction)
|
||||
|
||||
def _cu_canonical_to_tf_canonical(self, cu_weights, cu_biases):
|
||||
r"""Transform from Cudnn canonical to tf canonical.
|
||||
@ -294,9 +334,11 @@ class CudnnParamsFormatConverter(object):
|
||||
1 tuple, tf canonical weights and biases.
|
||||
"""
|
||||
tf_weights, tf_biases = [], []
|
||||
tf_weights_proj = []
|
||||
|
||||
layer_weights_num = self._num_params_per_layer * self._num_dirs
|
||||
layer_biases_num = layer_weights_num
|
||||
layer_weights_num += (1 * self._num_dirs) if self._num_proj else 0
|
||||
|
||||
for i in range(self._num_layers):
|
||||
layer_weights = cu_weights[i * layer_weights_num:(i + 1) *
|
||||
@ -305,7 +347,8 @@ class CudnnParamsFormatConverter(object):
|
||||
if self._direction == CUDNN_RNN_UNIDIRECTION:
|
||||
self._cu_canonical_to_tf_canonical_single_layer(layer_weights,
|
||||
layer_biases,
|
||||
tf_weights, tf_biases)
|
||||
tf_weights, tf_biases,
|
||||
tf_weights_proj)
|
||||
else:
|
||||
fw_weights = layer_weights[:len(layer_weights) // 2]
|
||||
bw_weights = layer_weights[len(layer_weights) // 2:]
|
||||
@ -317,6 +360,7 @@ class CudnnParamsFormatConverter(object):
|
||||
fw_biases,
|
||||
tf_weights,
|
||||
tf_biases,
|
||||
tf_weights_proj,
|
||||
)
|
||||
|
||||
self._cu_canonical_to_tf_canonical_single_layer(
|
||||
@ -324,11 +368,19 @@ class CudnnParamsFormatConverter(object):
|
||||
bw_biases,
|
||||
tf_weights,
|
||||
tf_biases,
|
||||
tf_weights_proj,
|
||||
)
|
||||
return (tf_weights, tf_biases)
|
||||
if self._num_proj:
|
||||
return (tf_weights, tf_biases, tf_weights_proj)
|
||||
else:
|
||||
return (tf_weights, tf_biases)
|
||||
|
||||
def _cu_canonical_to_tf_canonical_single_layer(self, cu_weights, cu_biases,
|
||||
tf_weights, tf_biases):
|
||||
def _cu_canonical_to_tf_canonical_single_layer(self,
|
||||
cu_weights,
|
||||
cu_biases,
|
||||
tf_weights,
|
||||
tf_biases,
|
||||
tf_weigths_proj=None):
|
||||
r"""Transform single layer Cudnn canonicals to tf canonicals.
|
||||
|
||||
The elements of cu_weights, cu_biases are laid out in the following format:
|
||||
@ -343,7 +395,7 @@ class CudnnParamsFormatConverter(object):
|
||||
"""
|
||||
raise NotImplementedError("Abstract method")
|
||||
|
||||
def _tf_canonical_to_cu_canonical(self, tf_canonicals):
|
||||
def _tf_canonical_to_cu_canonical(self, tf_canonicals, weights_proj):
|
||||
r"""Transform from tf canonical to Cudnn canonical.
|
||||
|
||||
This is the reverse routine of _TransformCanonical().
|
||||
@ -362,6 +414,7 @@ class CudnnParamsFormatConverter(object):
|
||||
---------------
|
||||
|fwd |bak |
|
||||
---------------
|
||||
weights_proj: (optional) weights matrices for projection
|
||||
Returns:
|
||||
2 lists: the recovered cudnn canonical weights and biases.
|
||||
"""
|
||||
@ -376,6 +429,9 @@ class CudnnParamsFormatConverter(object):
|
||||
layer_biases = biases[i * layer_biases_num:(i + 1) * layer_biases_num]
|
||||
if self._direction == CUDNN_RNN_UNIDIRECTION:
|
||||
cu_weights.extend(self._tf_to_cudnn_weights(i, *layer_weights))
|
||||
if weights_proj is not None:
|
||||
pw = array_ops.transpose(weights_proj[i])
|
||||
cu_weights.append(pw)
|
||||
cu_biases.extend(self._tf_to_cudnn_biases(*layer_biases))
|
||||
else:
|
||||
fw_weights, bw_weights = layer_weights[:len(layer_weights) //
|
||||
@ -385,9 +441,15 @@ class CudnnParamsFormatConverter(object):
|
||||
2], layer_biases[len(layer_biases
|
||||
) // 2:]
|
||||
cu_weights.extend(self._tf_to_cudnn_weights(i, *fw_weights))
|
||||
if weights_proj is not None:
|
||||
pw0 = array_ops.transpose(weights_proj[2 * i + 0])
|
||||
cu_weights.append(pw0)
|
||||
cu_biases.extend(self._tf_to_cudnn_biases(*fw_biases))
|
||||
|
||||
cu_weights.extend(self._tf_to_cudnn_weights(i, *bw_weights))
|
||||
if weights_proj is not None:
|
||||
pw1 = array_ops.transpose(weights_proj[2 * i + 1])
|
||||
cu_weights.append(pw1)
|
||||
cu_biases.extend(self._tf_to_cudnn_biases(*bw_biases))
|
||||
return cu_weights, cu_biases
|
||||
|
||||
@ -423,7 +485,10 @@ class CudnnParamsFormatConverterLSTM(CudnnParamsFormatConverter):
|
||||
|
||||
def _cudnn_to_tf_weights(self, *cu_weights):
|
||||
r"""Stitching cudnn canonical weights to generate tf canonical weights."""
|
||||
w_i, w_f, w_c, w_o, r_i, r_f, r_c, r_o = cu_weights
|
||||
if self._num_proj:
|
||||
w_i, w_f, w_c, w_o, r_i, r_f, r_c, r_o, pw = cu_weights
|
||||
else:
|
||||
w_i, w_f, w_c, w_o, r_i, r_f, r_c, r_o = cu_weights
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
W_i = array_ops.concat([w_i, r_i], axis=1)
|
||||
@ -433,7 +498,11 @@ class CudnnParamsFormatConverterLSTM(CudnnParamsFormatConverter):
|
||||
# pylint: enable=invalid-name
|
||||
# Cudnn LSTM weights are in ifco order, other tf LSTMs are in icfo order.
|
||||
reordered = self._cudnn_to_tf_gate_params(*[W_i, W_f, W_c, W_o])
|
||||
return (array_ops.transpose(array_ops.concat(reordered, axis=0)),)
|
||||
if self._num_proj:
|
||||
return (array_ops.transpose(array_ops.concat(reordered, axis=0)),
|
||||
array_ops.transpose(pw))
|
||||
else:
|
||||
return (array_ops.transpose(array_ops.concat(reordered, axis=0)),)
|
||||
|
||||
def _tf_to_cudnn_weights(self, layer, *tf_weights):
|
||||
r"""Reverse the operations in StitchWeights()."""
|
||||
@ -442,7 +511,7 @@ class CudnnParamsFormatConverterLSTM(CudnnParamsFormatConverter):
|
||||
if layer == 0:
|
||||
input_weight_width = input_size
|
||||
else:
|
||||
input_weight_width = num_units
|
||||
input_weight_width = self._num_proj if self._num_proj else num_units
|
||||
if self._direction == CUDNN_RNN_BIDIRECTION:
|
||||
input_weight_width *= 2
|
||||
|
||||
@ -452,10 +521,15 @@ class CudnnParamsFormatConverterLSTM(CudnnParamsFormatConverter):
|
||||
W_i, W_f, W_c, W_o = self._tf_to_cudnn_gate_params(
|
||||
*array_ops.split(w, 4, axis=0))
|
||||
|
||||
w_i, r_i = array_ops.split(W_i, [input_weight_width, num_units], axis=1)
|
||||
w_c, r_c = array_ops.split(W_c, [input_weight_width, num_units], axis=1)
|
||||
w_f, r_f = array_ops.split(W_f, [input_weight_width, num_units], axis=1)
|
||||
w_o, r_o = array_ops.split(W_o, [input_weight_width, num_units], axis=1)
|
||||
hidden_state_width = self._num_proj if self._num_proj else num_units
|
||||
w_i, r_i = array_ops.split(
|
||||
W_i, [input_weight_width, hidden_state_width], axis=1)
|
||||
w_c, r_c = array_ops.split(
|
||||
W_c, [input_weight_width, hidden_state_width], axis=1)
|
||||
w_f, r_f = array_ops.split(
|
||||
W_f, [input_weight_width, hidden_state_width], axis=1)
|
||||
w_o, r_o = array_ops.split(
|
||||
W_o, [input_weight_width, hidden_state_width], axis=1)
|
||||
return w_i, w_f, w_c, w_o, r_i, r_f, r_c, r_o
|
||||
# pylint: enable=invalid-name
|
||||
|
||||
@ -490,11 +564,20 @@ class CudnnParamsFormatConverterLSTM(CudnnParamsFormatConverter):
|
||||
# Return ifco order for Cudnn LSTM.
|
||||
return b_wi, b_wf, b_wc, b_wo, b_ri, b_rf, b_rc, b_ro
|
||||
|
||||
def _cu_canonical_to_tf_canonical_single_layer(self, cu_weights, cu_biases,
|
||||
tf_weights, tf_biases):
|
||||
(w,) = self._cudnn_to_tf_weights(*cu_weights)
|
||||
def _cu_canonical_to_tf_canonical_single_layer(self,
|
||||
cu_weights,
|
||||
cu_biases,
|
||||
tf_weights,
|
||||
tf_biases,
|
||||
tf_weights_proj=None):
|
||||
if self._num_proj:
|
||||
(w, pw) = self._cudnn_to_tf_weights(*cu_weights)
|
||||
tf_weights.append(w)
|
||||
tf_weights_proj.append(pw)
|
||||
else:
|
||||
(w,) = self._cudnn_to_tf_weights(*cu_weights)
|
||||
tf_weights.append(w)
|
||||
(b,) = self._cudnn_to_tf_biases(*cu_biases)
|
||||
tf_weights.append(w)
|
||||
tf_biases.append(b)
|
||||
|
||||
|
||||
@ -561,8 +644,12 @@ class CudnnParamsFormatConverterGRU(CudnnParamsFormatConverter):
|
||||
b_ri, b_rr = array_ops.split(br, 2, axis=0)
|
||||
return b_wi, b_wr, b_wh, b_ri, b_rr, b_rh
|
||||
|
||||
def _cu_canonical_to_tf_canonical_single_layer(self, cu_weights, cu_biases,
|
||||
tf_weights, tf_biases):
|
||||
def _cu_canonical_to_tf_canonical_single_layer(self,
|
||||
cu_weights,
|
||||
cu_biases,
|
||||
tf_weights,
|
||||
tf_biases,
|
||||
tf_weights_proj=None):
|
||||
# pylint: disable=invalid-name
|
||||
W_ir, w_h, r_h = self._cudnn_to_tf_weights(*cu_weights)
|
||||
b_ir, b_wh, b_rh = self._cudnn_to_tf_biases(*cu_biases)
|
||||
@ -735,8 +822,11 @@ class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject):
|
||||
def format_converter(self):
|
||||
if self._format_converter is None:
|
||||
self._format_converter = self._format_converter_cls(
|
||||
self._num_layers, self._num_units, self._input_size, self._input_mode,
|
||||
self._direction)
|
||||
self._num_layers,
|
||||
self._num_units,
|
||||
self._input_size,
|
||||
input_mode=self._input_mode,
|
||||
direction=self._direction)
|
||||
return self._format_converter
|
||||
|
||||
def restore(self, restored_tensors, restored_shapes):
|
||||
@ -970,6 +1060,7 @@ def _cudnn_rnn(inputs,
|
||||
direction=CUDNN_RNN_UNIDIRECTION,
|
||||
dropout=0.,
|
||||
seed=0,
|
||||
num_proj=None,
|
||||
name=None):
|
||||
"""Cudnn RNN.
|
||||
|
||||
@ -1006,6 +1097,8 @@ def _cudnn_rnn(inputs,
|
||||
dropout: whether to enable dropout. With it is 0, dropout is disabled.
|
||||
seed: the op seed used for initializing dropout. See
|
||||
`tf.compat.v1.set_random_seed` for behavior.
|
||||
num_proj: The output dimensionality for the projection matrices.
|
||||
If None or 0, no projection is performed.
|
||||
name: name of the operation.
|
||||
|
||||
Returns:
|
||||
@ -1035,13 +1128,16 @@ def _cudnn_rnn(inputs,
|
||||
if sequence_lengths is not None:
|
||||
args["sequence_lengths"] = sequence_lengths
|
||||
args["time_major"] = time_major
|
||||
args["num_proj"] = 0 if num_proj is None else num_proj
|
||||
outputs, output_h, output_c, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv3(**args)
|
||||
elif time_major is False:
|
||||
batch_size = array_ops.shape(inputs)[0]
|
||||
max_time = array_ops.shape(inputs)[1]
|
||||
elif time_major is False or num_proj:
|
||||
batch_id, time_id = (1, 0) if time_major else (0, 1)
|
||||
batch_size = array_ops.shape(inputs)[batch_id]
|
||||
max_time = array_ops.shape(inputs)[time_id]
|
||||
sequence_lengths = array_ops.fill([batch_size], max_time)
|
||||
args["sequence_lengths"] = sequence_lengths
|
||||
args["time_major"] = time_major
|
||||
args["num_proj"] = 0 if num_proj is None else num_proj
|
||||
outputs, output_h, output_c, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv3(**args)
|
||||
elif use_cudnn_v2 != "1":
|
||||
outputs, output_h, output_c, _ = gen_cudnn_rnn_ops.cudnn_rnn(**args)
|
||||
@ -1061,6 +1157,7 @@ def cudnn_lstm(inputs,
|
||||
direction=CUDNN_RNN_UNIDIRECTION,
|
||||
dropout=0.,
|
||||
seed=0,
|
||||
num_proj=None,
|
||||
name=None):
|
||||
"""Cudnn LSTM.
|
||||
|
||||
@ -1096,6 +1193,8 @@ def cudnn_lstm(inputs,
|
||||
dropout: whether to enable dropout. With it is 0, dropout is disabled.
|
||||
seed: the op seed used for initializing dropout. See
|
||||
`tf.compat.v1.set_random_seed` for behavior.
|
||||
num_proj: The output dimensionality for the projection matrices.
|
||||
If None or 0, no projection is performed.
|
||||
name: name of the operation.
|
||||
|
||||
Returns:
|
||||
@ -1103,7 +1202,7 @@ def cudnn_lstm(inputs,
|
||||
"""
|
||||
return _cudnn_rnn(inputs, input_h, input_c, params, is_training, CUDNN_LSTM,
|
||||
sequence_lengths, time_major, input_mode, direction,
|
||||
dropout, seed, name)
|
||||
dropout, seed, num_proj, name)
|
||||
|
||||
|
||||
def _cudnn_rnn_no_input_c(inputs,
|
||||
@ -1160,7 +1259,7 @@ def _cudnn_rnn_no_input_c(inputs,
|
||||
outputs, output_h, _ = _cudnn_rnn(inputs, input_h, input_c, params,
|
||||
is_training, rnn_mode, sequence_lengths,
|
||||
time_major, input_mode, direction, dropout,
|
||||
seed, name)
|
||||
seed, None, name)
|
||||
return outputs, output_h
|
||||
|
||||
|
||||
@ -1331,6 +1430,7 @@ def cudnn_rnn_opaque_params_to_canonical(rnn_mode,
|
||||
direction=CUDNN_RNN_UNIDIRECTION,
|
||||
dropout=0,
|
||||
seed=0,
|
||||
num_proj=None,
|
||||
name=None):
|
||||
"""Convert cudnn opaque params to canonical.
|
||||
|
||||
@ -1353,6 +1453,8 @@ def cudnn_rnn_opaque_params_to_canonical(rnn_mode,
|
||||
dropout: whether to enable dropout. With it is 0, dropout is disabled.
|
||||
seed: the op seed used for initializing dropout. See
|
||||
`tf.compat.v1.set_random_seed` for behavior.
|
||||
num_proj: The output dimensionality for the projection matrices.
|
||||
If None or 0, no projection is performed.
|
||||
name: name of the operation.
|
||||
|
||||
Returns:
|
||||
@ -1366,19 +1468,39 @@ def cudnn_rnn_opaque_params_to_canonical(rnn_mode,
|
||||
check_input_mode(input_mode)
|
||||
num_params = _get_num_params(rnn_mode, num_layers, direction)
|
||||
seed, seed2 = random_seed.get_seed(seed)
|
||||
weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical(
|
||||
rnn_mode=rnn_mode,
|
||||
num_layers=num_layers,
|
||||
num_units=num_units,
|
||||
input_size=input_size,
|
||||
params=params,
|
||||
input_mode=input_mode,
|
||||
direction=direction,
|
||||
dropout=dropout,
|
||||
seed=seed,
|
||||
seed2=seed2,
|
||||
num_params=num_params,
|
||||
name=name)
|
||||
num_dirs = 1 if direction == CUDNN_RNN_UNIDIRECTION else 2
|
||||
if num_proj is not None and num_proj != 0:
|
||||
num_params_weights = (num_params + 1 * num_layers * num_dirs)
|
||||
num_params_biases = num_params
|
||||
weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical_v2(
|
||||
rnn_mode=rnn_mode,
|
||||
num_layers=num_layers,
|
||||
num_units=num_units,
|
||||
input_size=input_size,
|
||||
params=params,
|
||||
input_mode=input_mode,
|
||||
direction=direction,
|
||||
dropout=dropout,
|
||||
seed=seed,
|
||||
seed2=seed2,
|
||||
num_params_weights=num_params_weights,
|
||||
num_params_biases=num_params_biases,
|
||||
num_proj=num_proj,
|
||||
name=name)
|
||||
else:
|
||||
weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical(
|
||||
rnn_mode=rnn_mode,
|
||||
num_layers=num_layers,
|
||||
num_units=num_units,
|
||||
input_size=input_size,
|
||||
params=params,
|
||||
input_mode=input_mode,
|
||||
direction=direction,
|
||||
dropout=dropout,
|
||||
seed=seed,
|
||||
seed2=seed2,
|
||||
num_params=num_params,
|
||||
name=name)
|
||||
return weights, biases
|
||||
|
||||
|
||||
@ -1392,6 +1514,7 @@ def cudnn_rnn_canonical_to_opaque_params(rnn_mode,
|
||||
direction=CUDNN_RNN_UNIDIRECTION,
|
||||
dropout=0,
|
||||
seed=0,
|
||||
num_proj=None,
|
||||
name=None):
|
||||
"""Converts params from the canonical format to a specific format of cuDNN.
|
||||
|
||||
@ -1415,6 +1538,8 @@ def cudnn_rnn_canonical_to_opaque_params(rnn_mode,
|
||||
dropout: whether to enable dropout. With it is 0, dropout is disabled.
|
||||
seed: the op seed used for initializing dropout. See
|
||||
`tf.compat.v1.set_random_seed` for behavior.
|
||||
num_proj: The output dimensionality for the projection matrices.
|
||||
If None or 0, no projection is performed.
|
||||
name: name of the operation.
|
||||
|
||||
Returns:
|
||||
@ -1426,19 +1551,35 @@ def cudnn_rnn_canonical_to_opaque_params(rnn_mode,
|
||||
check_direction(direction)
|
||||
check_input_mode(input_mode)
|
||||
seed, seed2 = random_seed.get_seed(seed)
|
||||
return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params(
|
||||
rnn_mode=rnn_mode,
|
||||
num_layers=num_layers,
|
||||
num_units=num_units,
|
||||
input_size=input_size,
|
||||
weights=weights,
|
||||
biases=biases,
|
||||
input_mode=input_mode,
|
||||
direction=direction,
|
||||
dropout=dropout,
|
||||
seed=seed,
|
||||
seed2=seed2,
|
||||
name=name)
|
||||
if num_proj is not None and num_proj != 0:
|
||||
return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params_v2(
|
||||
rnn_mode=rnn_mode,
|
||||
num_layers=num_layers,
|
||||
num_units=num_units,
|
||||
input_size=input_size,
|
||||
weights=weights,
|
||||
biases=biases,
|
||||
input_mode=input_mode,
|
||||
direction=direction,
|
||||
dropout=dropout,
|
||||
seed=seed,
|
||||
seed2=seed2,
|
||||
num_proj=num_proj,
|
||||
name=name)
|
||||
else:
|
||||
return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params(
|
||||
rnn_mode=rnn_mode,
|
||||
num_layers=num_layers,
|
||||
num_units=num_units,
|
||||
input_size=input_size,
|
||||
weights=weights,
|
||||
biases=biases,
|
||||
input_mode=input_mode,
|
||||
direction=direction,
|
||||
dropout=dropout,
|
||||
seed=seed,
|
||||
seed2=seed2,
|
||||
name=name)
|
||||
|
||||
|
||||
def cudnn_rnn_opaque_params_size(rnn_mode,
|
||||
@ -1450,6 +1591,7 @@ def cudnn_rnn_opaque_params_size(rnn_mode,
|
||||
dtype=dtypes.float32,
|
||||
dropout=0,
|
||||
seed=0,
|
||||
num_proj=None,
|
||||
name=None):
|
||||
"""Returns opaque params size for specific Cudnn config.
|
||||
|
||||
@ -1472,6 +1614,8 @@ def cudnn_rnn_opaque_params_size(rnn_mode,
|
||||
dropout: whether to enable dropout. With it is 0, dropout is disabled.
|
||||
seed: the op seed used for initializing dropout. See
|
||||
`tf.compat.v1.set_random_seed` for behavior.
|
||||
num_proj: The output dimensionality for the projection matrices.
|
||||
If None or 0, no projection is performed.
|
||||
name: name of the operation.
|
||||
|
||||
Returns:
|
||||
@ -1488,6 +1632,7 @@ def cudnn_rnn_opaque_params_size(rnn_mode,
|
||||
num_layers=num_layers,
|
||||
num_units=num_units,
|
||||
input_size=input_size,
|
||||
num_proj=num_proj,
|
||||
T=dtype,
|
||||
S=dtypes.int32,
|
||||
dropout=dropout,
|
||||
@ -1516,7 +1661,8 @@ class _CudnnRNN(object):
|
||||
direction=CUDNN_RNN_UNIDIRECTION,
|
||||
dtype=dtypes.float32,
|
||||
dropout=0.,
|
||||
seed=0):
|
||||
seed=0,
|
||||
num_proj=None):
|
||||
"""Creates a CudnnRNN model from model spec.
|
||||
|
||||
Args:
|
||||
@ -1539,6 +1685,8 @@ class _CudnnRNN(object):
|
||||
dropout: whether to enable dropout. With it is 0, dropout is disabled.
|
||||
seed: the op seed used for initializing dropout. See
|
||||
`tf.compat.v1.set_random_seed` for behavior.
|
||||
num_proj: The output dimensionality for the projection matrices.
|
||||
If None or 0, no projection is performed.
|
||||
|
||||
Raises:
|
||||
ValueError: if direction is invalid.
|
||||
@ -1552,6 +1700,7 @@ class _CudnnRNN(object):
|
||||
self._dtype = dtype
|
||||
self._dropout = dropout
|
||||
self._seed = seed
|
||||
self._num_proj = num_proj
|
||||
|
||||
@property
|
||||
def input_mode(self):
|
||||
@ -1577,6 +1726,10 @@ class _CudnnRNN(object):
|
||||
def direction(self):
|
||||
return self._direction
|
||||
|
||||
@property
|
||||
def num_proj(self):
|
||||
return self._num_proj
|
||||
|
||||
def params_size(self):
|
||||
"""Calculates the size of the opaque parameter buffer needed for this model.
|
||||
|
||||
@ -1588,6 +1741,7 @@ class _CudnnRNN(object):
|
||||
num_layers=self._num_layers,
|
||||
num_units=self._num_units,
|
||||
input_size=self._input_size,
|
||||
num_proj=self._num_proj,
|
||||
dtype=self._dtype,
|
||||
dropout=self._dropout,
|
||||
seed=self._seed,
|
||||
@ -1643,7 +1797,8 @@ class _CudnnRNN(object):
|
||||
input_mode=self._input_mode,
|
||||
direction=self._direction,
|
||||
dropout=self._dropout,
|
||||
seed=self._seed)
|
||||
seed=self._seed,
|
||||
num_proj=self._num_proj)
|
||||
|
||||
def params_to_canonical(self, params):
|
||||
"""Converts params from a specific format of cuDNN to the canonical format.
|
||||
@ -1663,7 +1818,8 @@ class _CudnnRNN(object):
|
||||
input_mode=self._input_mode,
|
||||
direction=self._direction,
|
||||
dropout=self._dropout,
|
||||
seed=self._seed)
|
||||
seed=self._seed,
|
||||
num_proj=self._num_proj)
|
||||
|
||||
def canonical_to_params(self, weights, biases):
|
||||
"""Converts params from the canonical format to a specific format of cuDNN.
|
||||
@ -1685,7 +1841,8 @@ class _CudnnRNN(object):
|
||||
input_mode=self._input_mode,
|
||||
direction=self._direction,
|
||||
dropout=self._dropout,
|
||||
seed=self._seed)
|
||||
seed=self._seed,
|
||||
num_proj=self._num_proj)
|
||||
|
||||
|
||||
class CudnnLSTM(_CudnnRNN):
|
||||
@ -1703,7 +1860,8 @@ class CudnnLSTM(_CudnnRNN):
|
||||
direction=CUDNN_RNN_UNIDIRECTION,
|
||||
dtype=dtypes.float32,
|
||||
dropout=0.,
|
||||
seed=0):
|
||||
seed=0,
|
||||
num_proj=None):
|
||||
"""Creates a Cudnn LSTM model from model spec.
|
||||
|
||||
Args:
|
||||
@ -1721,6 +1879,8 @@ class CudnnLSTM(_CudnnRNN):
|
||||
dtype: dtype of params, tf.float32 or tf.float64.
|
||||
dropout: whether to enable dropout. With it is 0, dropout is disabled.
|
||||
seed: the seed used for initializing dropout.
|
||||
num_proj: The output dimensionality for the projection matrices.
|
||||
If None or 0, no projection is performed.
|
||||
"""
|
||||
super(CudnnLSTM, self).__init__(
|
||||
CUDNN_LSTM,
|
||||
@ -1731,7 +1891,8 @@ class CudnnLSTM(_CudnnRNN):
|
||||
direction=direction,
|
||||
dtype=dtype,
|
||||
dropout=dropout,
|
||||
seed=seed)
|
||||
seed=seed,
|
||||
num_proj=num_proj)
|
||||
|
||||
def __call__(self,
|
||||
input_data,
|
||||
|
@ -0,0 +1,36 @@
|
||||
op {
|
||||
graph_op_name: "CudnnRNNCanonicalToParamsV2"
|
||||
summary: "Converts CudnnRNN params from canonical form to usable form. It supports the projection in LSTM."
|
||||
description: <<END
|
||||
Writes a set of weights into the opaque params buffer so they can be used in
|
||||
upcoming training or inferences.
|
||||
|
||||
Note that the params buffer may not be compatible across different GPUs. So any
|
||||
save and restoration should be converted to and from the canonical weights and
|
||||
biases.
|
||||
|
||||
num_layers: Specifies the number of layers in the RNN model.
|
||||
num_units: Specifies the size of the hidden state.
|
||||
input_size: Specifies the size of the input state.
|
||||
weights: the canonical form of weights that can be used for saving
|
||||
and restoration. They are more likely to be compatible across different
|
||||
generations.
|
||||
biases: the canonical form of biases that can be used for saving
|
||||
and restoration. They are more likely to be compatible across different
|
||||
generations.
|
||||
num_params_weigths: number of weight parameter matrix for all layers.
|
||||
num_params_biases: number of bias parameter vector for all layers.
|
||||
rnn_mode: Indicates the type of the RNN model.
|
||||
input_mode: Indicate whether there is a linear projection between the input and
|
||||
The actual computation before the first layer. 'skip_input' is only allowed
|
||||
when input_size == num_units; 'auto_select' implies 'skip_input' when
|
||||
input_size == num_units; otherwise, it implies 'linear_input'.
|
||||
direction: Indicates whether a bidirectional model will be used.
|
||||
dir = (direction == bidirectional) ? 2 : 1
|
||||
dropout: dropout probability. When set to 0., dropout is disabled.
|
||||
seed: the 1st part of a seed to initialize dropout.
|
||||
seed2: the 2nd part of a seed to initialize dropout.
|
||||
num_proj: The output dimensionality for the projection matrices. If None or 0,
|
||||
no projection is performed.
|
||||
END
|
||||
}
|
@ -0,0 +1,36 @@
|
||||
op {
|
||||
graph_op_name: "CudnnRNNParamsToCanonicalV2"
|
||||
summary: "Retrieves CudnnRNN params in canonical form. It supports the projection in LSTM."
|
||||
description: <<END
|
||||
Retrieves a set of weights from the opaque params buffer that can be saved and
|
||||
restored in a way compatible with future runs.
|
||||
|
||||
Note that the params buffer may not be compatible across different GPUs. So any
|
||||
save and restoration should be converted to and from the canonical weights and
|
||||
biases.
|
||||
|
||||
num_layers: Specifies the number of layers in the RNN model.
|
||||
num_units: Specifies the size of the hidden state.
|
||||
input_size: Specifies the size of the input state.
|
||||
num_params_weigths: number of weight parameter matrix for all layers.
|
||||
num_params_biases: number of bias parameter vector for all layers.
|
||||
weights: the canonical form of weights that can be used for saving
|
||||
and restoration. They are more likely to be compatible across different
|
||||
generations.
|
||||
biases: the canonical form of biases that can be used for saving
|
||||
and restoration. They are more likely to be compatible across different
|
||||
generations.
|
||||
rnn_mode: Indicates the type of the RNN model.
|
||||
input_mode: Indicate whether there is a linear projection between the input and
|
||||
The actual computation before the first layer. 'skip_input' is only allowed
|
||||
when input_size == num_units; 'auto_select' implies 'skip_input' when
|
||||
input_size == num_units; otherwise, it implies 'linear_input'.
|
||||
direction: Indicates whether a bidirectional model will be used.
|
||||
dir = (direction == bidirectional) ? 2 : 1
|
||||
dropout: dropout probability. When set to 0., dropout is disabled.
|
||||
seed: the 1st part of a seed to initialize dropout.
|
||||
seed2: the 2nd part of a seed to initialize dropout.
|
||||
num_proj: The output dimensionality for the projection matrices. If None or 0,
|
||||
no projection is performed.
|
||||
END
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "CudnnRNNCanonicalToParamsV2"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "CudnnRNNParamsToCanonicalV2"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -502,20 +502,23 @@ struct CudnnRnnModelShapes {
|
||||
int dir_count;
|
||||
int max_seq_length;
|
||||
int batch_size;
|
||||
int cell_num_units = 0;
|
||||
TensorShape input_shape;
|
||||
TensorShape output_shape;
|
||||
TensorShape hidden_state_shape;
|
||||
TensorShape cell_state_shape;
|
||||
// At present only fields related to cached RnnDescriptor are concerned.
|
||||
bool IsCompatibleWith(const CudnnRnnModelShapes& rhs) const {
|
||||
return num_layers == rhs.num_layers && input_size == rhs.input_size &&
|
||||
num_units == rhs.num_units && dir_count == rhs.dir_count;
|
||||
num_units == rhs.num_units && dir_count == rhs.dir_count &&
|
||||
cell_num_units == rhs.cell_num_units;
|
||||
}
|
||||
string DebugString() const {
|
||||
return strings::Printf(
|
||||
"[num_layers, input_size, num_units, dir_count, max_seq_length, "
|
||||
"batch_size]: [%d, %d, %d, %d, %d, %d] ",
|
||||
"batch_size, cell_num_units]: [%d, %d, %d, %d, %d, %d, %d] ",
|
||||
num_layers, input_size, num_units, dir_count, max_seq_length,
|
||||
batch_size);
|
||||
batch_size, cell_num_units);
|
||||
}
|
||||
};
|
||||
|
||||
@ -562,6 +565,7 @@ Status ExtractForwardInput(OpKernelContext* context,
|
||||
const CudnnModelTypes& model_types, bool time_major,
|
||||
const Tensor** input, const Tensor** input_h,
|
||||
const Tensor** input_c, const Tensor** params,
|
||||
const int num_proj,
|
||||
CudnnRnnModelShapes* model_shapes) {
|
||||
TF_RETURN_IF_ERROR(context->input("input", input));
|
||||
TF_RETURN_IF_ERROR(context->input("input_h", input_h));
|
||||
@ -615,12 +619,48 @@ Status ExtractForwardInput(OpKernelContext* context,
|
||||
model_shapes->hidden_state_shape.DebugString());
|
||||
}
|
||||
if (model_types.HasInputC()) {
|
||||
if ((*input_h)->shape() != (*input_c)->shape()) {
|
||||
return errors::InvalidArgument(
|
||||
"input_h and input_c must have the same shape: ",
|
||||
(*input_h)->shape().DebugString(), " ",
|
||||
(*input_c)->shape().DebugString());
|
||||
model_shapes->cell_num_units = (*input_c)->dim_size(2);
|
||||
if (time_major) {
|
||||
model_shapes->cell_state_shape =
|
||||
TensorShape({model_shapes->dir_count * model_shapes->num_layers,
|
||||
model_shapes->batch_size, model_shapes->cell_num_units});
|
||||
} else {
|
||||
model_shapes->cell_state_shape =
|
||||
TensorShape({model_shapes->batch_size,
|
||||
model_shapes->dir_count * model_shapes->num_layers,
|
||||
model_shapes->cell_num_units});
|
||||
}
|
||||
if (num_proj == 0) {
|
||||
if ((*input_h)->shape() != (*input_c)->shape()) {
|
||||
return errors::InvalidArgument(
|
||||
"input_h and input_c must have the same shape w/o projection: ",
|
||||
(*input_h)->shape().DebugString(), " ",
|
||||
(*input_c)->shape().DebugString());
|
||||
}
|
||||
} else {
|
||||
if ((*input_h)->dim_size(2) > (*input_c)->dim_size(2) ||
|
||||
num_proj != (*input_h)->dim_size(2) ||
|
||||
(*input_h)->dim_size(0) != (*input_c)->dim_size(0) ||
|
||||
(*input_h)->dim_size(1) != (*input_c)->dim_size(1)) {
|
||||
return errors::InvalidArgument(
|
||||
"Invalid input_h and input_c w/ projection size: ", num_proj, " ",
|
||||
(*input_h)->shape().DebugString(), " ",
|
||||
(*input_c)->shape().DebugString());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// dummy cell_state_shape TODO(kaixih): remove the time_major branch
|
||||
if (time_major) {
|
||||
model_shapes->cell_state_shape =
|
||||
TensorShape({model_shapes->dir_count * model_shapes->num_layers,
|
||||
model_shapes->batch_size, model_shapes->num_units});
|
||||
} else {
|
||||
model_shapes->cell_state_shape =
|
||||
TensorShape({model_shapes->batch_size,
|
||||
model_shapes->dir_count * model_shapes->num_layers,
|
||||
model_shapes->num_units});
|
||||
}
|
||||
model_shapes->cell_num_units = 0;
|
||||
}
|
||||
if (time_major) {
|
||||
model_shapes->output_shape =
|
||||
@ -639,18 +679,19 @@ Status ExtractForwardInput(OpKernelContext* context,
|
||||
const CudnnModelTypes& model_types, bool time_major,
|
||||
const Tensor** input, const Tensor** input_h,
|
||||
const Tensor** input_c, const Tensor** params,
|
||||
const Tensor** sequence_lengths,
|
||||
const Tensor** sequence_lengths, const int num_proj,
|
||||
CudnnRnnModelShapes* model_shapes) {
|
||||
TF_RETURN_IF_ERROR(context->input("sequence_lengths", sequence_lengths));
|
||||
return ExtractForwardInput(context, model_types, time_major, input, input_h,
|
||||
input_c, params, model_shapes);
|
||||
input_c, params, num_proj, model_shapes);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status CreateForwardAndBackwardIODescriptors(
|
||||
OpKernelContext* context, const CudnnRnnModelShapes& model_shapes,
|
||||
std::unique_ptr<RnnSequenceTensorDescriptor>* input_desc,
|
||||
std::unique_ptr<RnnStateTensorDescriptor>* state_desc,
|
||||
std::unique_ptr<RnnStateTensorDescriptor>* h_state_desc,
|
||||
std::unique_ptr<RnnStateTensorDescriptor>* c_state_desc,
|
||||
std::unique_ptr<RnnSequenceTensorDescriptor>* output_desc,
|
||||
const absl::Span<const int>& seq_lengths, bool time_major) {
|
||||
StreamExecutor* executor = context->op_device_context()->stream()->parent();
|
||||
@ -658,6 +699,7 @@ Status CreateForwardAndBackwardIODescriptors(
|
||||
|
||||
const TensorShape& input_shape = model_shapes.input_shape;
|
||||
const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape;
|
||||
const TensorShape& cell_state_shape = model_shapes.cell_state_shape;
|
||||
const TensorShape& output_shape = model_shapes.output_shape;
|
||||
|
||||
DCHECK_EQ(input_shape.dims(), 3);
|
||||
@ -689,13 +731,28 @@ Status CreateForwardAndBackwardIODescriptors(
|
||||
hidden_state_shape.dim_size(0), hidden_state_shape.dim_size(1),
|
||||
hidden_state_shape.dim_size(2), data_type);
|
||||
TF_RETURN_IF_ERROR(hidden_state_desc_s.status());
|
||||
*state_desc = hidden_state_desc_s.ConsumeValueOrDie();
|
||||
*h_state_desc = hidden_state_desc_s.ConsumeValueOrDie();
|
||||
} else {
|
||||
auto hidden_state_desc_s = executor->createRnnStateTensorDescriptor(
|
||||
hidden_state_shape.dim_size(1), hidden_state_shape.dim_size(0),
|
||||
hidden_state_shape.dim_size(2), data_type);
|
||||
TF_RETURN_IF_ERROR(hidden_state_desc_s.status());
|
||||
*state_desc = hidden_state_desc_s.ConsumeValueOrDie();
|
||||
*h_state_desc = hidden_state_desc_s.ConsumeValueOrDie();
|
||||
}
|
||||
|
||||
DCHECK_EQ(cell_state_shape.dims(), 3);
|
||||
if (time_major) {
|
||||
auto cell_state_desc_s = executor->createRnnStateTensorDescriptor(
|
||||
cell_state_shape.dim_size(0), cell_state_shape.dim_size(1),
|
||||
cell_state_shape.dim_size(2), data_type);
|
||||
TF_RETURN_IF_ERROR(cell_state_desc_s.status());
|
||||
*c_state_desc = cell_state_desc_s.ConsumeValueOrDie();
|
||||
} else {
|
||||
auto cell_state_desc_s = executor->createRnnStateTensorDescriptor(
|
||||
cell_state_shape.dim_size(1), cell_state_shape.dim_size(0),
|
||||
cell_state_shape.dim_size(2), data_type);
|
||||
TF_RETURN_IF_ERROR(cell_state_desc_s.status());
|
||||
*c_state_desc = cell_state_desc_s.ConsumeValueOrDie();
|
||||
}
|
||||
|
||||
DCHECK_EQ(output_shape.dims(), 3);
|
||||
@ -739,7 +796,8 @@ Status DoForward(OpKernelContext* context, const RnnDescriptor& rnn_desc,
|
||||
ScratchAllocator* workspace_allocator,
|
||||
ProfileResult* output_profile_result) {
|
||||
std::unique_ptr<RnnSequenceTensorDescriptor> input_desc;
|
||||
std::unique_ptr<RnnStateTensorDescriptor> state_desc;
|
||||
std::unique_ptr<RnnStateTensorDescriptor> h_state_desc;
|
||||
std::unique_ptr<RnnStateTensorDescriptor> c_state_desc;
|
||||
std::unique_ptr<RnnSequenceTensorDescriptor> output_desc;
|
||||
|
||||
absl::Span<const int> seq_lengths;
|
||||
@ -748,8 +806,8 @@ Status DoForward(OpKernelContext* context, const RnnDescriptor& rnn_desc,
|
||||
sequence_lengths->template flat<int>().data(), model_shapes.batch_size);
|
||||
}
|
||||
TF_RETURN_IF_ERROR(CreateForwardAndBackwardIODescriptors<T>(
|
||||
context, model_shapes, &input_desc, &state_desc, &output_desc,
|
||||
seq_lengths, time_major));
|
||||
context, model_shapes, &input_desc, &h_state_desc, &c_state_desc,
|
||||
&output_desc, seq_lengths, time_major));
|
||||
|
||||
auto input_data = AsDeviceMemory<T>(input);
|
||||
auto input_h_data = AsDeviceMemory<T>(input_h);
|
||||
@ -769,11 +827,11 @@ Status DoForward(OpKernelContext* context, const RnnDescriptor& rnn_desc,
|
||||
Stream* stream = context->op_device_context()->stream();
|
||||
bool launch_success =
|
||||
stream
|
||||
->ThenRnnForward(rnn_desc, *input_desc, input_data, *state_desc,
|
||||
input_h_data, *state_desc, input_c_data, params_data,
|
||||
*output_desc, &output_data, *state_desc,
|
||||
&output_h_data, *state_desc, &output_c_data,
|
||||
is_training, reserve_space_allocator,
|
||||
->ThenRnnForward(rnn_desc, *input_desc, input_data, *h_state_desc,
|
||||
input_h_data, *c_state_desc, input_c_data,
|
||||
params_data, *output_desc, &output_data,
|
||||
*h_state_desc, &output_h_data, *c_state_desc,
|
||||
&output_c_data, is_training, reserve_space_allocator,
|
||||
workspace_allocator, output_profile_result)
|
||||
.ok();
|
||||
return launch_success
|
||||
@ -801,7 +859,8 @@ Status DoBackward(
|
||||
ScratchAllocator* workspace_allocator,
|
||||
ProfileResult* output_profile_result) {
|
||||
std::unique_ptr<RnnSequenceTensorDescriptor> input_desc;
|
||||
std::unique_ptr<RnnStateTensorDescriptor> state_desc;
|
||||
std::unique_ptr<RnnStateTensorDescriptor> h_state_desc;
|
||||
std::unique_ptr<RnnStateTensorDescriptor> c_state_desc;
|
||||
std::unique_ptr<RnnSequenceTensorDescriptor> output_desc;
|
||||
|
||||
absl::Span<const int> seq_lengths;
|
||||
@ -810,8 +869,8 @@ Status DoBackward(
|
||||
sequence_lengths->template flat<int>().data(), model_shapes.batch_size);
|
||||
}
|
||||
TF_RETURN_IF_ERROR(CreateForwardAndBackwardIODescriptors<T>(
|
||||
context, model_shapes, &input_desc, &state_desc, &output_desc,
|
||||
seq_lengths, time_major));
|
||||
context, model_shapes, &input_desc, &h_state_desc, &c_state_desc,
|
||||
&output_desc, seq_lengths, time_major));
|
||||
|
||||
auto input_data = AsDeviceMemory<T>(input);
|
||||
auto input_h_data = AsDeviceMemory<T>(input_h);
|
||||
@ -847,15 +906,15 @@ Status DoBackward(
|
||||
Stream* stream = context->op_device_context()->stream();
|
||||
bool launch_success =
|
||||
stream
|
||||
->ThenRnnBackward(rnn_desc, *input_desc, input_data, *state_desc,
|
||||
input_h_data, *state_desc, input_c_data,
|
||||
params_data, *output_desc, output_data, *state_desc,
|
||||
output_h_data, *state_desc, output_c_data,
|
||||
output_backprop_data, output_h_backprop_data,
|
||||
output_c_backprop_data, &input_backprop_data,
|
||||
&input_h_backprop_data, &input_c_backprop_data,
|
||||
¶ms_backprop_data, &reserve_space_uint8,
|
||||
workspace_allocator, output_profile_result)
|
||||
->ThenRnnBackward(
|
||||
rnn_desc, *input_desc, input_data, *h_state_desc, input_h_data,
|
||||
*c_state_desc, input_c_data, params_data, *output_desc,
|
||||
output_data, *h_state_desc, output_h_data, *c_state_desc,
|
||||
output_c_data, output_backprop_data, output_h_backprop_data,
|
||||
output_c_backprop_data, &input_backprop_data,
|
||||
&input_h_backprop_data, &input_c_backprop_data,
|
||||
¶ms_backprop_data, &reserve_space_uint8, workspace_allocator,
|
||||
output_profile_result)
|
||||
.ok();
|
||||
return launch_success
|
||||
? Status::OK()
|
||||
@ -932,7 +991,7 @@ class CudnnRNNKernelCommon : public OpKernel {
|
||||
bool ResetRndGenState() { return reset_rnd_gen_state_; }
|
||||
|
||||
template <typename T>
|
||||
Status ExtractCudnnRNNParamsInfo(OpKernelContext* context,
|
||||
Status ExtractCudnnRNNParamsInfo(OpKernelContext* context, int num_proj,
|
||||
std::unique_ptr<RnnDescriptor>* rnn_desc) {
|
||||
const Tensor* num_layers_t = nullptr;
|
||||
TF_RETURN_IF_ERROR(context->input("num_layers", &num_layers_t));
|
||||
@ -953,6 +1012,9 @@ class CudnnRNNKernelCommon : public OpKernel {
|
||||
}
|
||||
int input_size = input_size_t->scalar<int>()();
|
||||
|
||||
int h_num_units = (num_proj == 0 ? num_units : num_proj);
|
||||
int c_num_units = (num_proj == 0 ? 0 : num_units);
|
||||
|
||||
RnnInputMode input_mode;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ToRNNInputMode(rnn_input_mode(), num_units, input_size, &input_mode));
|
||||
@ -962,9 +1024,10 @@ class CudnnRNNKernelCommon : public OpKernel {
|
||||
// random number generator, therefore set state_allocator to nullptr.
|
||||
const AlgorithmConfig algo_config;
|
||||
auto rnn_desc_s = stream->parent()->createRnnDescriptor(
|
||||
num_layers, num_units, input_size, /*batch_size=*/0, input_mode,
|
||||
rnn_direction_mode(), rnn_mode(), ToDataType<T>::value, algo_config,
|
||||
dropout(), seed(), /* state_allocator=*/nullptr);
|
||||
num_layers, h_num_units, input_size, /*cell_size=*/c_num_units,
|
||||
/*batch_size=*/0, input_mode, rnn_direction_mode(), rnn_mode(),
|
||||
ToDataType<T>::value, algo_config, dropout(), seed(),
|
||||
/* state_allocator=*/nullptr);
|
||||
if (!rnn_desc_s.ok()) {
|
||||
return FromExecutorStatus(rnn_desc_s);
|
||||
}
|
||||
@ -983,9 +1046,9 @@ class CudnnRNNKernelCommon : public OpKernel {
|
||||
se::dnn::DataType data_type = ToDataType<T>::value;
|
||||
auto rnn_desc_s = executor->createRnnDescriptor(
|
||||
model_shapes.num_layers, model_shapes.num_units,
|
||||
model_shapes.input_size, model_shapes.batch_size, input_mode,
|
||||
rnn_direction_mode(), rnn_mode(), data_type, algo_config, dropout(),
|
||||
seed(), dropout_state_allocator);
|
||||
model_shapes.input_size, model_shapes.cell_num_units,
|
||||
model_shapes.batch_size, input_mode, rnn_direction_mode(), rnn_mode(),
|
||||
data_type, algo_config, dropout(), seed(), dropout_state_allocator);
|
||||
TF_RETURN_IF_ERROR(rnn_desc_s.status());
|
||||
|
||||
*rnn_desc = rnn_desc_s.ConsumeValueOrDie();
|
||||
@ -1035,11 +1098,18 @@ template <typename T, typename Index>
|
||||
class CudnnRNNParamsSizeOp<GPUDevice, T, Index> : public CudnnRNNKernelCommon {
|
||||
public:
|
||||
explicit CudnnRNNParamsSizeOp(OpKernelConstruction* context)
|
||||
: CudnnRNNKernelCommon(context) {}
|
||||
: CudnnRNNKernelCommon(context) {
|
||||
if (context->HasAttr("num_proj")) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("num_proj", &num_proj_));
|
||||
} else {
|
||||
num_proj_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
std::unique_ptr<RnnDescriptor> rnn_desc;
|
||||
OP_REQUIRES_OK(context, ExtractCudnnRNNParamsInfo<T>(context, &rnn_desc));
|
||||
OP_REQUIRES_OK(context,
|
||||
ExtractCudnnRNNParamsInfo<T>(context, num_proj_, &rnn_desc));
|
||||
int64 params_size_in_bytes = rnn_desc->ParamsSizeInBytes();
|
||||
CHECK(params_size_in_bytes % sizeof(T) == 0)
|
||||
<< "params_size_in_bytes must be multiple of element size";
|
||||
@ -1049,6 +1119,9 @@ class CudnnRNNParamsSizeOp<GPUDevice, T, Index> : public CudnnRNNKernelCommon {
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, {1}, &output_t));
|
||||
*output_t->template flat<Index>().data() = params_size;
|
||||
}
|
||||
|
||||
private:
|
||||
int num_proj_;
|
||||
};
|
||||
|
||||
#define REGISTER_GPU(T) \
|
||||
@ -1074,7 +1147,32 @@ class CudnnRNNParamsToCanonical<GPUDevice, T> : public CudnnRNNKernelCommon {
|
||||
public:
|
||||
explicit CudnnRNNParamsToCanonical(OpKernelConstruction* context)
|
||||
: CudnnRNNKernelCommon(context) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("num_params", &num_params_));
|
||||
if (context->HasAttr("num_params")) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("num_params", &num_params_));
|
||||
} else {
|
||||
num_params_ = 0;
|
||||
}
|
||||
if (context->HasAttr("num_params_weights")) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("num_params_weights",
|
||||
&num_params_weights_));
|
||||
} else {
|
||||
num_params_weights_ = 0;
|
||||
}
|
||||
if (context->HasAttr("num_params_biases")) {
|
||||
OP_REQUIRES_OK(
|
||||
context, context->GetAttr("num_params_biases", &num_params_biases_));
|
||||
} else {
|
||||
num_params_biases_ = 0;
|
||||
}
|
||||
if (context->HasAttr("num_proj")) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("num_proj", &num_proj_));
|
||||
} else {
|
||||
num_proj_ = 0;
|
||||
}
|
||||
if (num_proj_ == 0) {
|
||||
num_params_weights_ = num_params_;
|
||||
num_params_biases_ = num_params_;
|
||||
}
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
@ -1083,7 +1181,8 @@ class CudnnRNNParamsToCanonical<GPUDevice, T> : public CudnnRNNKernelCommon {
|
||||
Stream* stream = context->op_device_context()->stream();
|
||||
|
||||
std::unique_ptr<RnnDescriptor> rnn_desc;
|
||||
OP_REQUIRES_OK(context, ExtractCudnnRNNParamsInfo<T>(context, &rnn_desc));
|
||||
OP_REQUIRES_OK(context,
|
||||
ExtractCudnnRNNParamsInfo<T>(context, num_proj_, &rnn_desc));
|
||||
int64 params_size_in_bytes = rnn_desc->ParamsSizeInBytes();
|
||||
CHECK(params_size_in_bytes % sizeof(T) == 0)
|
||||
<< "params_size_in_bytes must be multiple of element size";
|
||||
@ -1109,25 +1208,46 @@ class CudnnRNNParamsToCanonical<GPUDevice, T> : public CudnnRNNKernelCommon {
|
||||
if (rnn_direction_mode() == RnnDirectionMode::kRnnBidirectional) {
|
||||
num_dirs = 2;
|
||||
}
|
||||
const int num_params_per_layer = num_params_ / num_layers / num_dirs;
|
||||
const int num_params_weights_per_layer =
|
||||
num_params_weights_ / num_layers / num_dirs;
|
||||
// Number of params applied on inputs. The rest are applied on recurrent
|
||||
// hidden states.
|
||||
const int num_params_input_state = num_params_per_layer / 2;
|
||||
CHECK(num_params_ % (num_layers * num_dirs) == 0)
|
||||
<< "Number of params is not a multiple of num_layers * num_dirs.";
|
||||
CHECK(num_params_per_layer % 2 == 0)
|
||||
<< "Number of params per layer is not a even number.";
|
||||
const int num_params_input_state = num_params_weights_per_layer / 2;
|
||||
OP_REQUIRES(
|
||||
context, num_params_weights_ % (num_layers * num_dirs) == 0,
|
||||
errors::InvalidArgument("Number of params (weights) is not a multiple"
|
||||
"of num_layers * num_dirs."));
|
||||
OP_REQUIRES(
|
||||
context, num_params_biases_ % (num_layers * num_dirs) == 0,
|
||||
errors::InvalidArgument("Number of params (biases) is not a multiple"
|
||||
"of num_layers * num_dirs."));
|
||||
if (num_proj_ == 0) {
|
||||
OP_REQUIRES(
|
||||
context, num_params_weights_per_layer % 2 == 0,
|
||||
errors::InvalidArgument("Number of params (weights) per layer is not"
|
||||
"an even number with no projection."));
|
||||
} else {
|
||||
OP_REQUIRES(
|
||||
context, num_params_weights_per_layer % 2 != 0,
|
||||
errors::InvalidArgument("Number of params (weights) per layer is not"
|
||||
"an odl number with projection."));
|
||||
}
|
||||
|
||||
CHECK(num_params_ == rnn_desc->ParamsWeightRegions().size())
|
||||
<< "Number of params mismatch. Expected " << num_params_ << ", got "
|
||||
<< rnn_desc->ParamsWeightRegions().size();
|
||||
OP_REQUIRES(
|
||||
context, num_params_weights_ == rnn_desc->ParamsWeightRegions().size(),
|
||||
errors::InvalidArgument("C Number of params mismatch. Expected ",
|
||||
num_params_weights_, ", got ",
|
||||
rnn_desc->ParamsWeightRegions().size()));
|
||||
int h_num_units = (num_proj_ == 0 ? num_units : num_proj_);
|
||||
int c_num_units = (num_proj_ == 0 ? 0 : num_units);
|
||||
for (int i = 0; i < rnn_desc->ParamsWeightRegions().size(); i++) {
|
||||
int64 size_in_bytes = rnn_desc->ParamsWeightRegions()[i].size;
|
||||
int64 size = size_in_bytes / sizeof(T);
|
||||
const int layer_idx = i / num_params_per_layer;
|
||||
const int index_within_layer = i % num_params_per_layer;
|
||||
int width = 0, height = num_units;
|
||||
// In CuDNN layout, each layer has num_params_per_layer params, with the
|
||||
const int layer_idx = i / num_params_weights_per_layer;
|
||||
const int index_within_layer = i % num_params_weights_per_layer;
|
||||
int width = 0, height = (num_proj_ == 0 ? h_num_units : c_num_units);
|
||||
// In CuDNN layout, each layer has num_params_weights_per_layer params,
|
||||
// with the
|
||||
// first half a.k.a num_params_input_state params applied on the inputs,
|
||||
// and the second half on the recurrent hidden states.
|
||||
bool apply_on_input_state = index_within_layer < num_params_input_state;
|
||||
@ -1135,7 +1255,7 @@ class CudnnRNNParamsToCanonical<GPUDevice, T> : public CudnnRNNKernelCommon {
|
||||
if (layer_idx == 0 && apply_on_input_state) {
|
||||
width = input_size;
|
||||
} else {
|
||||
width = num_units;
|
||||
width = h_num_units;
|
||||
}
|
||||
} else {
|
||||
if (apply_on_input_state) {
|
||||
@ -1145,15 +1265,19 @@ class CudnnRNNParamsToCanonical<GPUDevice, T> : public CudnnRNNKernelCommon {
|
||||
} else {
|
||||
// Following layers, cell inputs are concatenated outputs of
|
||||
// its prior layer.
|
||||
width = 2 * num_units;
|
||||
width = 2 * h_num_units;
|
||||
}
|
||||
} else {
|
||||
width = num_units;
|
||||
width = h_num_units;
|
||||
}
|
||||
}
|
||||
CHECK(size == width * height) << "Params size mismatch. Expected "
|
||||
<< width * height << ", got " << size;
|
||||
Tensor* output = nullptr;
|
||||
int id_in_layer = i % num_params_weights_per_layer;
|
||||
if (num_proj_ != 0 && id_in_layer == num_params_weights_per_layer - 1) {
|
||||
std::swap(height, width);
|
||||
}
|
||||
OP_REQUIRES_OK(context, context->allocate_output(
|
||||
i, TensorShape({height, width}), &output));
|
||||
DeviceMemoryBase data_src_ptr = SliceDeviceMemory(
|
||||
@ -1162,10 +1286,11 @@ class CudnnRNNParamsToCanonical<GPUDevice, T> : public CudnnRNNKernelCommon {
|
||||
stream->ThenMemcpy(&data_dst_ptr, data_src_ptr, size_in_bytes);
|
||||
}
|
||||
|
||||
OP_REQUIRES(context, num_params_ == rnn_desc->ParamsBiasRegions().size(),
|
||||
errors::InvalidArgument("Number of params mismatch. Expected ",
|
||||
num_params_, ", got ",
|
||||
rnn_desc->ParamsBiasRegions().size()));
|
||||
OP_REQUIRES(
|
||||
context, num_params_biases_ == rnn_desc->ParamsBiasRegions().size(),
|
||||
errors::InvalidArgument("A Number of params mismatch. Expected ",
|
||||
num_params_biases_, ", got ",
|
||||
rnn_desc->ParamsBiasRegions().size()));
|
||||
for (int i = 0; i < rnn_desc->ParamsBiasRegions().size(); i++) {
|
||||
int64 size_in_bytes = rnn_desc->ParamsBiasRegions()[i].size;
|
||||
int64 size = size_in_bytes / sizeof(T);
|
||||
@ -1175,7 +1300,7 @@ class CudnnRNNParamsToCanonical<GPUDevice, T> : public CudnnRNNKernelCommon {
|
||||
|
||||
Tensor* output = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(num_params_ + i,
|
||||
context->allocate_output(num_params_weights_ + i,
|
||||
TensorShape({size}), &output));
|
||||
DeviceMemoryBase data_src_ptr = SliceDeviceMemory(
|
||||
input_ptr, rnn_desc->ParamsBiasRegions()[i].offset, size_in_bytes);
|
||||
@ -1186,6 +1311,9 @@ class CudnnRNNParamsToCanonical<GPUDevice, T> : public CudnnRNNKernelCommon {
|
||||
|
||||
private:
|
||||
int num_params_;
|
||||
int num_params_weights_;
|
||||
int num_params_biases_;
|
||||
int num_proj_;
|
||||
};
|
||||
|
||||
#define REGISTER_GPU(T) \
|
||||
@ -1201,17 +1329,37 @@ TF_CALL_float(REGISTER_GPU);
|
||||
TF_CALL_double(REGISTER_GPU);
|
||||
#undef REGISTER_GPU
|
||||
|
||||
#define REGISTER_GPU(T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("CudnnRNNParamsToCanonicalV2") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.HostMemory("num_layers") \
|
||||
.HostMemory("num_units") \
|
||||
.HostMemory("input_size") \
|
||||
.TypeConstraint<T>("T"), \
|
||||
CudnnRNNParamsToCanonical<GPUDevice, T>);
|
||||
TF_CALL_half(REGISTER_GPU);
|
||||
TF_CALL_float(REGISTER_GPU);
|
||||
TF_CALL_double(REGISTER_GPU);
|
||||
#undef REGISTER_GPU
|
||||
|
||||
// Convert weight and bias params from the canonical form to a
|
||||
// platform-specific layout.
|
||||
template <typename T>
|
||||
class CudnnRNNCanonicalToParams<GPUDevice, T> : public CudnnRNNKernelCommon {
|
||||
public:
|
||||
explicit CudnnRNNCanonicalToParams(OpKernelConstruction* context)
|
||||
: CudnnRNNKernelCommon(context) {}
|
||||
: CudnnRNNKernelCommon(context) {
|
||||
if (context->HasAttr("num_proj")) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("num_proj", &num_proj_));
|
||||
} else {
|
||||
num_proj_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
std::unique_ptr<RnnDescriptor> rnn_desc;
|
||||
OP_REQUIRES_OK(context, ExtractCudnnRNNParamsInfo<T>(context, &rnn_desc));
|
||||
OP_REQUIRES_OK(context,
|
||||
ExtractCudnnRNNParamsInfo<T>(context, num_proj_, &rnn_desc));
|
||||
int64 params_size_in_bytes = rnn_desc->ParamsSizeInBytes();
|
||||
CHECK(params_size_in_bytes % sizeof(T) == 0)
|
||||
<< "params_size_in_bytes must be multiple of element size";
|
||||
@ -1232,6 +1380,9 @@ class CudnnRNNCanonicalToParams<GPUDevice, T> : public CudnnRNNKernelCommon {
|
||||
RestoreParams<T>(biases, rnn_desc->ParamsBiasRegions(), &output_ptr,
|
||||
stream);
|
||||
}
|
||||
|
||||
private:
|
||||
int num_proj_;
|
||||
};
|
||||
|
||||
#define REGISTER_GPU(T) \
|
||||
@ -1247,6 +1398,19 @@ TF_CALL_float(REGISTER_GPU);
|
||||
TF_CALL_double(REGISTER_GPU);
|
||||
#undef REGISTER_GPU
|
||||
|
||||
#define REGISTER_GPU(T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("CudnnRNNCanonicalToParamsV2") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.HostMemory("num_layers") \
|
||||
.HostMemory("num_units") \
|
||||
.HostMemory("input_size") \
|
||||
.TypeConstraint<T>("T"), \
|
||||
CudnnRNNCanonicalToParams<GPUDevice, T>);
|
||||
TF_CALL_half(REGISTER_GPU);
|
||||
TF_CALL_float(REGISTER_GPU);
|
||||
TF_CALL_double(REGISTER_GPU);
|
||||
#undef REGISTER_GPU
|
||||
|
||||
// Run the forward operation of the RNN model.
|
||||
template <typename T>
|
||||
class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
|
||||
@ -1264,14 +1428,14 @@ class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
|
||||
void Compute(OpKernelContext* context) override {
|
||||
AlgorithmConfig algo_config;
|
||||
ComputeAndReturnAlgorithm(context, &algo_config, /*var_seq_lengths=*/false,
|
||||
/*time_major=*/true);
|
||||
/*time_major=*/true, /*num_proj=*/0);
|
||||
}
|
||||
|
||||
protected:
|
||||
virtual void ComputeAndReturnAlgorithm(OpKernelContext* context,
|
||||
AlgorithmConfig* output_algo_config,
|
||||
bool var_seq_lengths,
|
||||
bool time_major) {
|
||||
bool var_seq_lengths, bool time_major,
|
||||
int num_proj) {
|
||||
CHECK_NE(output_algo_config, nullptr);
|
||||
|
||||
const Tensor* input = nullptr;
|
||||
@ -1281,14 +1445,15 @@ class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
|
||||
const Tensor* sequence_lengths = nullptr;
|
||||
CudnnRnnModelShapes model_shapes;
|
||||
if (var_seq_lengths) {
|
||||
OP_REQUIRES_OK(context, ExtractForwardInput(
|
||||
context, model_types(), time_major, &input,
|
||||
&input_h, &input_c, ¶ms,
|
||||
&sequence_lengths, num_proj, &model_shapes));
|
||||
} else {
|
||||
OP_REQUIRES_OK(context,
|
||||
ExtractForwardInput(context, model_types(), time_major,
|
||||
&input, &input_h, &input_c, ¶ms,
|
||||
&sequence_lengths, &model_shapes));
|
||||
} else {
|
||||
OP_REQUIRES_OK(context, ExtractForwardInput(
|
||||
context, model_types(), time_major, &input,
|
||||
&input_h, &input_c, ¶ms, &model_shapes));
|
||||
num_proj, &model_shapes));
|
||||
}
|
||||
RnnInputMode input_mode;
|
||||
OP_REQUIRES_OK(context,
|
||||
@ -1362,13 +1527,14 @@ class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
|
||||
Tensor** output_c) {
|
||||
const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape;
|
||||
const TensorShape& output_shape = model_shapes.output_shape;
|
||||
const TensorShape& cell_state_shape = model_shapes.cell_state_shape;
|
||||
|
||||
TF_RETURN_IF_ERROR(context->allocate_output(0, output_shape, output));
|
||||
TF_RETURN_IF_ERROR(
|
||||
context->allocate_output(1, hidden_state_shape, output_h));
|
||||
if (HasInputC()) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
context->allocate_output(2, hidden_state_shape, output_c));
|
||||
context->allocate_output(2, cell_state_shape, output_c));
|
||||
} else {
|
||||
// Only LSTM uses input_c and output_c. So for all other models, we only
|
||||
// need to create dummy outputs.
|
||||
@ -1414,7 +1580,7 @@ class CudnnRNNForwardOpV2<GPUDevice, T>
|
||||
AlgorithmConfig best_algo_config;
|
||||
CudnnRNNForwardOp<GPUDevice, T>::ComputeAndReturnAlgorithm(
|
||||
context, &best_algo_config, /*var_seq_lengths=*/false,
|
||||
/*time_major=*/true);
|
||||
/*time_major=*/true, /*num_proj=*/0);
|
||||
if (!context->status().ok()) {
|
||||
return;
|
||||
}
|
||||
@ -1613,13 +1779,18 @@ class CudnnRNNForwardOpV3<GPUDevice, T>
|
||||
explicit CudnnRNNForwardOpV3(OpKernelConstruction* context)
|
||||
: CudnnRNNForwardOp<GPUDevice, T>(context) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("time_major", &time_major_));
|
||||
if (context->HasAttr("num_proj")) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("num_proj", &num_proj_));
|
||||
} else {
|
||||
num_proj_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
AlgorithmConfig best_algo_config;
|
||||
CudnnRNNForwardOp<GPUDevice, T>::ComputeAndReturnAlgorithm(
|
||||
context, &best_algo_config, /*var_seq_lengths=*/true,
|
||||
/*time_major=*/time_major());
|
||||
/*time_major=*/time_major(), num_proj_);
|
||||
if (!context->status().ok()) {
|
||||
return;
|
||||
}
|
||||
@ -1631,6 +1802,9 @@ class CudnnRNNForwardOpV3<GPUDevice, T>
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(4, {}, &output_host_reserved));
|
||||
}
|
||||
|
||||
private:
|
||||
int num_proj_;
|
||||
};
|
||||
|
||||
#define REGISTER_GPU(T) \
|
||||
@ -1654,12 +1828,12 @@ class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
|
||||
: CudnnRNNKernelCommon(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
ComputeImpl(context, false, true);
|
||||
ComputeImpl(context, false, true, 0);
|
||||
}
|
||||
|
||||
protected:
|
||||
virtual void ComputeImpl(OpKernelContext* context, bool var_seq_lengths,
|
||||
bool time_major) {
|
||||
bool time_major, int num_proj) {
|
||||
const Tensor* input = nullptr;
|
||||
const Tensor* input_h = nullptr;
|
||||
const Tensor* input_c = nullptr;
|
||||
@ -1667,14 +1841,15 @@ class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
|
||||
const Tensor* sequence_lengths = nullptr;
|
||||
CudnnRnnModelShapes model_shapes;
|
||||
if (var_seq_lengths) {
|
||||
OP_REQUIRES_OK(context, ExtractForwardInput(
|
||||
context, model_types(), time_major, &input,
|
||||
&input_h, &input_c, ¶ms,
|
||||
&sequence_lengths, num_proj, &model_shapes));
|
||||
} else {
|
||||
OP_REQUIRES_OK(context,
|
||||
ExtractForwardInput(context, model_types(), time_major,
|
||||
&input, &input_h, &input_c, ¶ms,
|
||||
&sequence_lengths, &model_shapes));
|
||||
} else {
|
||||
OP_REQUIRES_OK(context, ExtractForwardInput(
|
||||
context, model_types(), time_major, &input,
|
||||
&input_h, &input_c, ¶ms, &model_shapes));
|
||||
num_proj, &model_shapes));
|
||||
}
|
||||
RnnInputMode input_mode;
|
||||
OP_REQUIRES_OK(context,
|
||||
@ -1757,6 +1932,7 @@ class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
|
||||
TF_RETURN_IF_ERROR(context->input("reserve_space", reserve_space));
|
||||
const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape;
|
||||
const TensorShape& output_shape = model_shapes.output_shape;
|
||||
const TensorShape& cell_state_shape = model_shapes.cell_state_shape;
|
||||
|
||||
if (output_shape != (*output)->shape()) {
|
||||
return errors::InvalidArgument(
|
||||
@ -1782,16 +1958,16 @@ class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
|
||||
}
|
||||
|
||||
if (model_types.HasInputC()) {
|
||||
if (hidden_state_shape != (*output_c)->shape()) {
|
||||
if (cell_state_shape != (*output_c)->shape()) {
|
||||
return errors::InvalidArgument(
|
||||
"Invalid output_c shape: ", (*output_c)->shape().DebugString(), " ",
|
||||
hidden_state_shape.DebugString());
|
||||
cell_state_shape.DebugString());
|
||||
}
|
||||
if (hidden_state_shape != (*output_c_backprop)->shape()) {
|
||||
if (cell_state_shape != (*output_c_backprop)->shape()) {
|
||||
return errors::InvalidArgument(
|
||||
"Invalid output_c_backprop shape: ",
|
||||
(*output_c_backprop)->shape().DebugString(), " ",
|
||||
hidden_state_shape.DebugString());
|
||||
cell_state_shape.DebugString());
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
@ -1804,6 +1980,7 @@ class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
|
||||
Tensor** input_c_backprop, Tensor** params_backprop) {
|
||||
const TensorShape& input_shape = model_shapes.input_shape;
|
||||
const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape;
|
||||
const TensorShape& cell_state_shape = model_shapes.cell_state_shape;
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
context->allocate_output(0, input_shape, input_backprop));
|
||||
@ -1811,7 +1988,7 @@ class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
|
||||
context->allocate_output(1, hidden_state_shape, input_h_backprop));
|
||||
if (HasInputC()) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
context->allocate_output(2, hidden_state_shape, input_c_backprop));
|
||||
context->allocate_output(2, cell_state_shape, input_c_backprop));
|
||||
} else {
|
||||
// Only LSTM uses input_c and output_c. So for all other models, we only
|
||||
// need to create dummy outputs.
|
||||
@ -1879,11 +2056,20 @@ class CudnnRNNBackwardOpV3<GPUDevice, T>
|
||||
explicit CudnnRNNBackwardOpV3(OpKernelConstruction* context)
|
||||
: CudnnRNNBackwardOp<GPUDevice, T>(context) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("time_major", &time_major_));
|
||||
if (context->HasAttr("num_proj")) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("num_proj", &num_proj_));
|
||||
} else {
|
||||
num_proj_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
CudnnRNNBackwardOp<GPUDevice, T>::ComputeImpl(context, true, time_major());
|
||||
CudnnRNNBackwardOp<GPUDevice, T>::ComputeImpl(context, true, time_major(),
|
||||
num_proj_);
|
||||
}
|
||||
|
||||
private:
|
||||
int num_proj_;
|
||||
};
|
||||
|
||||
#define REGISTER_GPU(T) \
|
||||
|
@ -49,6 +49,7 @@ REGISTER_OP("CudnnRNNParamsSize")
|
||||
.Attr("dropout: float = 0.0")
|
||||
.Attr("seed: int = 0")
|
||||
.Attr("seed2: int = 0")
|
||||
.Attr("num_proj: int = 0")
|
||||
.Output("params_size: S")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
ShapeHandle unused;
|
||||
@ -166,11 +167,13 @@ REGISTER_OP("CudnnRNNV3")
|
||||
.Attr("dropout: float = 0.0")
|
||||
.Attr("seed: int = 0")
|
||||
.Attr("seed2: int = 0")
|
||||
.Attr("num_proj: int = 0")
|
||||
.Attr("is_training: bool = true")
|
||||
.Attr("time_major: bool = true")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
auto input_shape = c->input(0);
|
||||
auto input_h_shape = c->input(1);
|
||||
auto input_c_shape = c->input(2);
|
||||
auto max_seq_length = c->Dim(input_shape, 0);
|
||||
auto batch_size = c->Dim(input_shape, 1);
|
||||
auto num_units = c->Dim(input_h_shape, 2);
|
||||
@ -185,7 +188,7 @@ REGISTER_OP("CudnnRNNV3")
|
||||
c->MakeShape({max_seq_length, batch_size, output_size});
|
||||
auto output_h_shape = input_h_shape;
|
||||
auto output_c_shape TF_ATTRIBUTE_UNUSED =
|
||||
(rnn_mode == "lstm") ? output_h_shape : c->MakeShape({});
|
||||
(rnn_mode == "lstm") ? input_c_shape : c->MakeShape({});
|
||||
c->set_output(0, output_shape);
|
||||
c->set_output(1, output_h_shape);
|
||||
c->set_output(2, output_c_shape);
|
||||
@ -293,6 +296,7 @@ REGISTER_OP("CudnnRNNBackpropV3")
|
||||
.Attr("dropout: float = 0.0")
|
||||
.Attr("seed: int = 0")
|
||||
.Attr("seed2: int = 0")
|
||||
.Attr("num_proj: int = 0")
|
||||
.Attr("time_major: bool = true")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
auto input_shape = c->input(0);
|
||||
@ -338,6 +342,43 @@ REGISTER_OP("CudnnRNNParamsToCanonical")
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
REGISTER_OP("CudnnRNNParamsToCanonicalV2")
|
||||
.Input("num_layers: int32")
|
||||
.Input("num_units: int32")
|
||||
.Input("input_size: int32")
|
||||
.Input("params: T")
|
||||
.Output("weights: num_params_weights * T")
|
||||
.Output("biases: num_params_biases * T")
|
||||
.Attr("T: {float16, float32, float64}")
|
||||
.Attr("num_params_weights: int")
|
||||
.Attr("num_params_biases: int")
|
||||
.Attr(kRNNModeAttrs)
|
||||
.Attr(kRNNInputModeAttrs)
|
||||
.Attr(kRNNDirectionAttrs)
|
||||
.Attr("dropout: float = 0.0")
|
||||
.Attr("seed: int = 0")
|
||||
.Attr("seed2: int = 0")
|
||||
.Attr("num_proj: int = 0")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
ShapeHandle unused;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &unused));
|
||||
int num_params_weights;
|
||||
int num_params_biases;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("num_params_weights", &num_params_weights));
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("num_params_biases", &num_params_biases));
|
||||
// Set shape for weight matrices
|
||||
for (int i = 0; i < num_params_weights; i++) {
|
||||
c->set_output(i, c->Matrix(InferenceContext::kUnknownDim,
|
||||
InferenceContext::kUnknownDim));
|
||||
}
|
||||
// Set shape for bias vectors
|
||||
for (int i = 0; i < num_params_biases; i++) {
|
||||
c->set_output(num_params_weights + i,
|
||||
c->Vector(InferenceContext::kUnknownDim));
|
||||
}
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
REGISTER_OP("CudnnRNNCanonicalToParams")
|
||||
.Input("num_layers: int32")
|
||||
.Input("num_units: int32")
|
||||
@ -358,4 +399,26 @@ REGISTER_OP("CudnnRNNCanonicalToParams")
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
REGISTER_OP("CudnnRNNCanonicalToParamsV2")
|
||||
.Input("num_layers: int32")
|
||||
.Input("num_units: int32")
|
||||
.Input("input_size: int32")
|
||||
.Input("weights: num_params_weights * T")
|
||||
.Input("biases: num_params_biases * T")
|
||||
.Output("params: T")
|
||||
.Attr("T: {float16, float32, float64}")
|
||||
.Attr("num_params_weights: int")
|
||||
.Attr("num_params_biases: int")
|
||||
.Attr(kRNNModeAttrs)
|
||||
.Attr(kRNNInputModeAttrs)
|
||||
.Attr(kRNNDirectionAttrs)
|
||||
.Attr("dropout: float = 0.0")
|
||||
.Attr("seed: int = 0")
|
||||
.Attr("seed2: int = 0")
|
||||
.Attr("num_proj: int = 0")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -111,6 +111,8 @@ TEST(CudnnRNNOpsTest, ForwardV3Lstm_ShapeFn) {
|
||||
std::vector<int> input_shape = {max_seq_length, batch_size, num_units};
|
||||
std::vector<int> input_h_shape = {num_layers * dir_count, batch_size,
|
||||
num_units};
|
||||
std::vector<int> input_c_shape = {num_layers * dir_count, batch_size,
|
||||
num_units};
|
||||
std::vector<int> output_shape = {max_seq_length, batch_size,
|
||||
num_units * dir_count};
|
||||
std::vector<int> seq_lengths_shape = {batch_size};
|
||||
@ -119,9 +121,9 @@ TEST(CudnnRNNOpsTest, ForwardV3Lstm_ShapeFn) {
|
||||
};
|
||||
string input_shapes_desc = strings::StrCat(
|
||||
shape_to_str(input_shape), ";", shape_to_str(input_h_shape), ";",
|
||||
shape_to_str(input_h_shape), ";", "[?]", ";",
|
||||
shape_to_str(input_c_shape), ";", "[?]", ";",
|
||||
shape_to_str(seq_lengths_shape));
|
||||
string output_shapes_desc = "[d0_0,d0_1,d1_2];in1;in1;?;?";
|
||||
string output_shapes_desc = "[d0_0,d0_1,d1_2];in1;in2;?;?";
|
||||
|
||||
ShapeInferenceTestOp op("CudnnRNNV3");
|
||||
TF_ASSERT_OK(NodeDefBuilder("test", "CudnnRNNV3")
|
||||
|
@ -98,6 +98,7 @@ def _cudnn_rnn_backwardv3(op, *grads):
|
||||
seed=op.get_attr("seed"),
|
||||
seed2=op.get_attr("seed2"),
|
||||
time_major=op.get_attr("time_major"),
|
||||
num_proj=op.get_attr("num_proj"),
|
||||
rnn_mode=op.get_attr("rnn_mode"),
|
||||
input_mode=op.get_attr("input_mode"),
|
||||
direction=op.get_attr("direction")) + (None,)
|
||||
|
@ -1002,8 +1002,8 @@ class CudnnRnnParamsDescriptor {
|
||||
class CudnnRnnDescriptor : public dnn::RnnDescriptor {
|
||||
CudnnRnnDescriptor(const CudnnHandle& cudnn, gpu::RnnDescriptor rnn_desc,
|
||||
PersistentRnnPlan rnn_plan, int num_layers,
|
||||
int hidden_size, int input_size, int batch_size,
|
||||
cudnnRNNInputMode_t input_mode,
|
||||
int hidden_size, int input_size, int cell_size,
|
||||
int batch_size, cudnnRNNInputMode_t input_mode,
|
||||
cudnnDirectionMode_t direction_mode,
|
||||
cudnnRNNMode_t rnn_mode, cudnnDataType_t data_type,
|
||||
cudnnDataType_t compute_type,
|
||||
@ -1015,6 +1015,7 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor {
|
||||
num_layers_(num_layers),
|
||||
hidden_size_(hidden_size),
|
||||
input_size_(input_size),
|
||||
cell_size_(cell_size),
|
||||
batch_size_(batch_size),
|
||||
rnn_algo_(ToCudnnRNNAlgo(algorithm_config.algorithm())),
|
||||
input_mode_(input_mode),
|
||||
@ -1031,7 +1032,7 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor {
|
||||
|
||||
static port::StatusOr<CudnnRnnDescriptor> Create(
|
||||
const CudnnHandle& cudnn, int num_layers, int hidden_size, int input_size,
|
||||
int batch_size, cudnnRNNInputMode_t input_mode,
|
||||
int cell_size, int batch_size, cudnnRNNInputMode_t input_mode,
|
||||
cudnnDirectionMode_t direction_mode, cudnnRNNMode_t rnn_mode,
|
||||
cudnnDataType_t data_type, cudnnDataType_t compute_type,
|
||||
const dnn::AlgorithmConfig& algorithm_config, float dropout, uint64 seed,
|
||||
@ -1044,12 +1045,28 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor {
|
||||
cudnnRNNAlgo_t rnn_algo = ToCudnnRNNAlgo(algorithm_config.algorithm());
|
||||
|
||||
// TODO: allow the user to choose an algorithm.
|
||||
int unified_size = hidden_size;
|
||||
bool use_projection = cell_size != 0 && hidden_size < cell_size;
|
||||
if (use_projection) {
|
||||
unified_size = cell_size;
|
||||
}
|
||||
RETURN_IF_CUDNN_ERROR(cudnnSetRNNDescriptor_v6(
|
||||
cudnn.handle(), /*rnnDesc=*/rnn_desc.get(), /*hiddenSize=*/hidden_size,
|
||||
/*numLayers=*/num_layers, /*dropoutDesc=*/dropout_desc.handle(),
|
||||
/*inputMode=*/input_mode, /*direction=*/direction_mode,
|
||||
/*mode=*/rnn_mode, /*algo=*/rnn_algo,
|
||||
cudnn.handle(), /*rnnDesc=*/rnn_desc.get(),
|
||||
/*hiddenSize=*/unified_size, /*numLayers=*/num_layers,
|
||||
/*dropoutDesc=*/dropout_desc.handle(), /*inputMode=*/input_mode,
|
||||
/*direction=*/direction_mode, /*mode=*/rnn_mode, /*algo=*/rnn_algo,
|
||||
/*dataType=*/compute_type));
|
||||
if (use_projection) {
|
||||
#if CUDNN_VERSION >= 7101
|
||||
RETURN_IF_CUDNN_ERROR(cudnnSetRNNProjectionLayers(
|
||||
cudnn.handle(), /*rnnDesc=*/rnn_desc.get(),
|
||||
/*recProjSize=*/hidden_size, /*outProjSize=*/0));
|
||||
#else
|
||||
return port::Status(port::error::INVALID_ARGUMENT,
|
||||
"No supported cudnnSetRNNProjectionLayers when "
|
||||
"CUDNN_VERSION < 7.1.1");
|
||||
#endif
|
||||
}
|
||||
|
||||
// TODO: For now, we only use cudnnRNN**Ex API to process padded inputs.
|
||||
// But in the future if these APIs are used to process full length arrays,
|
||||
@ -1106,9 +1123,9 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor {
|
||||
#endif // CUDNN_VERSION >= 7000
|
||||
|
||||
return CudnnRnnDescriptor(cudnn, std::move(rnn_desc), std::move(rnn_plan),
|
||||
num_layers, hidden_size, input_size, batch_size,
|
||||
input_mode, direction_mode, rnn_mode, data_type,
|
||||
compute_type, algorithm_config,
|
||||
num_layers, hidden_size, input_size, cell_size,
|
||||
batch_size, input_mode, direction_mode, rnn_mode,
|
||||
data_type, compute_type, algorithm_config,
|
||||
std::move(dropout_desc), std::move(params_desc));
|
||||
}
|
||||
|
||||
@ -1116,6 +1133,7 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor {
|
||||
int num_layers() const { return num_layers_; }
|
||||
int hidden_size() const { return hidden_size_; }
|
||||
int input_size() const { return input_size_; }
|
||||
int cell_size() const { return cell_size_; }
|
||||
int batch_size() const { return batch_size_; }
|
||||
cudnnRNNInputMode_t input_mode() const { return input_mode_; }
|
||||
cudnnDirectionMode_t direction_mode() const { return direction_mode_; }
|
||||
@ -1144,6 +1162,9 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor {
|
||||
int num_layers_;
|
||||
int hidden_size_;
|
||||
int input_size_;
|
||||
// cell_size_ is the size of cell state, which will be different from
|
||||
// hidden_size_ if the projection is used.
|
||||
int cell_size_;
|
||||
// batch_size_ is set to -1 when not using CUDNN_RNN_ALGO_PERSIST_DYNAMIC
|
||||
// algorithm.
|
||||
int batch_size_;
|
||||
@ -1161,6 +1182,69 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor {
|
||||
|
||||
namespace {
|
||||
|
||||
// Check if the LSTM projection is used. If yes, an additional weigth matrix
|
||||
// (projection matrix) will be fetched to the 'weights'. Otherwise, nothing will
|
||||
// be done.
|
||||
port::Status CheckAndFetchProjectionWeights(
|
||||
const CudnnHandle& cudnn, cudnnRNNDescriptor_t rnn_desc, const int layer,
|
||||
const TensorDescriptor& input_desc, const FilterDescriptor& filter_desc,
|
||||
const FilterDescriptor& region_desc_handle,
|
||||
dnn::RnnDescriptor::ParamsRegions* weights) {
|
||||
#if CUDNN_VERSION >= 7101
|
||||
int hidden_size_v;
|
||||
int num_layers_v;
|
||||
cudnnDropoutDescriptor_t dropout_desc;
|
||||
cudnnRNNInputMode_t input_mode;
|
||||
cudnnDirectionMode_t direction;
|
||||
cudnnRNNMode_t mode;
|
||||
cudnnRNNAlgo_t algo;
|
||||
cudnnDataType_t data_type;
|
||||
RETURN_IF_CUDNN_ERROR(cudnnGetRNNDescriptor(
|
||||
/*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
|
||||
/*hiddenSize=*/&hidden_size_v,
|
||||
/*numLayers=*/&num_layers_v,
|
||||
/*dropoutDesc=*/&dropout_desc,
|
||||
/*inputMode=*/&input_mode,
|
||||
/*direction=*/&direction,
|
||||
/*mode=*/&mode,
|
||||
/*algo=*/&algo,
|
||||
/*dataType=*/&data_type));
|
||||
int rec_proj_size_v;
|
||||
int out_proj_size_v;
|
||||
RETURN_IF_CUDNN_ERROR(cudnnGetRNNProjectionLayers(
|
||||
/*handle=*/cudnn.handle(),
|
||||
/*rnnDesc=*/rnn_desc,
|
||||
/*recProjSize*/ &rec_proj_size_v,
|
||||
/*outProjSize*/ &out_proj_size_v));
|
||||
if (rec_proj_size_v != hidden_size_v) {
|
||||
void* offset = nullptr;
|
||||
int region_id = 8;
|
||||
RETURN_IF_CUDNN_ERROR(cudnnGetRNNLinLayerMatrixParams(
|
||||
/*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
|
||||
/*layer=*/layer, /*xDesc=*/input_desc.get(),
|
||||
/*wDesc=*/filter_desc.get(),
|
||||
/*w=*/nullptr, /*linLayerID=*/region_id,
|
||||
/*linLayerMatDesc=*/region_desc_handle.get(),
|
||||
/*linLayerMat or linLayerBias=*/&offset));
|
||||
int dims[] = {1, 1, 1};
|
||||
cudnnDataType_t data_type;
|
||||
cudnnTensorFormat_t tensor_format;
|
||||
int n_dims;
|
||||
RETURN_IF_CUDNN_ERROR(cudnnGetFilterNdDescriptor(
|
||||
/*filterDesc=*/region_desc_handle.get(),
|
||||
/*nbDimsRequested=*/sizeof(dims) / sizeof(dims[0]),
|
||||
/*dataType=*/&data_type, /*format=*/&tensor_format,
|
||||
/*nbDims=*/&n_dims, /*filterDimA=*/dims));
|
||||
int64 size =
|
||||
dims[0] * dims[1] * dims[2] * CudnnDataTypeToByteSize(data_type);
|
||||
dnn::RnnDescriptor::ParamsRegion region = {reinterpret_cast<int64>(offset),
|
||||
size};
|
||||
weights->push_back(region);
|
||||
}
|
||||
#endif // CUDNN_VERSION >= 7101
|
||||
return port::Status::OK();
|
||||
}
|
||||
|
||||
port::StatusOr<CudnnRnnParamsDescriptor> CudnnRnnParamsDescriptor::Create(
|
||||
const CudnnHandle& cudnn, int input_size, cudnnDataType_t data_type,
|
||||
cudnnRNNDescriptor_t rnn_desc, cudnnRNNMode_t rnn_mode,
|
||||
@ -1248,6 +1332,9 @@ port::StatusOr<CudnnRnnParamsDescriptor> CudnnRnnParamsDescriptor::Create(
|
||||
(type == 0 ? weights : biases).push_back(region);
|
||||
}
|
||||
}
|
||||
TF_RETURN_IF_ERROR(CheckAndFetchProjectionWeights(
|
||||
cudnn, rnn_desc, layer, input_desc, filter_desc, region_desc_handle,
|
||||
&weights));
|
||||
}
|
||||
|
||||
return CudnnRnnParamsDescriptor(std::move(filter_desc), params_size_in_bytes,
|
||||
@ -1412,6 +1499,7 @@ struct RnnModelDims {
|
||||
int max_seq_length = 0;
|
||||
int hidden_size = 0;
|
||||
int input_size = 0;
|
||||
int cell_size = 0;
|
||||
int dir_count = 0;
|
||||
};
|
||||
|
||||
@ -1437,6 +1525,7 @@ port::StatusOr<RnnModelDims> ExtractAndCheckRnnForward(
|
||||
model_dims.max_seq_length = input_desc.max_seq_length();
|
||||
model_dims.hidden_size = rnn_desc.hidden_size();
|
||||
model_dims.input_size = input_desc.data_size();
|
||||
model_dims.cell_size = rnn_desc.cell_size();
|
||||
model_dims.dir_count =
|
||||
(rnn_desc.direction_mode() == CUDNN_BIDIRECTIONAL) ? 2 : 1;
|
||||
|
||||
@ -1447,9 +1536,11 @@ port::StatusOr<RnnModelDims> ExtractAndCheckRnnForward(
|
||||
input_h_desc.data_size() == model_dims.hidden_size)) {
|
||||
return port::Status(port::error::INVALID_ARGUMENT, "Invalid input_h shape");
|
||||
}
|
||||
// The LSTM projection will be used if input_h_desc.data_size() <
|
||||
// input_c_desc.data_size()
|
||||
if (!(input_h_desc.num_layers() == input_c_desc.num_layers() &&
|
||||
input_h_desc.batch_size() == input_c_desc.batch_size() &&
|
||||
input_h_desc.data_size() == input_c_desc.data_size())) {
|
||||
input_h_desc.data_size() <= input_c_desc.data_size())) {
|
||||
return port::Status(port::error::INVALID_ARGUMENT, "Invalid input_c shape");
|
||||
}
|
||||
if (!(output_desc.max_seq_length() == model_dims.max_seq_length &&
|
||||
@ -1466,7 +1557,7 @@ port::StatusOr<RnnModelDims> ExtractAndCheckRnnForward(
|
||||
}
|
||||
if (!(input_h_desc.num_layers() == output_c_desc.num_layers() &&
|
||||
input_h_desc.batch_size() == output_c_desc.batch_size() &&
|
||||
input_h_desc.data_size() == output_c_desc.data_size())) {
|
||||
input_h_desc.data_size() <= output_c_desc.data_size())) {
|
||||
return port::Status(port::error::INVALID_ARGUMENT,
|
||||
"Invalid output_c shape");
|
||||
}
|
||||
@ -1872,18 +1963,18 @@ port::Status CudnnSupport::DoRnnBackwardImpl(
|
||||
|
||||
port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
|
||||
CudnnSupport::createRnnDescriptor(
|
||||
int num_layers, int hidden_size, int input_size, int batch_size,
|
||||
dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode,
|
||||
dnn::RnnMode rnn_mode, dnn::DataType data_type,
|
||||
const dnn::AlgorithmConfig& algorithm_config, float dropout, uint64 seed,
|
||||
ScratchAllocator* state_allocator) {
|
||||
int num_layers, int hidden_size, int input_size, int cell_size,
|
||||
int batch_size, dnn::RnnInputMode input_mode,
|
||||
dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
|
||||
dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config,
|
||||
float dropout, uint64 seed, ScratchAllocator* state_allocator) {
|
||||
// Setting up a cudnnRNNDescriptor requires a cuDNN handle, but because it's
|
||||
// not enqueueing anything into a stream, we pass in the null stream.
|
||||
auto cudnn = cudnn_->GetHandle(parent_, /*stream=*/nullptr);
|
||||
SE_ASSIGN_OR_RETURN(
|
||||
CudnnRnnDescriptor rnn_desc,
|
||||
CudnnRnnDescriptor::Create(
|
||||
cudnn, num_layers, hidden_size, input_size, batch_size,
|
||||
cudnn, num_layers, hidden_size, input_size, cell_size, batch_size,
|
||||
ToCudnnRnnInputMode(input_mode),
|
||||
ToCudnnRnnDirectionMode(direction_mode), ToCudnnRnnMode(rnn_mode),
|
||||
ToCudnnDataType(data_type), GetRnnComputeType(data_type),
|
||||
|
@ -47,11 +47,11 @@ class CudnnSupport : public dnn::DnnSupport {
|
||||
port::StatusOr<perftools::gputools::dnn::VersionInfo> GetVersion() override;
|
||||
|
||||
port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor(
|
||||
int num_layers, int hidden_size, int input_size, int batch_size,
|
||||
dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode,
|
||||
dnn::RnnMode rnn_mode, dnn::DataType data_type,
|
||||
const dnn::AlgorithmConfig& algorithm_config, float dropout, uint64 seed,
|
||||
ScratchAllocator* state_allocator) override;
|
||||
int num_layers, int hidden_size, int input_size, int cell_size,
|
||||
int batch_size, dnn::RnnInputMode input_mode,
|
||||
dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
|
||||
dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config,
|
||||
float dropout, uint64 seed, ScratchAllocator* state_allocator) override;
|
||||
|
||||
port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
|
||||
createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
|
||||
|
@ -2065,6 +2065,7 @@ class DnnSupport {
|
||||
// num_layers: the number of layers for a RNN model.
|
||||
// hidden_size: the size of the hidden state.
|
||||
// input_size: the size of the input state.
|
||||
// cell_size: the size of the cell state
|
||||
// input_mode: an enum to specify whether a linear transformation is added
|
||||
// after the input state. If input_size is different from hidden_size, this
|
||||
// is required.
|
||||
@ -2080,7 +2081,8 @@ class DnnSupport {
|
||||
// is no longer in use.
|
||||
virtual port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
|
||||
createRnnDescriptor(int num_layers, int hidden_size, int input_size,
|
||||
int batch_size, dnn::RnnInputMode input_mode,
|
||||
int cell_size, int batch_size,
|
||||
dnn::RnnInputMode input_mode,
|
||||
dnn::RnnDirectionMode direction_mode,
|
||||
dnn::RnnMode rnn_mode, dnn::DataType data_type,
|
||||
const dnn::AlgorithmConfig& algorithm_config,
|
||||
|
@ -336,18 +336,18 @@ bool StreamExecutor::GetBlasGemmAlgorithms(
|
||||
|
||||
port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
|
||||
StreamExecutor::createRnnDescriptor(
|
||||
int num_layers, int hidden_size, int input_size, int batch_size,
|
||||
dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode,
|
||||
dnn::RnnMode rnn_mode, dnn::DataType data_type,
|
||||
const dnn::AlgorithmConfig &algorithm_config, float dropout, uint64 seed,
|
||||
ScratchAllocator *state_allocator) {
|
||||
int num_layers, int hidden_size, int input_size, int cell_size,
|
||||
int batch_size, dnn::RnnInputMode input_mode,
|
||||
dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
|
||||
dnn::DataType data_type, const dnn::AlgorithmConfig &algorithm_config,
|
||||
float dropout, uint64 seed, ScratchAllocator *state_allocator) {
|
||||
dnn::DnnSupport *dnn_support = AsDnn();
|
||||
if (!dnn_support) {
|
||||
return port::Status(port::error::UNKNOWN,
|
||||
"Fail to find the dnn implementation.");
|
||||
}
|
||||
return dnn_support->createRnnDescriptor(
|
||||
num_layers, hidden_size, input_size, batch_size, input_mode,
|
||||
num_layers, hidden_size, input_size, cell_size, batch_size, input_mode,
|
||||
direction_mode, rnn_mode, data_type, algorithm_config, dropout, seed,
|
||||
state_allocator);
|
||||
}
|
||||
|
@ -394,11 +394,11 @@ class StreamExecutor {
|
||||
// Create an RNN descriptor based on model shapes and configurations.
|
||||
// The caller retains the ownership of the descriptor.
|
||||
port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor(
|
||||
int num_layers, int hidden_size, int input_size, int batch_size,
|
||||
dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode,
|
||||
dnn::RnnMode rnn_mode, dnn::DataType data_type,
|
||||
const dnn::AlgorithmConfig &algorithm_config, float dropout, uint64 seed,
|
||||
ScratchAllocator *state_allocator);
|
||||
int num_layers, int hidden_size, int input_size, int cell_size,
|
||||
int batch_size, dnn::RnnInputMode input_mode,
|
||||
dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
|
||||
dnn::DataType data_type, const dnn::AlgorithmConfig &algorithm_config,
|
||||
float dropout, uint64 seed, ScratchAllocator *state_allocator);
|
||||
|
||||
// Create a RNN sequence descriptor that specifies either the input or output
|
||||
// sequence. The caller retains the ownership of the returned descriptor.
|
||||
|
@ -770,27 +770,35 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "CudnnRNNBackpropV3"
|
||||
argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'sequence_lengths\', \'output\', \'output_h\', \'output_c\', \'output_backprop\', \'output_h_backprop\', \'output_c_backprop\', \'reserve_space\', \'host_reserved\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'time_major\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'True\', \'None\'], "
|
||||
argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'sequence_lengths\', \'output\', \'output_h\', \'output_c\', \'output_backprop\', \'output_h_backprop\', \'output_c_backprop\', \'reserve_space\', \'host_reserved\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'num_proj\', \'time_major\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'0\', \'True\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "CudnnRNNCanonicalToParams"
|
||||
argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'weights\', \'biases\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "CudnnRNNCanonicalToParamsV2"
|
||||
argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'weights\', \'biases\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'num_proj\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "CudnnRNNParamsSize"
|
||||
argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'T\', \'S\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'None\'], "
|
||||
argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'T\', \'S\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'num_proj\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "CudnnRNNParamsToCanonical"
|
||||
argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'params\', \'num_params\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "CudnnRNNParamsToCanonicalV2"
|
||||
argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'params\', \'num_params_weights\', \'num_params_biases\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'num_proj\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "CudnnRNNV2"
|
||||
argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'is_training\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'True\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "CudnnRNNV3"
|
||||
argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'sequence_lengths\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'is_training\', \'time_major\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'True\', \'True\', \'None\'], "
|
||||
argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'sequence_lengths\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'num_proj\', \'is_training\', \'time_major\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'0\', \'True\', \'True\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "Cumprod"
|
||||
|
@ -770,27 +770,35 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "CudnnRNNBackpropV3"
|
||||
argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'sequence_lengths\', \'output\', \'output_h\', \'output_c\', \'output_backprop\', \'output_h_backprop\', \'output_c_backprop\', \'reserve_space\', \'host_reserved\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'time_major\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'True\', \'None\'], "
|
||||
argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'sequence_lengths\', \'output\', \'output_h\', \'output_c\', \'output_backprop\', \'output_h_backprop\', \'output_c_backprop\', \'reserve_space\', \'host_reserved\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'num_proj\', \'time_major\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'0\', \'True\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "CudnnRNNCanonicalToParams"
|
||||
argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'weights\', \'biases\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "CudnnRNNCanonicalToParamsV2"
|
||||
argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'weights\', \'biases\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'num_proj\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "CudnnRNNParamsSize"
|
||||
argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'T\', \'S\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'None\'], "
|
||||
argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'T\', \'S\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'num_proj\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "CudnnRNNParamsToCanonical"
|
||||
argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'params\', \'num_params\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "CudnnRNNParamsToCanonicalV2"
|
||||
argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'params\', \'num_params_weights\', \'num_params_biases\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'num_proj\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'0\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "CudnnRNNV2"
|
||||
argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'is_training\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'True\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "CudnnRNNV3"
|
||||
argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'sequence_lengths\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'is_training\', \'time_major\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'True\', \'True\', \'None\'], "
|
||||
argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'sequence_lengths\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'num_proj\', \'is_training\', \'time_major\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'0\', \'True\', \'True\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "Cumprod"
|
||||
|
Loading…
Reference in New Issue
Block a user