add cudnn lstm projection (lstmp)
This commit is contained in:
parent
9318759bfa
commit
9ff2778789
@ -71,7 +71,8 @@ def RunLSTM(sess,
|
|||||||
is_training=True,
|
is_training=True,
|
||||||
dropout=0.,
|
dropout=0.,
|
||||||
num_dirs=True,
|
num_dirs=True,
|
||||||
dtype=dtypes.float32):
|
dtype=dtypes.float32,
|
||||||
|
num_proj=None):
|
||||||
# TODO(jamesqin): add multi-layer tests.
|
# TODO(jamesqin): add multi-layer tests.
|
||||||
# TODO(jamesqin): add multi-dir tests
|
# TODO(jamesqin): add multi-dir tests
|
||||||
assert num_layers == 1
|
assert num_layers == 1
|
||||||
@ -94,7 +95,8 @@ def RunLSTM(sess,
|
|||||||
initial_h_op = variable_scope.get_variable(
|
initial_h_op = variable_scope.get_variable(
|
||||||
"initial_h_op",
|
"initial_h_op",
|
||||||
initializer=np.random.rand(batch_size,
|
initializer=np.random.rand(batch_size,
|
||||||
num_units).astype(dtype.as_numpy_dtype),
|
num_proj if num_proj else num_units)
|
||||||
|
.astype(dtype.as_numpy_dtype),
|
||||||
dtype=dtype)
|
dtype=dtype)
|
||||||
initial_c_op = variable_scope.get_variable(
|
initial_c_op = variable_scope.get_variable(
|
||||||
"initial_c_op",
|
"initial_c_op",
|
||||||
@ -115,13 +117,19 @@ def RunLSTM(sess,
|
|||||||
with variable_scope.variable_scope("test", initializer=initializer):
|
with variable_scope.variable_scope("test", initializer=initializer):
|
||||||
w = variable_scope.get_variable(
|
w = variable_scope.get_variable(
|
||||||
"rnn/lstm_cell/kernel",
|
"rnn/lstm_cell/kernel",
|
||||||
shape=[input_size + num_units, num_units * 4],
|
shape=[input_size + (num_proj if num_proj else num_units),
|
||||||
|
num_units * 4],
|
||||||
dtype=dtype)
|
dtype=dtype)
|
||||||
b = variable_scope.get_variable(
|
b = variable_scope.get_variable(
|
||||||
"rnn/lstm_cell/bias", shape=[num_units * 4], dtype=dtype)
|
"rnn/lstm_cell/bias", shape=[num_units * 4], dtype=dtype)
|
||||||
|
if num_proj:
|
||||||
|
pw = variable_scope.get_variable(
|
||||||
|
"rnn/lstm_cell/projection/kernel",
|
||||||
|
shape=[num_units, num_proj], dtype=dtype)
|
||||||
|
|
||||||
# canonical lstm. must set forget_bias to 0. to align with cudnn lstm.
|
# canonical lstm. must set forget_bias to 0. to align with cudnn lstm.
|
||||||
cell = rnn_cell_impl.LSTMCell(num_units, forget_bias=0., reuse=True)
|
cell = rnn_cell_impl.LSTMCell(num_units, forget_bias=0., reuse=True,
|
||||||
|
num_proj=num_proj if num_proj else None)
|
||||||
outputs_op, state_tuple_op = rnn.dynamic_rnn(
|
outputs_op, state_tuple_op = rnn.dynamic_rnn(
|
||||||
cell,
|
cell,
|
||||||
inputs_static,
|
inputs_static,
|
||||||
@ -134,8 +142,12 @@ def RunLSTM(sess,
|
|||||||
|
|
||||||
# Convert to cudnn opaque param.
|
# Convert to cudnn opaque param.
|
||||||
format_converter = cudnn_rnn_ops.CudnnParamsFormatConverterLSTM(
|
format_converter = cudnn_rnn_ops.CudnnParamsFormatConverterLSTM(
|
||||||
num_layers, num_units, input_size)
|
num_layers, num_units, input_size,
|
||||||
opaque_params = format_converter.tf_canonical_to_opaque([w, b])
|
num_proj=num_proj if num_proj else None)
|
||||||
|
if num_proj:
|
||||||
|
opaque_params = format_converter.tf_canonical_to_opaque([w, b], [pw,])
|
||||||
|
else:
|
||||||
|
opaque_params = format_converter.tf_canonical_to_opaque([w, b])
|
||||||
|
|
||||||
cu_initial_h_op = array_ops.expand_dims(
|
cu_initial_h_op = array_ops.expand_dims(
|
||||||
initial_h_op, axis=(0 if time_major else 1))
|
initial_h_op, axis=(0 if time_major else 1))
|
||||||
@ -150,16 +162,22 @@ def RunLSTM(sess,
|
|||||||
time_major=time_major,
|
time_major=time_major,
|
||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
is_training=is_training,
|
is_training=is_training,
|
||||||
rnn_mode=cudnn_rnn_ops.CUDNN_LSTM)
|
rnn_mode=cudnn_rnn_ops.CUDNN_LSTM,
|
||||||
|
num_proj=num_proj if num_proj else None)
|
||||||
# Remove the trivial 1st dimension.
|
# Remove the trivial 1st dimension.
|
||||||
cu_state_tuple_op = rnn_cell_impl.LSTMStateTuple(
|
cu_state_tuple_op = rnn_cell_impl.LSTMStateTuple(
|
||||||
c=array_ops.squeeze(cu_c_op, axis=0 if time_major else 1),
|
c=array_ops.squeeze(cu_c_op, axis=0 if time_major else 1),
|
||||||
h=array_ops.squeeze(cu_h_op, axis=0 if time_major else 1))
|
h=array_ops.squeeze(cu_h_op, axis=0 if time_major else 1))
|
||||||
|
|
||||||
if is_training:
|
if is_training:
|
||||||
(inp_grad_op, hgrad_op,
|
if num_proj:
|
||||||
cgrad_op, wgrad_op, bgrad_op) = gradients_impl.gradients(
|
(inp_grad_op, hgrad_op,
|
||||||
outputs_op, [inputs_static, initial_h_op, initial_c_op, w, b])
|
cgrad_op, wgrad_op, bgrad_op, pwgrad_op) = gradients_impl.gradients(
|
||||||
|
outputs_op, [inputs_static, initial_h_op, initial_c_op, w, b, pw])
|
||||||
|
else:
|
||||||
|
(inp_grad_op, hgrad_op,
|
||||||
|
cgrad_op, wgrad_op, bgrad_op) = gradients_impl.gradients(
|
||||||
|
outputs_op, [inputs_static, initial_h_op, initial_c_op, w, b])
|
||||||
|
|
||||||
(cu_inp_grad_op, cu_hgrad_op,
|
(cu_inp_grad_op, cu_hgrad_op,
|
||||||
cu_cgrad_op, opaque_grad_op) = gradients_impl.gradients(
|
cu_cgrad_op, opaque_grad_op) = gradients_impl.gradients(
|
||||||
@ -170,10 +188,16 @@ def RunLSTM(sess,
|
|||||||
# Remove the trivial 1st dimension
|
# Remove the trivial 1st dimension
|
||||||
cu_cgrad_op = array_ops.squeeze(cu_cgrad_op, axis=0 if time_major else 1)
|
cu_cgrad_op = array_ops.squeeze(cu_cgrad_op, axis=0 if time_major else 1)
|
||||||
|
|
||||||
cu_wgrad_op, cu_bgrad_op = format_converter.opaque_to_tf_canonical(
|
if num_proj:
|
||||||
opaque_grad_op)
|
cu_wgrad_op, cu_bgrad_op, cu_pwgrad_op = \
|
||||||
|
format_converter.opaque_to_tf_canonical(opaque_grad_op)
|
||||||
|
else:
|
||||||
|
cu_wgrad_op, cu_bgrad_op = format_converter.opaque_to_tf_canonical(
|
||||||
|
opaque_grad_op)
|
||||||
cu_wgrad_op = cu_wgrad_op[0]
|
cu_wgrad_op = cu_wgrad_op[0]
|
||||||
cu_bgrad_op = cu_bgrad_op[0]
|
cu_bgrad_op = cu_bgrad_op[0]
|
||||||
|
if num_proj:
|
||||||
|
cu_pwgrad_op = cu_pwgrad_op[0]
|
||||||
# cudnn lstm has 2 biases each gate. When converting to tf canonical format,
|
# cudnn lstm has 2 biases each gate. When converting to tf canonical format,
|
||||||
# the two biases are summed into one. Thus here bias gradient should be
|
# the two biases are summed into one. Thus here bias gradient should be
|
||||||
# halved when comparing with tf lstm.
|
# halved when comparing with tf lstm.
|
||||||
@ -183,18 +207,30 @@ def RunLSTM(sess,
|
|||||||
sess.run(init_op)
|
sess.run(init_op)
|
||||||
|
|
||||||
if is_training:
|
if is_training:
|
||||||
outputs, state_tuple, inp_grad, state_grad, wgrad, bgrad = sess.run([
|
if num_proj:
|
||||||
outputs_op, state_tuple_op, inp_grad_op,
|
outputs, state_tuple, inp_grad, state_grad, wgrad, bgrad, pwgrad = \
|
||||||
(hgrad_op, cgrad_op), wgrad_op, bgrad_op
|
sess.run([
|
||||||
])
|
outputs_op, state_tuple_op, inp_grad_op,
|
||||||
(cu_outputs, cu_state_tuple, cu_inp_grad, cu_state_grad, cu_wgrad,
|
(hgrad_op, cgrad_op), wgrad_op, bgrad_op, pwgrad_op
|
||||||
cu_bgrad) = sess.run(
|
])
|
||||||
[
|
(cu_outputs, cu_state_tuple, cu_inp_grad, cu_state_grad, cu_wgrad,
|
||||||
cu_outputs_op, cu_state_tuple_op, cu_inp_grad_op,
|
cu_bgrad, cu_pwgrad) = sess.run([
|
||||||
(cu_hgrad_op, cu_cgrad_op), cu_wgrad_op, cu_bgrad_op
|
cu_outputs_op, cu_state_tuple_op, cu_inp_grad_op,
|
||||||
],
|
(cu_hgrad_op, cu_cgrad_op), cu_wgrad_op, cu_bgrad_op, cu_pwgrad_op
|
||||||
feed_dict={inputs: inputs_np} if dynamic_shape_input else None)
|
],
|
||||||
|
feed_dict={inputs: inputs_np} if dynamic_shape_input else None)
|
||||||
|
else:
|
||||||
|
outputs, state_tuple, inp_grad, state_grad, wgrad, bgrad = sess.run([
|
||||||
|
outputs_op, state_tuple_op, inp_grad_op,
|
||||||
|
(hgrad_op, cgrad_op), wgrad_op, bgrad_op
|
||||||
|
])
|
||||||
|
(cu_outputs, cu_state_tuple, cu_inp_grad, cu_state_grad, cu_wgrad,
|
||||||
|
cu_bgrad) = sess.run([
|
||||||
|
cu_outputs_op, cu_state_tuple_op, cu_inp_grad_op,
|
||||||
|
(cu_hgrad_op, cu_cgrad_op), cu_wgrad_op, cu_bgrad_op
|
||||||
|
],
|
||||||
|
feed_dict={inputs: inputs_np} if dynamic_shape_input else None)
|
||||||
|
|
||||||
logging.vlog(1, "outputs: %s" % outputs)
|
logging.vlog(1, "outputs: %s" % outputs)
|
||||||
logging.vlog(1, "cu_outputs: %s" % cu_outputs)
|
logging.vlog(1, "cu_outputs: %s" % cu_outputs)
|
||||||
logging.vlog(1, "state_tuple: %s" % str(state_tuple))
|
logging.vlog(1, "state_tuple: %s" % str(state_tuple))
|
||||||
@ -205,11 +241,20 @@ def RunLSTM(sess,
|
|||||||
logging.vlog(1, "cu_state_grad: %s" % str(cu_state_grad))
|
logging.vlog(1, "cu_state_grad: %s" % str(cu_state_grad))
|
||||||
logging.vlog(1, "wgrad: %s" % str(wgrad))
|
logging.vlog(1, "wgrad: %s" % str(wgrad))
|
||||||
logging.vlog(1, "bgrad: %s" % str(bgrad))
|
logging.vlog(1, "bgrad: %s" % str(bgrad))
|
||||||
|
if num_proj:
|
||||||
|
logging.vlog(1, "pwgrad: %s" % str(bgrad))
|
||||||
logging.vlog(1, "cu_wgrad: %s" % str(cu_wgrad))
|
logging.vlog(1, "cu_wgrad: %s" % str(cu_wgrad))
|
||||||
logging.vlog(1, "cu_bgrad: %s" % str(cu_bgrad))
|
logging.vlog(1, "cu_bgrad: %s" % str(cu_bgrad))
|
||||||
return (outputs, cu_outputs, state_tuple, cu_state_tuple, inp_grad,
|
if num_proj:
|
||||||
cu_inp_grad, state_grad, cu_state_grad, wgrad, bgrad, cu_wgrad,
|
logging.vlog(1, "cu_pwgrad: %s" % str(cu_bgrad))
|
||||||
cu_bgrad)
|
if num_proj:
|
||||||
|
return (outputs, cu_outputs, state_tuple, cu_state_tuple, inp_grad,
|
||||||
|
cu_inp_grad, state_grad, cu_state_grad, wgrad, bgrad, pwgrad,
|
||||||
|
cu_wgrad, cu_bgrad, cu_pwgrad)
|
||||||
|
else:
|
||||||
|
return (outputs, cu_outputs, state_tuple, cu_state_tuple, inp_grad,
|
||||||
|
cu_inp_grad, state_grad, cu_state_grad, wgrad, bgrad, cu_wgrad,
|
||||||
|
cu_bgrad)
|
||||||
else:
|
else:
|
||||||
outputs, state_tuple = sess.run([outputs_op, state_tuple_op])
|
outputs, state_tuple = sess.run([outputs_op, state_tuple_op])
|
||||||
cu_outputs, cu_state_tuple = sess.run([cu_outputs_op, cu_state_tuple_op],
|
cu_outputs, cu_state_tuple = sess.run([cu_outputs_op, cu_state_tuple_op],
|
||||||
@ -226,35 +271,43 @@ def RunLSTM(sess,
|
|||||||
|
|
||||||
# Basic set of RNN configs to test. They can be further extended in relevant
|
# Basic set of RNN configs to test. They can be further extended in relevant
|
||||||
# test (e.g. adding num_dirs).
|
# test (e.g. adding num_dirs).
|
||||||
|
#NAMED_RNN_TESTCASES = ({
|
||||||
|
# "testcase_name": "xsmall",
|
||||||
|
# "num_units": 1,
|
||||||
|
# "input_size": 1,
|
||||||
|
# "batch_size": 1,
|
||||||
|
# "time": 1,
|
||||||
|
# "num_layers": 1,
|
||||||
|
#}, {
|
||||||
|
# "testcase_name": "small",
|
||||||
|
# "num_units": 4,
|
||||||
|
# "input_size": 4,
|
||||||
|
# "batch_size": 4,
|
||||||
|
# "time": 4,
|
||||||
|
# "num_layers": 1,
|
||||||
|
#}, {
|
||||||
|
# "testcase_name": "medium",
|
||||||
|
# "num_units": 128,
|
||||||
|
# "input_size": 64,
|
||||||
|
# "batch_size": 8,
|
||||||
|
# "time": 16,
|
||||||
|
# "num_layers": 1,
|
||||||
|
#}, {
|
||||||
|
# "testcase_name": "large",
|
||||||
|
# "num_units": 128,
|
||||||
|
# "input_size": 128,
|
||||||
|
# "batch_size": 16,
|
||||||
|
# "time": 32,
|
||||||
|
# "num_layers": 1,
|
||||||
|
#})
|
||||||
NAMED_RNN_TESTCASES = ({
|
NAMED_RNN_TESTCASES = ({
|
||||||
"testcase_name": "xsmall",
|
|
||||||
"num_units": 1,
|
|
||||||
"input_size": 1,
|
|
||||||
"batch_size": 1,
|
|
||||||
"time": 1,
|
|
||||||
"num_layers": 1,
|
|
||||||
}, {
|
|
||||||
"testcase_name": "small",
|
"testcase_name": "small",
|
||||||
"num_units": 4,
|
"num_units": 4,
|
||||||
"input_size": 4,
|
"input_size": 4,
|
||||||
"batch_size": 4,
|
"batch_size": 4,
|
||||||
"time": 4,
|
"time": 4,
|
||||||
"num_layers": 1,
|
"num_layers": 1,
|
||||||
}, {
|
}, )
|
||||||
"testcase_name": "medium",
|
|
||||||
"num_units": 128,
|
|
||||||
"input_size": 64,
|
|
||||||
"batch_size": 8,
|
|
||||||
"time": 16,
|
|
||||||
"num_layers": 1,
|
|
||||||
}, {
|
|
||||||
"testcase_name": "large",
|
|
||||||
"num_units": 128,
|
|
||||||
"input_size": 128,
|
|
||||||
"batch_size": 16,
|
|
||||||
"time": 32,
|
|
||||||
"num_layers": 1,
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
def ExpandNamedTestCases(inputs, *remove_keys, **extra_configs):
|
def ExpandNamedTestCases(inputs, *remove_keys, **extra_configs):
|
||||||
@ -349,19 +402,22 @@ class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
time_major,
|
time_major,
|
||||||
dynamic_shape_input=False,
|
dynamic_shape_input=False,
|
||||||
rtol=3e-6,
|
rtol=3e-6,
|
||||||
atol=3e-6):
|
atol=3e-6,
|
||||||
|
num_proj=None):
|
||||||
with self.session(use_gpu=True) as sess:
|
with self.session(use_gpu=True) as sess:
|
||||||
(outputs, cu_outputs, state_tuple, cu_state_tuple, inp_grad, cu_inp_grad,
|
if num_proj is not None and num_proj != 0:
|
||||||
state_grad, cu_state_grad, wgrad, bgrad, cu_wgrad, cu_bgrad) = RunLSTM(
|
(outputs, cu_outputs, state_tuple, cu_state_tuple, inp_grad,
|
||||||
sess,
|
cu_inp_grad, state_grad, cu_state_grad, wgrad, bgrad, pwgrad, cu_wgrad,
|
||||||
num_units,
|
cu_bgrad, cu_pwgrad) = RunLSTM(
|
||||||
input_size,
|
sess, num_units, input_size, batch_size, time, num_layers,
|
||||||
batch_size,
|
variable_seq_lengths=variable_seq_lengths,
|
||||||
time,
|
dynamic_shape_input=dynamic_shape_input, num_proj=num_proj)
|
||||||
num_layers,
|
else:
|
||||||
variable_seq_lengths=variable_seq_lengths,
|
(outputs, cu_outputs, state_tuple, cu_state_tuple, inp_grad,
|
||||||
time_major=time_major,
|
cu_inp_grad, state_grad, cu_state_grad, wgrad, bgrad, cu_wgrad,
|
||||||
dynamic_shape_input=dynamic_shape_input)
|
cu_bgrad) = RunLSTM(sess, num_units, input_size, batch_size, time,
|
||||||
|
num_layers, variable_seq_lengths=variable_seq_lengths,
|
||||||
|
dynamic_shape_input=dynamic_shape_input, num_proj=num_proj)
|
||||||
|
|
||||||
self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol)
|
self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol)
|
||||||
for s, cu_s in zip(state_tuple, cu_state_tuple):
|
for s, cu_s in zip(state_tuple, cu_state_tuple):
|
||||||
@ -371,6 +427,8 @@ class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
self.assertAllClose(inp_grad, cu_inp_grad, rtol=rtol, atol=atol)
|
self.assertAllClose(inp_grad, cu_inp_grad, rtol=rtol, atol=atol)
|
||||||
self.assertAllClose(bgrad, cu_bgrad, rtol=rtol, atol=atol)
|
self.assertAllClose(bgrad, cu_bgrad, rtol=rtol, atol=atol)
|
||||||
self.assertAllClose(wgrad, cu_wgrad, rtol=rtol, atol=atol)
|
self.assertAllClose(wgrad, cu_wgrad, rtol=rtol, atol=atol)
|
||||||
|
if num_proj is not None and num_proj != 0:
|
||||||
|
self.assertAllClose(pwgrad, cu_pwgrad, rtol=rtol, atol=atol)
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
ExpandNamedTestCases(
|
ExpandNamedTestCases(
|
||||||
@ -378,10 +436,15 @@ class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
"variable_seq_lengths": [True, False],
|
"variable_seq_lengths": [True, False],
|
||||||
"time_major": [True, False],
|
"time_major": [True, False],
|
||||||
"dynamic_shape_input": [True, False],
|
"dynamic_shape_input": [True, False],
|
||||||
|
"use_proj": [True, False],
|
||||||
}))
|
}))
|
||||||
@test_util.run_gpu_only
|
@test_util.run_gpu_only
|
||||||
def test_training(self, num_units, input_size, batch_size, time, num_layers,
|
def test_training(self, num_units, input_size, batch_size, time, num_layers,
|
||||||
variable_seq_lengths, time_major, dynamic_shape_input):
|
variable_seq_lengths, time_major, dynamic_shape_input,
|
||||||
|
use_proj):
|
||||||
|
num_proj = num_units // 2
|
||||||
|
if use_proj and num_proj == 0:
|
||||||
|
self.skipTest("num_proj cannot be 0")
|
||||||
self._test_training_helper(
|
self._test_training_helper(
|
||||||
num_units,
|
num_units,
|
||||||
input_size,
|
input_size,
|
||||||
@ -391,7 +454,8 @@ class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
dtypes.float32,
|
dtypes.float32,
|
||||||
variable_seq_lengths=variable_seq_lengths,
|
variable_seq_lengths=variable_seq_lengths,
|
||||||
time_major=time_major,
|
time_major=time_major,
|
||||||
dynamic_shape_input=dynamic_shape_input)
|
dynamic_shape_input=dynamic_shape_input,
|
||||||
|
num_proj=num_proj if use_proj else None)
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
ExpandNamedTestCases(
|
ExpandNamedTestCases(
|
||||||
@ -399,11 +463,15 @@ class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
"variable_seq_lengths": [True, False],
|
"variable_seq_lengths": [True, False],
|
||||||
"time_major": [True, False],
|
"time_major": [True, False],
|
||||||
"dynamic_shape_input": [True, False],
|
"dynamic_shape_input": [True, False],
|
||||||
|
"use_proj": [True, False],
|
||||||
}))
|
}))
|
||||||
@test_util.run_gpu_only
|
@test_util.run_gpu_only
|
||||||
def test_training_fp16(self, num_units, input_size, batch_size, time,
|
def test_training_fp16(self, num_units, input_size, batch_size, time,
|
||||||
num_layers, variable_seq_lengths, time_major,
|
num_layers, variable_seq_lengths, time_major,
|
||||||
dynamic_shape_input):
|
dynamic_shape_input, use_proj):
|
||||||
|
num_proj = num_units // 2
|
||||||
|
if use_proj and num_proj == 0:
|
||||||
|
self.skipTest("num_proj cannot be 0")
|
||||||
self._test_training_helper(
|
self._test_training_helper(
|
||||||
num_units,
|
num_units,
|
||||||
input_size,
|
input_size,
|
||||||
@ -415,7 +483,8 @@ class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
atol=5e-4,
|
atol=5e-4,
|
||||||
variable_seq_lengths=variable_seq_lengths,
|
variable_seq_lengths=variable_seq_lengths,
|
||||||
time_major=time_major,
|
time_major=time_major,
|
||||||
dynamic_shape_input=dynamic_shape_input)
|
dynamic_shape_input=dynamic_shape_input,
|
||||||
|
num_proj=num_proj if use_proj else None)
|
||||||
|
|
||||||
@parameterized.named_parameters(
|
@parameterized.named_parameters(
|
||||||
ExpandNamedTestCases(
|
ExpandNamedTestCases(
|
||||||
@ -423,10 +492,15 @@ class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
"variable_seq_lengths": [True, False],
|
"variable_seq_lengths": [True, False],
|
||||||
"time_major": [True, False],
|
"time_major": [True, False],
|
||||||
"dynamic_shape_input": [True, False],
|
"dynamic_shape_input": [True, False],
|
||||||
|
"use_proj": [True, False],
|
||||||
}))
|
}))
|
||||||
@test_util.run_gpu_only
|
@test_util.run_gpu_only
|
||||||
def test_inference(self, num_units, input_size, batch_size, time, num_layers,
|
def test_inference(self, num_units, input_size, batch_size, time, num_layers,
|
||||||
variable_seq_lengths, time_major, dynamic_shape_input):
|
variable_seq_lengths, time_major, dynamic_shape_input,
|
||||||
|
use_proj):
|
||||||
|
num_proj = num_units // 2
|
||||||
|
if use_proj and num_proj == 0:
|
||||||
|
self.skipTest("num_proj cannot be 0")
|
||||||
with self.session(use_gpu=True) as sess:
|
with self.session(use_gpu=True) as sess:
|
||||||
(outputs, cu_outputs, state_tuple, cu_state_tuple) = RunLSTM(
|
(outputs, cu_outputs, state_tuple, cu_state_tuple) = RunLSTM(
|
||||||
sess,
|
sess,
|
||||||
@ -438,7 +512,8 @@ class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
is_training=False,
|
is_training=False,
|
||||||
variable_seq_lengths=variable_seq_lengths,
|
variable_seq_lengths=variable_seq_lengths,
|
||||||
time_major=time_major,
|
time_major=time_major,
|
||||||
dynamic_shape_input=dynamic_shape_input)
|
dynamic_shape_input=dynamic_shape_input,
|
||||||
|
num_proj=num_proj if use_proj else None)
|
||||||
|
|
||||||
self.assertAllClose(outputs, cu_outputs)
|
self.assertAllClose(outputs, cu_outputs)
|
||||||
# h
|
# h
|
||||||
@ -452,11 +527,15 @@ class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
"variable_seq_lengths": [True, False],
|
"variable_seq_lengths": [True, False],
|
||||||
"time_major": [True, False],
|
"time_major": [True, False],
|
||||||
"dynamic_shape_input": [True, False],
|
"dynamic_shape_input": [True, False],
|
||||||
|
"use_proj": [True, False],
|
||||||
}))
|
}))
|
||||||
@test_util.run_gpu_only
|
@test_util.run_gpu_only
|
||||||
def test_inference_fp16(self, num_units, input_size, batch_size, time,
|
def test_inference_fp16(self, num_units, input_size, batch_size, time,
|
||||||
num_layers, variable_seq_lengths, time_major,
|
num_layers, variable_seq_lengths, time_major,
|
||||||
dynamic_shape_input):
|
dynamic_shape_input, use_proj):
|
||||||
|
num_proj = num_units // 2
|
||||||
|
if use_proj and num_proj == 0:
|
||||||
|
self.skipTest("num_proj cannot be 0")
|
||||||
with self.session(use_gpu=True) as sess:
|
with self.session(use_gpu=True) as sess:
|
||||||
(outputs, cu_outputs, state_tuple, cu_state_tuple) = RunLSTM(
|
(outputs, cu_outputs, state_tuple, cu_state_tuple) = RunLSTM(
|
||||||
sess,
|
sess,
|
||||||
@ -469,7 +548,8 @@ class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
dtype=dtypes.float16,
|
dtype=dtypes.float16,
|
||||||
variable_seq_lengths=variable_seq_lengths,
|
variable_seq_lengths=variable_seq_lengths,
|
||||||
time_major=time_major,
|
time_major=time_major,
|
||||||
dynamic_shape_input=dynamic_shape_input)
|
dynamic_shape_input=dynamic_shape_input,
|
||||||
|
num_proj=num_proj if use_proj else None)
|
||||||
|
|
||||||
rtol, atol = 5e-3, 5e-4
|
rtol, atol = 5e-3, 5e-4
|
||||||
self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol)
|
self.assertAllClose(outputs, cu_outputs, rtol=rtol, atol=atol)
|
||||||
@ -486,12 +566,16 @@ class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
"variable_seq_lengths": [True, False],
|
"variable_seq_lengths": [True, False],
|
||||||
"time_major": [True, False],
|
"time_major": [True, False],
|
||||||
"dynamic_shape_input": [True, False],
|
"dynamic_shape_input": [True, False],
|
||||||
|
"use_proj": [True, False],
|
||||||
}))
|
}))
|
||||||
@test_util.run_gpu_only
|
@test_util.run_gpu_only
|
||||||
def test_inference_with_dropout(self, num_units, input_size, batch_size, time,
|
def test_inference_with_dropout(self, num_units, input_size, batch_size, time,
|
||||||
num_layers, variable_seq_lengths, time_major,
|
num_layers, variable_seq_lengths, time_major,
|
||||||
dynamic_shape_input):
|
dynamic_shape_input, use_proj):
|
||||||
"""Validates that dropout does not affect Cudnn Rnn inference."""
|
"""Validates that dropout does not affect Cudnn Rnn inference."""
|
||||||
|
num_proj = num_units // 2
|
||||||
|
if use_proj and num_proj == 0:
|
||||||
|
self.skipTest("num_proj cannot be 0")
|
||||||
# Hand-picked dropouts are used below (0. and 1.)
|
# Hand-picked dropouts are used below (0. and 1.)
|
||||||
with ops.Graph().as_default() as g:
|
with ops.Graph().as_default() as g:
|
||||||
with self.session(use_gpu=True, graph=g) as sess:
|
with self.session(use_gpu=True, graph=g) as sess:
|
||||||
@ -507,7 +591,8 @@ class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
dropout=0.,
|
dropout=0.,
|
||||||
variable_seq_lengths=variable_seq_lengths,
|
variable_seq_lengths=variable_seq_lengths,
|
||||||
time_major=time_major,
|
time_major=time_major,
|
||||||
dynamic_shape_input=dynamic_shape_input)
|
dynamic_shape_input=dynamic_shape_input,
|
||||||
|
num_proj=num_proj if use_proj else None)
|
||||||
|
|
||||||
with ops.Graph().as_default() as g:
|
with ops.Graph().as_default() as g:
|
||||||
with self.session(use_gpu=True, graph=g) as sess:
|
with self.session(use_gpu=True, graph=g) as sess:
|
||||||
@ -522,7 +607,8 @@ class CudnnLSTMTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
dropout=1.,
|
dropout=1.,
|
||||||
variable_seq_lengths=variable_seq_lengths,
|
variable_seq_lengths=variable_seq_lengths,
|
||||||
time_major=time_major,
|
time_major=time_major,
|
||||||
dynamic_shape_input=dynamic_shape_input)
|
dynamic_shape_input=dynamic_shape_input,
|
||||||
|
num_proj=num_proj if use_proj else None)
|
||||||
|
|
||||||
self.assertAllClose(cu_outputs, cu_outputs2)
|
self.assertAllClose(cu_outputs, cu_outputs2)
|
||||||
# h
|
# h
|
||||||
@ -890,40 +976,61 @@ class CudnnParamsFormatConverterTest(test_util.TensorFlowTestCase,
|
|||||||
parameterized.TestCase):
|
parameterized.TestCase):
|
||||||
"""Class for testing various format converters."""
|
"""Class for testing various format converters."""
|
||||||
|
|
||||||
def _test_lstm_helper(self, num_units, input_size, num_layers, direction):
|
def _test_lstm_helper(self, num_units, input_size, num_layers, direction,
|
||||||
|
num_proj=None):
|
||||||
with self.session(use_gpu=True) as sess:
|
with self.session(use_gpu=True) as sess:
|
||||||
random_seed.set_random_seed(0)
|
random_seed.set_random_seed(0)
|
||||||
np.random.seed(0)
|
np.random.seed(0)
|
||||||
|
|
||||||
num_dirs = 1 if direction == cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION else 2
|
num_dirs = 1 if direction == cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION else 2
|
||||||
format_converter = cudnn_rnn_ops.CudnnParamsFormatConverterLSTM(
|
format_converter = cudnn_rnn_ops.CudnnParamsFormatConverterLSTM(
|
||||||
num_layers, num_units, input_size, direction=direction)
|
num_layers, num_units, input_size, direction=direction,
|
||||||
|
num_proj=num_proj if num_proj else None)
|
||||||
|
|
||||||
ws, bs = [], []
|
ws, bs, pws = [], [], []
|
||||||
for _ in range(num_layers * num_dirs):
|
for _ in range(num_layers * num_dirs):
|
||||||
w = constant_op.constant(
|
w = constant_op.constant(
|
||||||
np.random.rand(input_size + num_units, 4 * num_units),
|
np.random.rand(input_size + (num_proj if num_proj else num_units),
|
||||||
|
4 * num_units),
|
||||||
dtype=dtypes.float32)
|
dtype=dtypes.float32)
|
||||||
b = constant_op.constant(
|
b = constant_op.constant(
|
||||||
np.random.rand(4 * num_units), dtype=dtypes.float32)
|
np.random.rand(4 * num_units), dtype=dtypes.float32)
|
||||||
ws.append(w)
|
ws.append(w)
|
||||||
bs.append(b)
|
bs.append(b)
|
||||||
|
if num_proj:
|
||||||
|
pw = constant_op.constant(
|
||||||
|
np.random.rand(num_units, num_proj), dtype=dtypes.float32)
|
||||||
|
pws.append(pw)
|
||||||
|
|
||||||
|
if num_proj:
|
||||||
|
opaque_params = format_converter.tf_canonical_to_opaque(ws + bs, pws)
|
||||||
|
else:
|
||||||
|
opaque_params = format_converter.tf_canonical_to_opaque(ws + bs)
|
||||||
|
|
||||||
opaque_params = format_converter.tf_canonical_to_opaque(ws + bs)
|
|
||||||
opaque_params_size = cudnn_rnn_ops.cudnn_rnn_opaque_params_size(
|
opaque_params_size = cudnn_rnn_ops.cudnn_rnn_opaque_params_size(
|
||||||
cudnn_rnn_ops.CUDNN_LSTM,
|
cudnn_rnn_ops.CUDNN_LSTM,
|
||||||
num_layers,
|
num_layers,
|
||||||
num_units,
|
num_units,
|
||||||
input_size,
|
input_size,
|
||||||
direction=direction)
|
direction=direction,
|
||||||
|
num_proj=num_proj if num_proj else None)
|
||||||
|
|
||||||
ws_r, bs_r = format_converter.opaque_to_tf_canonical(opaque_params)
|
if num_proj:
|
||||||
|
ws_r, bs_r, pws_r = format_converter.opaque_to_tf_canonical(
|
||||||
|
opaque_params)
|
||||||
|
ws, ws_r, pws, bs, bs_r, pws_r = sess.run([ws, ws_r, pws, bs, bs_r,
|
||||||
|
pws_r])
|
||||||
|
else:
|
||||||
|
ws_r, bs_r = format_converter.opaque_to_tf_canonical(opaque_params)
|
||||||
|
ws, ws_r, bs, bs_r = sess.run([ws, ws_r, bs, bs_r])
|
||||||
|
|
||||||
# Test tf_canonical_to_opaque() followed by opaque_to_tf_canonical()
|
# Test tf_canonical_to_opaque() followed by opaque_to_tf_canonical()
|
||||||
# returns the original input.
|
# returns the original input.
|
||||||
ws, ws_r, bs, bs_r = sess.run([ws, ws_r, bs, bs_r])
|
|
||||||
for w, w_r in zip(ws, ws_r):
|
for w, w_r in zip(ws, ws_r):
|
||||||
self.assertAllClose(w, w_r)
|
self.assertAllClose(w, w_r)
|
||||||
|
if num_proj:
|
||||||
|
for pw, pw_r in zip(pws, pws_r):
|
||||||
|
self.assertAllClose(pw, pw_r)
|
||||||
for b, b_r in zip(bs, bs_r):
|
for b, b_r in zip(bs, bs_r):
|
||||||
self.assertAllClose(b, b_r)
|
self.assertAllClose(b, b_r)
|
||||||
|
|
||||||
@ -942,6 +1049,18 @@ class CudnnParamsFormatConverterTest(test_util.TensorFlowTestCase,
|
|||||||
self._test_lstm_helper(num_units, input_size, num_layers,
|
self._test_lstm_helper(num_units, input_size, num_layers,
|
||||||
cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION)
|
cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION)
|
||||||
|
|
||||||
|
@parameterized.named_parameters((c["testcase_name"], c["num_units"],
|
||||||
|
c["input_size"], c["num_layers"])
|
||||||
|
for c in NAMED_RNN_TESTCASES)
|
||||||
|
@test_util.run_gpu_only
|
||||||
|
def test_lstmp(self, num_units, input_size, num_layers):
|
||||||
|
num_proj = num_units // 2
|
||||||
|
if num_proj == 0:
|
||||||
|
self.skipTest("num_proj cannot be 0")
|
||||||
|
self._test_lstm_helper(num_units, input_size, num_layers,
|
||||||
|
cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION,
|
||||||
|
num_proj=num_proj)
|
||||||
|
|
||||||
@parameterized.named_parameters((c["testcase_name"], c["num_units"],
|
@parameterized.named_parameters((c["testcase_name"], c["num_units"],
|
||||||
c["input_size"], c["num_layers"])
|
c["input_size"], c["num_layers"])
|
||||||
for c in NAMED_RNN_TESTCASES)
|
for c in NAMED_RNN_TESTCASES)
|
||||||
@ -950,6 +1069,18 @@ class CudnnParamsFormatConverterTest(test_util.TensorFlowTestCase,
|
|||||||
self._test_lstm_helper(num_units, input_size, num_layers,
|
self._test_lstm_helper(num_units, input_size, num_layers,
|
||||||
cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION)
|
cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION)
|
||||||
|
|
||||||
|
@parameterized.named_parameters((c["testcase_name"], c["num_units"],
|
||||||
|
c["input_size"], c["num_layers"])
|
||||||
|
for c in NAMED_RNN_TESTCASES)
|
||||||
|
@test_util.run_gpu_only
|
||||||
|
def test_lstmp_bidi(self, num_units, input_size, num_layers):
|
||||||
|
num_proj = num_units // 2
|
||||||
|
if num_proj == 0:
|
||||||
|
self.skipTest("num_proj cannot be 0")
|
||||||
|
self._test_lstm_helper(num_units, input_size, num_layers,
|
||||||
|
cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION,
|
||||||
|
num_proj=num_proj)
|
||||||
|
|
||||||
def _test_gru_helper(self, num_units, input_size, num_layers, direction):
|
def _test_gru_helper(self, num_units, input_size, num_layers, direction):
|
||||||
with self.session(use_gpu=True) as sess:
|
with self.session(use_gpu=True) as sess:
|
||||||
random_seed.set_random_seed(0)
|
random_seed.set_random_seed(0)
|
||||||
|
@ -184,6 +184,7 @@ class CudnnParamsFormatConverter(object):
|
|||||||
num_layers,
|
num_layers,
|
||||||
num_units,
|
num_units,
|
||||||
input_size,
|
input_size,
|
||||||
|
num_proj=None,
|
||||||
input_mode=CUDNN_INPUT_LINEAR_MODE,
|
input_mode=CUDNN_INPUT_LINEAR_MODE,
|
||||||
direction=CUDNN_RNN_UNIDIRECTION):
|
direction=CUDNN_RNN_UNIDIRECTION):
|
||||||
"""Constructor.
|
"""Constructor.
|
||||||
@ -193,6 +194,8 @@ class CudnnParamsFormatConverter(object):
|
|||||||
num_units: the number of units within the RNN model.
|
num_units: the number of units within the RNN model.
|
||||||
input_size: the size of the input, it could be different from the
|
input_size: the size of the input, it could be different from the
|
||||||
num_units.
|
num_units.
|
||||||
|
num_proj: The output dimensionality for the projection matrices.
|
||||||
|
If None or 0, no projection is performed.
|
||||||
input_mode: indicate whether there is a linear projection between the
|
input_mode: indicate whether there is a linear projection between the
|
||||||
input and the actual computation before the first layer. It could be one
|
input and the actual computation before the first layer. It could be one
|
||||||
of 'linear_input', 'skip_input' or 'auto_select'. * 'linear_input'
|
of 'linear_input', 'skip_input' or 'auto_select'. * 'linear_input'
|
||||||
@ -207,14 +210,16 @@ class CudnnParamsFormatConverter(object):
|
|||||||
self._input_size = input_size
|
self._input_size = input_size
|
||||||
self._num_units = num_units
|
self._num_units = num_units
|
||||||
self._input_mode = input_mode
|
self._input_mode = input_mode
|
||||||
|
self._num_proj = num_proj
|
||||||
self._direction = direction
|
self._direction = direction
|
||||||
self._num_dirs = 1 if self._direction == CUDNN_RNN_UNIDIRECTION else 2
|
self._num_dirs = 1 if self._direction == CUDNN_RNN_UNIDIRECTION else 2
|
||||||
self._num_params = (
|
self._num_params = (
|
||||||
self._num_params_per_layer * self._num_layers * self._num_dirs)
|
self._num_params_per_layer * self._num_layers * self._num_dirs)
|
||||||
|
|
||||||
def tf_canonical_to_opaque(self, tf_canonicals):
|
def tf_canonical_to_opaque(self, tf_canonicals, weights_proj=None):
|
||||||
r"""Converts tf canonical weights to cudnn opaque param."""
|
r"""Converts tf canonical weights to cudnn opaque param."""
|
||||||
cu_weights, cu_biases = self._tf_canonical_to_cu_canonical(tf_canonicals)
|
cu_weights, cu_biases = self._tf_canonical_to_cu_canonical(tf_canonicals,
|
||||||
|
weights_proj)
|
||||||
cu_weights = [array_ops.reshape(w, [-1]) for w in cu_weights]
|
cu_weights = [array_ops.reshape(w, [-1]) for w in cu_weights]
|
||||||
opaque_params = self._cu_canonical_to_opaque(cu_weights, cu_biases)
|
opaque_params = self._cu_canonical_to_opaque(cu_weights, cu_biases)
|
||||||
return opaque_params
|
return opaque_params
|
||||||
@ -222,8 +227,14 @@ class CudnnParamsFormatConverter(object):
|
|||||||
def opaque_to_tf_canonical(self, opaque_param):
|
def opaque_to_tf_canonical(self, opaque_param):
|
||||||
r"""Converts cudnn opaque param to tf canonical weights."""
|
r"""Converts cudnn opaque param to tf canonical weights."""
|
||||||
cu_weights, cu_biases = self._opaque_to_cu_canonical(opaque_param)
|
cu_weights, cu_biases = self._opaque_to_cu_canonical(opaque_param)
|
||||||
weights, biases = self._cu_canonical_to_tf_canonical(cu_weights, cu_biases)
|
if self._num_proj:
|
||||||
return weights, biases
|
weights, biases, weights_proj = self._cu_canonical_to_tf_canonical(
|
||||||
|
cu_weights, cu_biases)
|
||||||
|
return weights, biases, weights_proj
|
||||||
|
else:
|
||||||
|
weights, biases = self._cu_canonical_to_tf_canonical(
|
||||||
|
cu_weights, cu_biases)
|
||||||
|
return weights, biases
|
||||||
|
|
||||||
def _opaque_to_cu_canonical(self, opaque_param):
|
def _opaque_to_cu_canonical(self, opaque_param):
|
||||||
"""Converts opaque params to Cudnn canonical format.
|
"""Converts opaque params to Cudnn canonical format.
|
||||||
@ -235,15 +246,31 @@ class CudnnParamsFormatConverter(object):
|
|||||||
2 list for weights and biases respectively.
|
2 list for weights and biases respectively.
|
||||||
"""
|
"""
|
||||||
with ops.device("/gpu:0"):
|
with ops.device("/gpu:0"):
|
||||||
weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical(
|
if self._num_proj:
|
||||||
num_layers=self._num_layers,
|
num_params_weights = (self._num_params +
|
||||||
num_units=self._num_units,
|
1 * self._num_layers * self._num_dirs)
|
||||||
input_size=self._input_size,
|
num_params_biases = self._num_params
|
||||||
params=opaque_param,
|
weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical_v2(
|
||||||
num_params=self._num_params,
|
num_layers=self._num_layers,
|
||||||
rnn_mode=self._rnn_mode,
|
num_units=self._num_units,
|
||||||
input_mode=self._input_mode,
|
input_size=self._input_size,
|
||||||
direction=self._direction)
|
params=opaque_param,
|
||||||
|
rnn_mode=self._rnn_mode,
|
||||||
|
input_mode=self._input_mode,
|
||||||
|
direction=self._direction,
|
||||||
|
num_params_weights=num_params_weights,
|
||||||
|
num_params_biases=num_params_biases,
|
||||||
|
num_proj=self._num_proj)
|
||||||
|
else:
|
||||||
|
weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical(
|
||||||
|
num_layers=self._num_layers,
|
||||||
|
num_units=self._num_units,
|
||||||
|
input_size=self._input_size,
|
||||||
|
params=opaque_param,
|
||||||
|
num_params=self._num_params,
|
||||||
|
rnn_mode=self._rnn_mode,
|
||||||
|
input_mode=self._input_mode,
|
||||||
|
direction=self._direction)
|
||||||
return (weights, biases)
|
return (weights, biases)
|
||||||
|
|
||||||
def _cu_canonical_to_opaque(self, cu_weights, cu_biases):
|
def _cu_canonical_to_opaque(self, cu_weights, cu_biases):
|
||||||
@ -256,16 +283,28 @@ class CudnnParamsFormatConverter(object):
|
|||||||
a single opaque tensor.
|
a single opaque tensor.
|
||||||
"""
|
"""
|
||||||
with ops.device("/gpu:0"):
|
with ops.device("/gpu:0"):
|
||||||
return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params(
|
if self._num_proj:
|
||||||
num_layers=self._num_layers,
|
return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params_v2(
|
||||||
num_units=self._num_units,
|
num_layers=self._num_layers,
|
||||||
input_size=self._input_size,
|
num_units=self._num_units,
|
||||||
weights=cu_weights,
|
input_size=self._input_size,
|
||||||
biases=cu_biases,
|
weights=cu_weights,
|
||||||
rnn_mode=self._rnn_mode,
|
biases=cu_biases,
|
||||||
input_mode=self._input_mode,
|
rnn_mode=self._rnn_mode,
|
||||||
direction=self._direction)
|
input_mode=self._input_mode,
|
||||||
|
num_proj=self._num_proj,
|
||||||
|
direction=self._direction)
|
||||||
|
else:
|
||||||
|
return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params(
|
||||||
|
num_layers=self._num_layers,
|
||||||
|
num_units=self._num_units,
|
||||||
|
input_size=self._input_size,
|
||||||
|
weights=cu_weights,
|
||||||
|
biases=cu_biases,
|
||||||
|
rnn_mode=self._rnn_mode,
|
||||||
|
input_mode=self._input_mode,
|
||||||
|
direction=self._direction)
|
||||||
|
|
||||||
def _cu_canonical_to_tf_canonical(self, cu_weights, cu_biases):
|
def _cu_canonical_to_tf_canonical(self, cu_weights, cu_biases):
|
||||||
r"""Transform from Cudnn canonical to tf canonical.
|
r"""Transform from Cudnn canonical to tf canonical.
|
||||||
|
|
||||||
@ -289,9 +328,11 @@ class CudnnParamsFormatConverter(object):
|
|||||||
1 tuple, tf canonical weights and biases.
|
1 tuple, tf canonical weights and biases.
|
||||||
"""
|
"""
|
||||||
tf_weights, tf_biases = [], []
|
tf_weights, tf_biases = [], []
|
||||||
|
tf_weights_proj = []
|
||||||
|
|
||||||
layer_weights_num = self._num_params_per_layer * self._num_dirs
|
layer_weights_num = self._num_params_per_layer * self._num_dirs
|
||||||
layer_biases_num = layer_weights_num
|
layer_biases_num = layer_weights_num
|
||||||
|
layer_weights_num += (1 * self._num_dirs) if self._num_proj else 0
|
||||||
|
|
||||||
for i in range(self._num_layers):
|
for i in range(self._num_layers):
|
||||||
layer_weights = cu_weights[i * layer_weights_num:(i + 1) *
|
layer_weights = cu_weights[i * layer_weights_num:(i + 1) *
|
||||||
@ -299,7 +340,7 @@ class CudnnParamsFormatConverter(object):
|
|||||||
layer_biases = cu_biases[i * layer_biases_num:(i + 1) * layer_biases_num]
|
layer_biases = cu_biases[i * layer_biases_num:(i + 1) * layer_biases_num]
|
||||||
if self._direction == CUDNN_RNN_UNIDIRECTION:
|
if self._direction == CUDNN_RNN_UNIDIRECTION:
|
||||||
self._cu_canonical_to_tf_canonical_single_layer(
|
self._cu_canonical_to_tf_canonical_single_layer(
|
||||||
layer_weights, layer_biases, tf_weights, tf_biases)
|
layer_weights, layer_biases, tf_weights, tf_biases, tf_weights_proj)
|
||||||
else:
|
else:
|
||||||
fw_weights = layer_weights[:len(layer_weights) // 2]
|
fw_weights = layer_weights[:len(layer_weights) // 2]
|
||||||
bw_weights = layer_weights[len(layer_weights) // 2:]
|
bw_weights = layer_weights[len(layer_weights) // 2:]
|
||||||
@ -311,6 +352,7 @@ class CudnnParamsFormatConverter(object):
|
|||||||
fw_biases,
|
fw_biases,
|
||||||
tf_weights,
|
tf_weights,
|
||||||
tf_biases,
|
tf_biases,
|
||||||
|
tf_weights_proj,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._cu_canonical_to_tf_canonical_single_layer(
|
self._cu_canonical_to_tf_canonical_single_layer(
|
||||||
@ -318,11 +360,16 @@ class CudnnParamsFormatConverter(object):
|
|||||||
bw_biases,
|
bw_biases,
|
||||||
tf_weights,
|
tf_weights,
|
||||||
tf_biases,
|
tf_biases,
|
||||||
|
tf_weights_proj,
|
||||||
)
|
)
|
||||||
return (tf_weights, tf_biases)
|
if self._num_proj:
|
||||||
|
return (tf_weights, tf_biases, tf_weights_proj)
|
||||||
|
else:
|
||||||
|
return (tf_weights, tf_biases)
|
||||||
|
|
||||||
def _cu_canonical_to_tf_canonical_single_layer(self, cu_weights, cu_biases,
|
def _cu_canonical_to_tf_canonical_single_layer(self, cu_weights, cu_biases,
|
||||||
tf_weights, tf_biases):
|
tf_weights, tf_biases,
|
||||||
|
tf_weigths_proj=None):
|
||||||
r"""Transform single layer Cudnn canonicals to tf canonicals.
|
r"""Transform single layer Cudnn canonicals to tf canonicals.
|
||||||
|
|
||||||
The elements of cu_weights, cu_biases are laid out in the following format:
|
The elements of cu_weights, cu_biases are laid out in the following format:
|
||||||
@ -337,7 +384,7 @@ class CudnnParamsFormatConverter(object):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError("Abstract method")
|
raise NotImplementedError("Abstract method")
|
||||||
|
|
||||||
def _tf_canonical_to_cu_canonical(self, tf_canonicals):
|
def _tf_canonical_to_cu_canonical(self, tf_canonicals, weights_proj):
|
||||||
r"""Transform from tf canonical to Cudnn canonical.
|
r"""Transform from tf canonical to Cudnn canonical.
|
||||||
|
|
||||||
This is the reverse routine of _TransformCanonical().
|
This is the reverse routine of _TransformCanonical().
|
||||||
@ -356,6 +403,7 @@ class CudnnParamsFormatConverter(object):
|
|||||||
---------------
|
---------------
|
||||||
|fwd |bak |
|
|fwd |bak |
|
||||||
---------------
|
---------------
|
||||||
|
weights_proj: (optional) weights matrices for projection
|
||||||
Returns:
|
Returns:
|
||||||
2 lists: the recovered cudnn canonical weights and biases.
|
2 lists: the recovered cudnn canonical weights and biases.
|
||||||
"""
|
"""
|
||||||
@ -370,6 +418,9 @@ class CudnnParamsFormatConverter(object):
|
|||||||
layer_biases = biases[i * layer_biases_num:(i + 1) * layer_biases_num]
|
layer_biases = biases[i * layer_biases_num:(i + 1) * layer_biases_num]
|
||||||
if self._direction == CUDNN_RNN_UNIDIRECTION:
|
if self._direction == CUDNN_RNN_UNIDIRECTION:
|
||||||
cu_weights.extend(self._tf_to_cudnn_weights(i, *layer_weights))
|
cu_weights.extend(self._tf_to_cudnn_weights(i, *layer_weights))
|
||||||
|
if weights_proj is not None:
|
||||||
|
pw = array_ops.transpose(weights_proj[i])
|
||||||
|
cu_weights.append(pw)
|
||||||
cu_biases.extend(self._tf_to_cudnn_biases(*layer_biases))
|
cu_biases.extend(self._tf_to_cudnn_biases(*layer_biases))
|
||||||
else:
|
else:
|
||||||
fw_weights, bw_weights = layer_weights[:len(
|
fw_weights, bw_weights = layer_weights[:len(
|
||||||
@ -377,9 +428,15 @@ class CudnnParamsFormatConverter(object):
|
|||||||
fw_biases, bw_biases = layer_biases[:len(
|
fw_biases, bw_biases = layer_biases[:len(
|
||||||
layer_biases) // 2], layer_biases[len(layer_biases) // 2:]
|
layer_biases) // 2], layer_biases[len(layer_biases) // 2:]
|
||||||
cu_weights.extend(self._tf_to_cudnn_weights(i, *fw_weights))
|
cu_weights.extend(self._tf_to_cudnn_weights(i, *fw_weights))
|
||||||
|
if weights_proj is not None:
|
||||||
|
pw0 = array_ops.transpose(weights_proj[2*i+0])
|
||||||
|
cu_weights.append(pw0)
|
||||||
cu_biases.extend(self._tf_to_cudnn_biases(*fw_biases))
|
cu_biases.extend(self._tf_to_cudnn_biases(*fw_biases))
|
||||||
|
|
||||||
cu_weights.extend(self._tf_to_cudnn_weights(i, *bw_weights))
|
cu_weights.extend(self._tf_to_cudnn_weights(i, *bw_weights))
|
||||||
|
if weights_proj is not None:
|
||||||
|
pw1 = array_ops.transpose(weights_proj[2*i+1])
|
||||||
|
cu_weights.append(pw1)
|
||||||
cu_biases.extend(self._tf_to_cudnn_biases(*bw_biases))
|
cu_biases.extend(self._tf_to_cudnn_biases(*bw_biases))
|
||||||
return cu_weights, cu_biases
|
return cu_weights, cu_biases
|
||||||
|
|
||||||
@ -415,7 +472,10 @@ class CudnnParamsFormatConverterLSTM(CudnnParamsFormatConverter):
|
|||||||
|
|
||||||
def _cudnn_to_tf_weights(self, *cu_weights):
|
def _cudnn_to_tf_weights(self, *cu_weights):
|
||||||
r"""Stitching cudnn canonical weights to generate tf canonical weights."""
|
r"""Stitching cudnn canonical weights to generate tf canonical weights."""
|
||||||
w_i, w_f, w_c, w_o, r_i, r_f, r_c, r_o = cu_weights
|
if self._num_proj:
|
||||||
|
w_i, w_f, w_c, w_o, r_i, r_f, r_c, r_o, pw = cu_weights
|
||||||
|
else:
|
||||||
|
w_i, w_f, w_c, w_o, r_i, r_f, r_c, r_o = cu_weights
|
||||||
|
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
W_i = array_ops.concat([w_i, r_i], axis=1)
|
W_i = array_ops.concat([w_i, r_i], axis=1)
|
||||||
@ -425,7 +485,11 @@ class CudnnParamsFormatConverterLSTM(CudnnParamsFormatConverter):
|
|||||||
# pylint: enable=invalid-name
|
# pylint: enable=invalid-name
|
||||||
# Cudnn LSTM weights are in ifco order, other tf LSTMs are in icfo order.
|
# Cudnn LSTM weights are in ifco order, other tf LSTMs are in icfo order.
|
||||||
reordered = self._cudnn_to_tf_gate_params(* [W_i, W_f, W_c, W_o])
|
reordered = self._cudnn_to_tf_gate_params(* [W_i, W_f, W_c, W_o])
|
||||||
return (array_ops.transpose(array_ops.concat(reordered, axis=0)),)
|
if self._num_proj:
|
||||||
|
return (array_ops.transpose(array_ops.concat(reordered, axis=0)),
|
||||||
|
array_ops.transpose(pw))
|
||||||
|
else:
|
||||||
|
return (array_ops.transpose(array_ops.concat(reordered, axis=0)),)
|
||||||
|
|
||||||
def _tf_to_cudnn_weights(self, layer, *tf_weights):
|
def _tf_to_cudnn_weights(self, layer, *tf_weights):
|
||||||
r"""Reverse the operations in StitchWeights()."""
|
r"""Reverse the operations in StitchWeights()."""
|
||||||
@ -434,7 +498,7 @@ class CudnnParamsFormatConverterLSTM(CudnnParamsFormatConverter):
|
|||||||
if layer == 0:
|
if layer == 0:
|
||||||
input_weight_width = input_size
|
input_weight_width = input_size
|
||||||
else:
|
else:
|
||||||
input_weight_width = num_units
|
input_weight_width = self._num_proj if self._num_proj else num_units
|
||||||
if self._direction == CUDNN_RNN_BIDIRECTION:
|
if self._direction == CUDNN_RNN_BIDIRECTION:
|
||||||
input_weight_width *= 2
|
input_weight_width *= 2
|
||||||
|
|
||||||
@ -444,10 +508,15 @@ class CudnnParamsFormatConverterLSTM(CudnnParamsFormatConverter):
|
|||||||
W_i, W_f, W_c, W_o = self._tf_to_cudnn_gate_params(*array_ops.split(
|
W_i, W_f, W_c, W_o = self._tf_to_cudnn_gate_params(*array_ops.split(
|
||||||
w, 4, axis=0))
|
w, 4, axis=0))
|
||||||
|
|
||||||
w_i, r_i = array_ops.split(W_i, [input_weight_width, num_units], axis=1)
|
hidden_state_width = self._num_proj if self._num_proj else num_units
|
||||||
w_c, r_c = array_ops.split(W_c, [input_weight_width, num_units], axis=1)
|
w_i, r_i = array_ops.split(W_i, [input_weight_width, hidden_state_width],
|
||||||
w_f, r_f = array_ops.split(W_f, [input_weight_width, num_units], axis=1)
|
axis=1)
|
||||||
w_o, r_o = array_ops.split(W_o, [input_weight_width, num_units], axis=1)
|
w_c, r_c = array_ops.split(W_c, [input_weight_width, hidden_state_width],
|
||||||
|
axis=1)
|
||||||
|
w_f, r_f = array_ops.split(W_f, [input_weight_width, hidden_state_width],
|
||||||
|
axis=1)
|
||||||
|
w_o, r_o = array_ops.split(W_o, [input_weight_width, hidden_state_width],
|
||||||
|
axis=1)
|
||||||
return w_i, w_f, w_c, w_o, r_i, r_f, r_c, r_o
|
return w_i, w_f, w_c, w_o, r_i, r_f, r_c, r_o
|
||||||
# pylint: enable=invalid-name
|
# pylint: enable=invalid-name
|
||||||
|
|
||||||
@ -483,10 +552,16 @@ class CudnnParamsFormatConverterLSTM(CudnnParamsFormatConverter):
|
|||||||
return b_wi, b_wf, b_wc, b_wo, b_ri, b_rf, b_rc, b_ro
|
return b_wi, b_wf, b_wc, b_wo, b_ri, b_rf, b_rc, b_ro
|
||||||
|
|
||||||
def _cu_canonical_to_tf_canonical_single_layer(self, cu_weights, cu_biases,
|
def _cu_canonical_to_tf_canonical_single_layer(self, cu_weights, cu_biases,
|
||||||
tf_weights, tf_biases):
|
tf_weights, tf_biases,
|
||||||
(w,) = self._cudnn_to_tf_weights(*cu_weights)
|
tf_weights_proj=None):
|
||||||
|
if self._num_proj:
|
||||||
|
(w, pw) = self._cudnn_to_tf_weights(*cu_weights)
|
||||||
|
tf_weights.append(w)
|
||||||
|
tf_weights_proj.append(pw)
|
||||||
|
else:
|
||||||
|
(w,) = self._cudnn_to_tf_weights(*cu_weights)
|
||||||
|
tf_weights.append(w)
|
||||||
(b,) = self._cudnn_to_tf_biases(*cu_biases)
|
(b,) = self._cudnn_to_tf_biases(*cu_biases)
|
||||||
tf_weights.append(w)
|
|
||||||
tf_biases.append(b)
|
tf_biases.append(b)
|
||||||
|
|
||||||
|
|
||||||
@ -554,7 +629,8 @@ class CudnnParamsFormatConverterGRU(CudnnParamsFormatConverter):
|
|||||||
return b_wi, b_wr, b_wh, b_ri, b_rr, b_rh
|
return b_wi, b_wr, b_wh, b_ri, b_rr, b_rh
|
||||||
|
|
||||||
def _cu_canonical_to_tf_canonical_single_layer(self, cu_weights, cu_biases,
|
def _cu_canonical_to_tf_canonical_single_layer(self, cu_weights, cu_biases,
|
||||||
tf_weights, tf_biases):
|
tf_weights, tf_biases,
|
||||||
|
tf_weights_proj=None):
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
W_ir, w_h, r_h = self._cudnn_to_tf_weights(*cu_weights)
|
W_ir, w_h, r_h = self._cudnn_to_tf_weights(*cu_weights)
|
||||||
b_ir, b_wh, b_rh = self._cudnn_to_tf_biases(*cu_biases)
|
b_ir, b_wh, b_rh = self._cudnn_to_tf_biases(*cu_biases)
|
||||||
@ -727,8 +803,9 @@ class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject):
|
|||||||
def format_converter(self):
|
def format_converter(self):
|
||||||
if self._format_converter is None:
|
if self._format_converter is None:
|
||||||
self._format_converter = self._format_converter_cls(
|
self._format_converter = self._format_converter_cls(
|
||||||
self._num_layers, self._num_units, self._input_size, self._input_mode,
|
self._num_layers, self._num_units, self._input_size,
|
||||||
self._direction)
|
input_mode=self._input_mode,
|
||||||
|
direction=self._direction)
|
||||||
return self._format_converter
|
return self._format_converter
|
||||||
|
|
||||||
def restore(self, restored_tensors, restored_shapes):
|
def restore(self, restored_tensors, restored_shapes):
|
||||||
@ -962,6 +1039,7 @@ def _cudnn_rnn(inputs,
|
|||||||
direction=CUDNN_RNN_UNIDIRECTION,
|
direction=CUDNN_RNN_UNIDIRECTION,
|
||||||
dropout=0.,
|
dropout=0.,
|
||||||
seed=0,
|
seed=0,
|
||||||
|
num_proj=None,
|
||||||
name=None):
|
name=None):
|
||||||
"""Cudnn RNN.
|
"""Cudnn RNN.
|
||||||
|
|
||||||
@ -999,6 +1077,8 @@ def _cudnn_rnn(inputs,
|
|||||||
dropout: whether to enable dropout. With it is 0, dropout is disabled.
|
dropout: whether to enable dropout. With it is 0, dropout is disabled.
|
||||||
seed: the op seed used for initializing dropout. See `tf.set_random_seed`
|
seed: the op seed used for initializing dropout. See `tf.set_random_seed`
|
||||||
for behavior.
|
for behavior.
|
||||||
|
num_proj: The output dimensionality for the projection matrices.
|
||||||
|
If None or 0, no projection is performed.
|
||||||
name: name of the operation.
|
name: name of the operation.
|
||||||
Returns:
|
Returns:
|
||||||
outputs, output_h, output_c
|
outputs, output_h, output_c
|
||||||
@ -1027,13 +1107,15 @@ def _cudnn_rnn(inputs,
|
|||||||
if sequence_lengths is not None:
|
if sequence_lengths is not None:
|
||||||
args["sequence_lengths"] = sequence_lengths
|
args["sequence_lengths"] = sequence_lengths
|
||||||
args["time_major"] = time_major
|
args["time_major"] = time_major
|
||||||
|
args["num_proj"] = 0 if num_proj is None else num_proj
|
||||||
outputs, output_h, output_c, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv3(**args)
|
outputs, output_h, output_c, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv3(**args)
|
||||||
elif time_major is False:
|
elif time_major is False or num_proj:
|
||||||
batch_size = array_ops.shape(inputs)[0]
|
batch_size = array_ops.shape(inputs)[0]
|
||||||
max_time = array_ops.shape(inputs)[1]
|
max_time = array_ops.shape(inputs)[1]
|
||||||
sequence_lengths = array_ops.fill([batch_size], max_time)
|
sequence_lengths = array_ops.fill([batch_size], max_time)
|
||||||
args["sequence_lengths"] = sequence_lengths
|
args["sequence_lengths"] = sequence_lengths
|
||||||
args["time_major"] = time_major
|
args["time_major"] = time_major
|
||||||
|
args["num_proj"] = 0 if num_proj is None else num_proj
|
||||||
outputs, output_h, output_c, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv3(**args)
|
outputs, output_h, output_c, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv3(**args)
|
||||||
elif use_cudnn_v2 != "1":
|
elif use_cudnn_v2 != "1":
|
||||||
outputs, output_h, output_c, _ = gen_cudnn_rnn_ops.cudnn_rnn(**args)
|
outputs, output_h, output_c, _ = gen_cudnn_rnn_ops.cudnn_rnn(**args)
|
||||||
@ -1053,6 +1135,7 @@ def cudnn_lstm(inputs,
|
|||||||
direction=CUDNN_RNN_UNIDIRECTION,
|
direction=CUDNN_RNN_UNIDIRECTION,
|
||||||
dropout=0.,
|
dropout=0.,
|
||||||
seed=0,
|
seed=0,
|
||||||
|
num_proj=None,
|
||||||
name=None):
|
name=None):
|
||||||
"""Cudnn LSTM.
|
"""Cudnn LSTM.
|
||||||
|
|
||||||
@ -1089,13 +1172,15 @@ def cudnn_lstm(inputs,
|
|||||||
dropout: whether to enable dropout. With it is 0, dropout is disabled.
|
dropout: whether to enable dropout. With it is 0, dropout is disabled.
|
||||||
seed: the op seed used for initializing dropout. See `tf.set_random_seed`
|
seed: the op seed used for initializing dropout. See `tf.set_random_seed`
|
||||||
for behavior.
|
for behavior.
|
||||||
|
num_proj: The output dimensionality for the projection matrices.
|
||||||
|
If None or 0, no projection is performed.
|
||||||
name: name of the operation.
|
name: name of the operation.
|
||||||
Returns:
|
Returns:
|
||||||
outputs, output_h, output_c
|
outputs, output_h, output_c
|
||||||
"""
|
"""
|
||||||
return _cudnn_rnn(inputs, input_h, input_c, params, is_training, CUDNN_LSTM,
|
return _cudnn_rnn(inputs, input_h, input_c, params, is_training, CUDNN_LSTM,
|
||||||
sequence_lengths, time_major, input_mode, direction,
|
sequence_lengths, time_major, input_mode, direction,
|
||||||
dropout, seed, name)
|
dropout, seed, num_proj, name)
|
||||||
|
|
||||||
|
|
||||||
def _cudnn_rnn_no_input_c(inputs,
|
def _cudnn_rnn_no_input_c(inputs,
|
||||||
@ -1151,7 +1236,7 @@ def _cudnn_rnn_no_input_c(inputs,
|
|||||||
input_c = array_ops.constant([], dtype=input_h.dtype)
|
input_c = array_ops.constant([], dtype=input_h.dtype)
|
||||||
outputs, output_h, _ = _cudnn_rnn(
|
outputs, output_h, _ = _cudnn_rnn(
|
||||||
inputs, input_h, input_c, params, is_training, rnn_mode, sequence_lengths,
|
inputs, input_h, input_c, params, is_training, rnn_mode, sequence_lengths,
|
||||||
time_major, input_mode, direction, dropout, seed, name)
|
time_major, input_mode, direction, dropout, seed, None, name)
|
||||||
return outputs, output_h
|
return outputs, output_h
|
||||||
|
|
||||||
|
|
||||||
@ -1322,6 +1407,7 @@ def cudnn_rnn_opaque_params_to_canonical(rnn_mode,
|
|||||||
direction=CUDNN_RNN_UNIDIRECTION,
|
direction=CUDNN_RNN_UNIDIRECTION,
|
||||||
dropout=0,
|
dropout=0,
|
||||||
seed=0,
|
seed=0,
|
||||||
|
num_proj=None,
|
||||||
name=None):
|
name=None):
|
||||||
"""Convert cudnn opaque params to canonical.
|
"""Convert cudnn opaque params to canonical.
|
||||||
|
|
||||||
@ -1346,6 +1432,8 @@ def cudnn_rnn_opaque_params_to_canonical(rnn_mode,
|
|||||||
dropout: whether to enable dropout. With it is 0, dropout is disabled.
|
dropout: whether to enable dropout. With it is 0, dropout is disabled.
|
||||||
seed: the op seed used for initializing dropout. See `tf.set_random_seed`
|
seed: the op seed used for initializing dropout. See `tf.set_random_seed`
|
||||||
for behavior.
|
for behavior.
|
||||||
|
num_proj: The output dimensionality for the projection matrices.
|
||||||
|
If None or 0, no projection is performed.
|
||||||
name: name of the operation.
|
name: name of the operation.
|
||||||
Returns:
|
Returns:
|
||||||
weights list and bias list
|
weights list and bias list
|
||||||
@ -1358,19 +1446,39 @@ def cudnn_rnn_opaque_params_to_canonical(rnn_mode,
|
|||||||
check_input_mode(input_mode)
|
check_input_mode(input_mode)
|
||||||
num_params = _get_num_params(rnn_mode, num_layers, direction)
|
num_params = _get_num_params(rnn_mode, num_layers, direction)
|
||||||
seed, seed2 = random_seed.get_seed(seed)
|
seed, seed2 = random_seed.get_seed(seed)
|
||||||
weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical(
|
num_dirs = 1 if direction == CUDNN_RNN_UNIDIRECTION else 2
|
||||||
rnn_mode=rnn_mode,
|
if num_proj is not None and num_proj != 0:
|
||||||
num_layers=num_layers,
|
num_params_weights = (num_params + 1 * num_layers * num_dirs)
|
||||||
num_units=num_units,
|
num_params_biases = num_params
|
||||||
input_size=input_size,
|
weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical_v2(
|
||||||
params=params,
|
rnn_mode=rnn_mode,
|
||||||
input_mode=input_mode,
|
num_layers=num_layers,
|
||||||
direction=direction,
|
num_units=num_units,
|
||||||
dropout=dropout,
|
input_size=input_size,
|
||||||
seed=seed,
|
params=params,
|
||||||
seed2=seed2,
|
input_mode=input_mode,
|
||||||
num_params=num_params,
|
direction=direction,
|
||||||
name=name)
|
dropout=dropout,
|
||||||
|
seed=seed,
|
||||||
|
seed2=seed2,
|
||||||
|
num_params_weights=num_params_weights,
|
||||||
|
num_params_biases=num_params_biases,
|
||||||
|
num_proj=num_proj,
|
||||||
|
name=name)
|
||||||
|
else:
|
||||||
|
weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical(
|
||||||
|
rnn_mode=rnn_mode,
|
||||||
|
num_layers=num_layers,
|
||||||
|
num_units=num_units,
|
||||||
|
input_size=input_size,
|
||||||
|
params=params,
|
||||||
|
input_mode=input_mode,
|
||||||
|
direction=direction,
|
||||||
|
dropout=dropout,
|
||||||
|
seed=seed,
|
||||||
|
seed2=seed2,
|
||||||
|
num_params=num_params,
|
||||||
|
name=name)
|
||||||
return weights, biases
|
return weights, biases
|
||||||
|
|
||||||
|
|
||||||
@ -1384,6 +1492,7 @@ def cudnn_rnn_canonical_to_opaque_params(rnn_mode,
|
|||||||
direction=CUDNN_RNN_UNIDIRECTION,
|
direction=CUDNN_RNN_UNIDIRECTION,
|
||||||
dropout=0,
|
dropout=0,
|
||||||
seed=0,
|
seed=0,
|
||||||
|
num_proj=None,
|
||||||
name=None):
|
name=None):
|
||||||
"""Converts params from the canonical format to a specific format of cuDNN.
|
"""Converts params from the canonical format to a specific format of cuDNN.
|
||||||
|
|
||||||
@ -1409,6 +1518,8 @@ def cudnn_rnn_canonical_to_opaque_params(rnn_mode,
|
|||||||
dropout: whether to enable dropout. With it is 0, dropout is disabled.
|
dropout: whether to enable dropout. With it is 0, dropout is disabled.
|
||||||
seed: the op seed used for initializing dropout. See `tf.set_random_seed`
|
seed: the op seed used for initializing dropout. See `tf.set_random_seed`
|
||||||
for behavior.
|
for behavior.
|
||||||
|
num_proj: The output dimensionality for the projection matrices.
|
||||||
|
If None or 0, no projection is performed.
|
||||||
name: name of the operation.
|
name: name of the operation.
|
||||||
Returns:
|
Returns:
|
||||||
an opaque Cudnn param.
|
an opaque Cudnn param.
|
||||||
@ -1419,20 +1530,35 @@ def cudnn_rnn_canonical_to_opaque_params(rnn_mode,
|
|||||||
check_direction(direction)
|
check_direction(direction)
|
||||||
check_input_mode(input_mode)
|
check_input_mode(input_mode)
|
||||||
seed, seed2 = random_seed.get_seed(seed)
|
seed, seed2 = random_seed.get_seed(seed)
|
||||||
return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params(
|
if num_proj is not None and num_proj != 0:
|
||||||
rnn_mode=rnn_mode,
|
return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params_v2(
|
||||||
num_layers=num_layers,
|
rnn_mode=rnn_mode,
|
||||||
num_units=num_units,
|
num_layers=num_layers,
|
||||||
input_size=input_size,
|
num_units=num_units,
|
||||||
weights=weights,
|
input_size=input_size,
|
||||||
biases=biases,
|
weights=weights,
|
||||||
input_mode=input_mode,
|
biases=biases,
|
||||||
direction=direction,
|
input_mode=input_mode,
|
||||||
dropout=dropout,
|
direction=direction,
|
||||||
seed=seed,
|
dropout=dropout,
|
||||||
seed2=seed2,
|
seed=seed,
|
||||||
name=name)
|
seed2=seed2,
|
||||||
|
num_proj=num_proj,
|
||||||
|
name=name)
|
||||||
|
else:
|
||||||
|
return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params(
|
||||||
|
rnn_mode=rnn_mode,
|
||||||
|
num_layers=num_layers,
|
||||||
|
num_units=num_units,
|
||||||
|
input_size=input_size,
|
||||||
|
weights=weights,
|
||||||
|
biases=biases,
|
||||||
|
input_mode=input_mode,
|
||||||
|
direction=direction,
|
||||||
|
dropout=dropout,
|
||||||
|
seed=seed,
|
||||||
|
seed2=seed2,
|
||||||
|
name=name)
|
||||||
|
|
||||||
def cudnn_rnn_opaque_params_size(rnn_mode,
|
def cudnn_rnn_opaque_params_size(rnn_mode,
|
||||||
num_layers,
|
num_layers,
|
||||||
@ -1443,6 +1569,7 @@ def cudnn_rnn_opaque_params_size(rnn_mode,
|
|||||||
dtype=dtypes.float32,
|
dtype=dtypes.float32,
|
||||||
dropout=0,
|
dropout=0,
|
||||||
seed=0,
|
seed=0,
|
||||||
|
num_proj=None,
|
||||||
name=None):
|
name=None):
|
||||||
"""Returns opaque params size for specific Cudnn config.
|
"""Returns opaque params size for specific Cudnn config.
|
||||||
|
|
||||||
@ -1467,6 +1594,8 @@ def cudnn_rnn_opaque_params_size(rnn_mode,
|
|||||||
dropout: whether to enable dropout. With it is 0, dropout is disabled.
|
dropout: whether to enable dropout. With it is 0, dropout is disabled.
|
||||||
seed: the op seed used for initializing dropout. See `tf.set_random_seed`
|
seed: the op seed used for initializing dropout. See `tf.set_random_seed`
|
||||||
for behavior.
|
for behavior.
|
||||||
|
num_proj: The output dimensionality for the projection matrices.
|
||||||
|
If None or 0, no projection is performed.
|
||||||
name: name of the operation.
|
name: name of the operation.
|
||||||
Returns:
|
Returns:
|
||||||
a int, size of Cudnn opaque params.
|
a int, size of Cudnn opaque params.
|
||||||
@ -1482,6 +1611,7 @@ def cudnn_rnn_opaque_params_size(rnn_mode,
|
|||||||
num_layers=num_layers,
|
num_layers=num_layers,
|
||||||
num_units=num_units,
|
num_units=num_units,
|
||||||
input_size=input_size,
|
input_size=input_size,
|
||||||
|
num_proj=num_proj,
|
||||||
T=dtype,
|
T=dtype,
|
||||||
S=dtypes.int32,
|
S=dtypes.int32,
|
||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
@ -1510,7 +1640,8 @@ class _CudnnRNN(object):
|
|||||||
direction=CUDNN_RNN_UNIDIRECTION,
|
direction=CUDNN_RNN_UNIDIRECTION,
|
||||||
dtype=dtypes.float32,
|
dtype=dtypes.float32,
|
||||||
dropout=0.,
|
dropout=0.,
|
||||||
seed=0):
|
seed=0,
|
||||||
|
num_proj=None):
|
||||||
"""Creates a CudnnRNN model from model spec.
|
"""Creates a CudnnRNN model from model spec.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -1534,6 +1665,8 @@ class _CudnnRNN(object):
|
|||||||
dropout: whether to enable dropout. With it is 0, dropout is disabled.
|
dropout: whether to enable dropout. With it is 0, dropout is disabled.
|
||||||
seed: the op seed used for initializing dropout. See `tf.set_random_seed`
|
seed: the op seed used for initializing dropout. See `tf.set_random_seed`
|
||||||
for behavior.
|
for behavior.
|
||||||
|
num_proj: The output dimensionality for the projection matrices.
|
||||||
|
If None or 0, no projection is performed.
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: if direction is invalid.
|
ValueError: if direction is invalid.
|
||||||
"""
|
"""
|
||||||
@ -1546,6 +1679,7 @@ class _CudnnRNN(object):
|
|||||||
self._dtype = dtype
|
self._dtype = dtype
|
||||||
self._dropout = dropout
|
self._dropout = dropout
|
||||||
self._seed = seed
|
self._seed = seed
|
||||||
|
self._num_proj = num_proj
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def input_mode(self):
|
def input_mode(self):
|
||||||
@ -1571,6 +1705,10 @@ class _CudnnRNN(object):
|
|||||||
def direction(self):
|
def direction(self):
|
||||||
return self._direction
|
return self._direction
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_proj(self):
|
||||||
|
return self._num_proj
|
||||||
|
|
||||||
def params_size(self):
|
def params_size(self):
|
||||||
"""Calculates the size of the opaque parameter buffer needed for this model.
|
"""Calculates the size of the opaque parameter buffer needed for this model.
|
||||||
|
|
||||||
@ -1582,6 +1720,7 @@ class _CudnnRNN(object):
|
|||||||
num_layers=self._num_layers,
|
num_layers=self._num_layers,
|
||||||
num_units=self._num_units,
|
num_units=self._num_units,
|
||||||
input_size=self._input_size,
|
input_size=self._input_size,
|
||||||
|
num_proj=self._num_proj,
|
||||||
dtype=self._dtype,
|
dtype=self._dtype,
|
||||||
dropout=self._dropout,
|
dropout=self._dropout,
|
||||||
seed=self._seed,
|
seed=self._seed,
|
||||||
@ -1637,7 +1776,8 @@ class _CudnnRNN(object):
|
|||||||
input_mode=self._input_mode,
|
input_mode=self._input_mode,
|
||||||
direction=self._direction,
|
direction=self._direction,
|
||||||
dropout=self._dropout,
|
dropout=self._dropout,
|
||||||
seed=self._seed)
|
seed=self._seed,
|
||||||
|
num_proj=self._num_proj)
|
||||||
|
|
||||||
def params_to_canonical(self, params):
|
def params_to_canonical(self, params):
|
||||||
"""Converts params from a specific format of cuDNN to the canonical format.
|
"""Converts params from a specific format of cuDNN to the canonical format.
|
||||||
@ -1657,7 +1797,8 @@ class _CudnnRNN(object):
|
|||||||
input_mode=self._input_mode,
|
input_mode=self._input_mode,
|
||||||
direction=self._direction,
|
direction=self._direction,
|
||||||
dropout=self._dropout,
|
dropout=self._dropout,
|
||||||
seed=self._seed)
|
seed=self._seed,
|
||||||
|
num_proj=self._num_proj)
|
||||||
|
|
||||||
def canonical_to_params(self, weights, biases):
|
def canonical_to_params(self, weights, biases):
|
||||||
"""Converts params from the canonical format to a specific format of cuDNN.
|
"""Converts params from the canonical format to a specific format of cuDNN.
|
||||||
@ -1679,7 +1820,8 @@ class _CudnnRNN(object):
|
|||||||
input_mode=self._input_mode,
|
input_mode=self._input_mode,
|
||||||
direction=self._direction,
|
direction=self._direction,
|
||||||
dropout=self._dropout,
|
dropout=self._dropout,
|
||||||
seed=self._seed)
|
seed=self._seed,
|
||||||
|
num_proj=self._num_proj)
|
||||||
|
|
||||||
|
|
||||||
class CudnnLSTM(_CudnnRNN):
|
class CudnnLSTM(_CudnnRNN):
|
||||||
@ -1697,7 +1839,8 @@ class CudnnLSTM(_CudnnRNN):
|
|||||||
direction=CUDNN_RNN_UNIDIRECTION,
|
direction=CUDNN_RNN_UNIDIRECTION,
|
||||||
dtype=dtypes.float32,
|
dtype=dtypes.float32,
|
||||||
dropout=0.,
|
dropout=0.,
|
||||||
seed=0):
|
seed=0,
|
||||||
|
num_proj=None):
|
||||||
"""Creates a Cudnn LSTM model from model spec.
|
"""Creates a Cudnn LSTM model from model spec.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -1716,6 +1859,8 @@ class CudnnLSTM(_CudnnRNN):
|
|||||||
dtype: dtype of params, tf.float32 or tf.float64.
|
dtype: dtype of params, tf.float32 or tf.float64.
|
||||||
dropout: whether to enable dropout. With it is 0, dropout is disabled.
|
dropout: whether to enable dropout. With it is 0, dropout is disabled.
|
||||||
seed: the seed used for initializing dropout.
|
seed: the seed used for initializing dropout.
|
||||||
|
num_proj: The output dimensionality for the projection matrices.
|
||||||
|
If None or 0, no projection is performed.
|
||||||
"""
|
"""
|
||||||
super(CudnnLSTM, self).__init__(
|
super(CudnnLSTM, self).__init__(
|
||||||
CUDNN_LSTM,
|
CUDNN_LSTM,
|
||||||
@ -1726,7 +1871,8 @@ class CudnnLSTM(_CudnnRNN):
|
|||||||
direction=direction,
|
direction=direction,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
seed=seed)
|
seed=seed,
|
||||||
|
num_proj=num_proj)
|
||||||
|
|
||||||
def __call__(self,
|
def __call__(self,
|
||||||
input_data,
|
input_data,
|
||||||
|
@ -502,20 +502,23 @@ struct CudnnRnnModelShapes {
|
|||||||
int dir_count;
|
int dir_count;
|
||||||
int max_seq_length;
|
int max_seq_length;
|
||||||
int batch_size;
|
int batch_size;
|
||||||
|
int c_num_units;
|
||||||
TensorShape input_shape;
|
TensorShape input_shape;
|
||||||
TensorShape output_shape;
|
TensorShape output_shape;
|
||||||
TensorShape hidden_state_shape;
|
TensorShape hidden_state_shape;
|
||||||
|
TensorShape c_state_shape;
|
||||||
// At present only fields related to cached RnnDescriptor are concerned.
|
// At present only fields related to cached RnnDescriptor are concerned.
|
||||||
bool IsCompatibleWith(const CudnnRnnModelShapes& rhs) const {
|
bool IsCompatibleWith(const CudnnRnnModelShapes& rhs) const {
|
||||||
return num_layers == rhs.num_layers && input_size == rhs.input_size &&
|
return num_layers == rhs.num_layers && input_size == rhs.input_size &&
|
||||||
num_units == rhs.num_units && dir_count == rhs.dir_count;
|
num_units == rhs.num_units && dir_count == rhs.dir_count &&
|
||||||
|
c_num_units == rhs.c_num_units;
|
||||||
}
|
}
|
||||||
string DebugString() const {
|
string DebugString() const {
|
||||||
return strings::Printf(
|
return strings::Printf(
|
||||||
"[num_layers, input_size, num_units, dir_count, max_seq_length, "
|
"[num_layers, input_size, num_units, dir_count, max_seq_length, "
|
||||||
"batch_size]: [%d, %d, %d, %d, %d, %d] ",
|
"batch_size, c_num_units]: [%d, %d, %d, %d, %d, %d, %d] ",
|
||||||
num_layers, input_size, num_units, dir_count, max_seq_length,
|
num_layers, input_size, num_units, dir_count, max_seq_length,
|
||||||
batch_size);
|
batch_size, c_num_units);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -562,6 +565,7 @@ Status ExtractForwardInput(OpKernelContext* context,
|
|||||||
const CudnnModelTypes& model_types, bool time_major,
|
const CudnnModelTypes& model_types, bool time_major,
|
||||||
const Tensor** input, const Tensor** input_h,
|
const Tensor** input, const Tensor** input_h,
|
||||||
const Tensor** input_c, const Tensor** params,
|
const Tensor** input_c, const Tensor** params,
|
||||||
|
const int num_proj,
|
||||||
CudnnRnnModelShapes* model_shapes) {
|
CudnnRnnModelShapes* model_shapes) {
|
||||||
TF_RETURN_IF_ERROR(context->input("input", input));
|
TF_RETURN_IF_ERROR(context->input("input", input));
|
||||||
TF_RETURN_IF_ERROR(context->input("input_h", input_h));
|
TF_RETURN_IF_ERROR(context->input("input_h", input_h));
|
||||||
@ -615,12 +619,48 @@ Status ExtractForwardInput(OpKernelContext* context,
|
|||||||
model_shapes->hidden_state_shape.DebugString());
|
model_shapes->hidden_state_shape.DebugString());
|
||||||
}
|
}
|
||||||
if (model_types.HasInputC()) {
|
if (model_types.HasInputC()) {
|
||||||
if ((*input_h)->shape() != (*input_c)->shape()) {
|
model_shapes->c_num_units = (*input_c)->dim_size(2);
|
||||||
return errors::InvalidArgument(
|
if (time_major) {
|
||||||
"input_h and input_c must have the same shape: ",
|
model_shapes->c_state_shape =
|
||||||
(*input_h)->shape().DebugString(), " ",
|
TensorShape({model_shapes->dir_count * model_shapes->num_layers,
|
||||||
(*input_c)->shape().DebugString());
|
model_shapes->batch_size, model_shapes->c_num_units});
|
||||||
|
} else {
|
||||||
|
model_shapes->c_state_shape =
|
||||||
|
TensorShape({model_shapes->batch_size,
|
||||||
|
model_shapes->dir_count * model_shapes->num_layers,
|
||||||
|
model_shapes->c_num_units});
|
||||||
}
|
}
|
||||||
|
if (num_proj == 0) {
|
||||||
|
if ((*input_h)->shape() != (*input_c)->shape()) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"input_h and input_c must have the same shape w/o projection: ",
|
||||||
|
(*input_h)->shape().DebugString(), " ",
|
||||||
|
(*input_c)->shape().DebugString());
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if ((*input_h)->dim_size(2) > (*input_c)->dim_size(2) ||
|
||||||
|
num_proj != (*input_h)->dim_size(2) ||
|
||||||
|
(*input_h)->dim_size(0) != (*input_c)->dim_size(0) ||
|
||||||
|
(*input_h)->dim_size(1) != (*input_c)->dim_size(1)) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Invalid input_h and input_c w/ projection size: ", num_proj, " ",
|
||||||
|
(*input_h)->shape().DebugString(), " ",
|
||||||
|
(*input_c)->shape().DebugString());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// dummy c_state_shape TODO(kaixih): remove the time_major branch
|
||||||
|
if (time_major) {
|
||||||
|
model_shapes->c_state_shape =
|
||||||
|
TensorShape({model_shapes->dir_count * model_shapes->num_layers,
|
||||||
|
model_shapes->batch_size, model_shapes->num_units});
|
||||||
|
} else {
|
||||||
|
model_shapes->c_state_shape =
|
||||||
|
TensorShape({model_shapes->batch_size,
|
||||||
|
model_shapes->dir_count * model_shapes->num_layers,
|
||||||
|
model_shapes->num_units});
|
||||||
|
}
|
||||||
|
model_shapes->c_num_units = 0;
|
||||||
}
|
}
|
||||||
if (time_major) {
|
if (time_major) {
|
||||||
model_shapes->output_shape =
|
model_shapes->output_shape =
|
||||||
@ -639,18 +679,19 @@ Status ExtractForwardInput(OpKernelContext* context,
|
|||||||
const CudnnModelTypes& model_types, bool time_major,
|
const CudnnModelTypes& model_types, bool time_major,
|
||||||
const Tensor** input, const Tensor** input_h,
|
const Tensor** input, const Tensor** input_h,
|
||||||
const Tensor** input_c, const Tensor** params,
|
const Tensor** input_c, const Tensor** params,
|
||||||
const Tensor** sequence_lengths,
|
const Tensor** sequence_lengths, const int num_proj,
|
||||||
CudnnRnnModelShapes* model_shapes) {
|
CudnnRnnModelShapes* model_shapes) {
|
||||||
TF_RETURN_IF_ERROR(context->input("sequence_lengths", sequence_lengths));
|
TF_RETURN_IF_ERROR(context->input("sequence_lengths", sequence_lengths));
|
||||||
return ExtractForwardInput(context, model_types, time_major, input, input_h,
|
return ExtractForwardInput(context, model_types, time_major, input, input_h,
|
||||||
input_c, params, model_shapes);
|
input_c, params, num_proj, model_shapes);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
Status CreateForwardAndBackwardIODescriptors(
|
Status CreateForwardAndBackwardIODescriptors(
|
||||||
OpKernelContext* context, const CudnnRnnModelShapes& model_shapes,
|
OpKernelContext* context, const CudnnRnnModelShapes& model_shapes,
|
||||||
std::unique_ptr<RnnSequenceTensorDescriptor>* input_desc,
|
std::unique_ptr<RnnSequenceTensorDescriptor>* input_desc,
|
||||||
std::unique_ptr<RnnStateTensorDescriptor>* state_desc,
|
std::unique_ptr<RnnStateTensorDescriptor>* h_state_desc,
|
||||||
|
std::unique_ptr<RnnStateTensorDescriptor>* c_state_desc,
|
||||||
std::unique_ptr<RnnSequenceTensorDescriptor>* output_desc,
|
std::unique_ptr<RnnSequenceTensorDescriptor>* output_desc,
|
||||||
const absl::Span<const int>& seq_lengths, bool time_major) {
|
const absl::Span<const int>& seq_lengths, bool time_major) {
|
||||||
StreamExecutor* executor = context->op_device_context()->stream()->parent();
|
StreamExecutor* executor = context->op_device_context()->stream()->parent();
|
||||||
@ -658,6 +699,7 @@ Status CreateForwardAndBackwardIODescriptors(
|
|||||||
|
|
||||||
const TensorShape& input_shape = model_shapes.input_shape;
|
const TensorShape& input_shape = model_shapes.input_shape;
|
||||||
const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape;
|
const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape;
|
||||||
|
const TensorShape& c_state_shape = model_shapes.c_state_shape;
|
||||||
const TensorShape& output_shape = model_shapes.output_shape;
|
const TensorShape& output_shape = model_shapes.output_shape;
|
||||||
|
|
||||||
DCHECK_EQ(input_shape.dims(), 3);
|
DCHECK_EQ(input_shape.dims(), 3);
|
||||||
@ -689,13 +731,28 @@ Status CreateForwardAndBackwardIODescriptors(
|
|||||||
hidden_state_shape.dim_size(0), hidden_state_shape.dim_size(1),
|
hidden_state_shape.dim_size(0), hidden_state_shape.dim_size(1),
|
||||||
hidden_state_shape.dim_size(2), data_type);
|
hidden_state_shape.dim_size(2), data_type);
|
||||||
TF_RETURN_IF_ERROR(hidden_state_desc_s.status());
|
TF_RETURN_IF_ERROR(hidden_state_desc_s.status());
|
||||||
*state_desc = hidden_state_desc_s.ConsumeValueOrDie();
|
*h_state_desc = hidden_state_desc_s.ConsumeValueOrDie();
|
||||||
} else {
|
} else {
|
||||||
auto hidden_state_desc_s = executor->createRnnStateTensorDescriptor(
|
auto hidden_state_desc_s = executor->createRnnStateTensorDescriptor(
|
||||||
hidden_state_shape.dim_size(1), hidden_state_shape.dim_size(0),
|
hidden_state_shape.dim_size(1), hidden_state_shape.dim_size(0),
|
||||||
hidden_state_shape.dim_size(2), data_type);
|
hidden_state_shape.dim_size(2), data_type);
|
||||||
TF_RETURN_IF_ERROR(hidden_state_desc_s.status());
|
TF_RETURN_IF_ERROR(hidden_state_desc_s.status());
|
||||||
*state_desc = hidden_state_desc_s.ConsumeValueOrDie();
|
*h_state_desc = hidden_state_desc_s.ConsumeValueOrDie();
|
||||||
|
}
|
||||||
|
|
||||||
|
DCHECK_EQ(c_state_shape.dims(), 3);
|
||||||
|
if (time_major) {
|
||||||
|
auto c_state_desc_s = executor->createRnnStateTensorDescriptor(
|
||||||
|
c_state_shape.dim_size(0), c_state_shape.dim_size(1),
|
||||||
|
c_state_shape.dim_size(2), data_type);
|
||||||
|
TF_RETURN_IF_ERROR(c_state_desc_s.status());
|
||||||
|
*c_state_desc = c_state_desc_s.ConsumeValueOrDie();
|
||||||
|
} else {
|
||||||
|
auto c_state_desc_s = executor->createRnnStateTensorDescriptor(
|
||||||
|
c_state_shape.dim_size(1), c_state_shape.dim_size(0),
|
||||||
|
c_state_shape.dim_size(2), data_type);
|
||||||
|
TF_RETURN_IF_ERROR(c_state_desc_s.status());
|
||||||
|
*c_state_desc = c_state_desc_s.ConsumeValueOrDie();
|
||||||
}
|
}
|
||||||
|
|
||||||
DCHECK_EQ(output_shape.dims(), 3);
|
DCHECK_EQ(output_shape.dims(), 3);
|
||||||
@ -739,7 +796,8 @@ Status DoForward(OpKernelContext* context, const RnnDescriptor& rnn_desc,
|
|||||||
ScratchAllocator* workspace_allocator,
|
ScratchAllocator* workspace_allocator,
|
||||||
ProfileResult* output_profile_result) {
|
ProfileResult* output_profile_result) {
|
||||||
std::unique_ptr<RnnSequenceTensorDescriptor> input_desc;
|
std::unique_ptr<RnnSequenceTensorDescriptor> input_desc;
|
||||||
std::unique_ptr<RnnStateTensorDescriptor> state_desc;
|
std::unique_ptr<RnnStateTensorDescriptor> h_state_desc;
|
||||||
|
std::unique_ptr<RnnStateTensorDescriptor> c_state_desc;
|
||||||
std::unique_ptr<RnnSequenceTensorDescriptor> output_desc;
|
std::unique_ptr<RnnSequenceTensorDescriptor> output_desc;
|
||||||
|
|
||||||
absl::Span<const int> seq_lengths;
|
absl::Span<const int> seq_lengths;
|
||||||
@ -748,8 +806,8 @@ Status DoForward(OpKernelContext* context, const RnnDescriptor& rnn_desc,
|
|||||||
sequence_lengths->template flat<int>().data(), model_shapes.batch_size);
|
sequence_lengths->template flat<int>().data(), model_shapes.batch_size);
|
||||||
}
|
}
|
||||||
TF_RETURN_IF_ERROR(CreateForwardAndBackwardIODescriptors<T>(
|
TF_RETURN_IF_ERROR(CreateForwardAndBackwardIODescriptors<T>(
|
||||||
context, model_shapes, &input_desc, &state_desc, &output_desc,
|
context, model_shapes, &input_desc, &h_state_desc, &c_state_desc,
|
||||||
seq_lengths, time_major));
|
&output_desc, seq_lengths, time_major));
|
||||||
|
|
||||||
auto input_data = AsDeviceMemory<T>(input);
|
auto input_data = AsDeviceMemory<T>(input);
|
||||||
auto input_h_data = AsDeviceMemory<T>(input_h);
|
auto input_h_data = AsDeviceMemory<T>(input_h);
|
||||||
@ -769,11 +827,11 @@ Status DoForward(OpKernelContext* context, const RnnDescriptor& rnn_desc,
|
|||||||
Stream* stream = context->op_device_context()->stream();
|
Stream* stream = context->op_device_context()->stream();
|
||||||
bool launch_success =
|
bool launch_success =
|
||||||
stream
|
stream
|
||||||
->ThenRnnForward(rnn_desc, *input_desc, input_data, *state_desc,
|
->ThenRnnForward(rnn_desc, *input_desc, input_data, *h_state_desc,
|
||||||
input_h_data, *state_desc, input_c_data, params_data,
|
input_h_data, *c_state_desc, input_c_data,
|
||||||
*output_desc, &output_data, *state_desc,
|
params_data, *output_desc, &output_data,
|
||||||
&output_h_data, *state_desc, &output_c_data,
|
*h_state_desc, &output_h_data, *c_state_desc,
|
||||||
is_training, reserve_space_allocator,
|
&output_c_data, is_training, reserve_space_allocator,
|
||||||
workspace_allocator, output_profile_result)
|
workspace_allocator, output_profile_result)
|
||||||
.ok();
|
.ok();
|
||||||
return launch_success
|
return launch_success
|
||||||
@ -801,7 +859,8 @@ Status DoBackward(
|
|||||||
ScratchAllocator* workspace_allocator,
|
ScratchAllocator* workspace_allocator,
|
||||||
ProfileResult* output_profile_result) {
|
ProfileResult* output_profile_result) {
|
||||||
std::unique_ptr<RnnSequenceTensorDescriptor> input_desc;
|
std::unique_ptr<RnnSequenceTensorDescriptor> input_desc;
|
||||||
std::unique_ptr<RnnStateTensorDescriptor> state_desc;
|
std::unique_ptr<RnnStateTensorDescriptor> h_state_desc;
|
||||||
|
std::unique_ptr<RnnStateTensorDescriptor> c_state_desc;
|
||||||
std::unique_ptr<RnnSequenceTensorDescriptor> output_desc;
|
std::unique_ptr<RnnSequenceTensorDescriptor> output_desc;
|
||||||
|
|
||||||
absl::Span<const int> seq_lengths;
|
absl::Span<const int> seq_lengths;
|
||||||
@ -810,8 +869,8 @@ Status DoBackward(
|
|||||||
sequence_lengths->template flat<int>().data(), model_shapes.batch_size);
|
sequence_lengths->template flat<int>().data(), model_shapes.batch_size);
|
||||||
}
|
}
|
||||||
TF_RETURN_IF_ERROR(CreateForwardAndBackwardIODescriptors<T>(
|
TF_RETURN_IF_ERROR(CreateForwardAndBackwardIODescriptors<T>(
|
||||||
context, model_shapes, &input_desc, &state_desc, &output_desc,
|
context, model_shapes, &input_desc, &h_state_desc, &c_state_desc,
|
||||||
seq_lengths, time_major));
|
&output_desc, seq_lengths, time_major));
|
||||||
|
|
||||||
auto input_data = AsDeviceMemory<T>(input);
|
auto input_data = AsDeviceMemory<T>(input);
|
||||||
auto input_h_data = AsDeviceMemory<T>(input_h);
|
auto input_h_data = AsDeviceMemory<T>(input_h);
|
||||||
@ -847,15 +906,16 @@ Status DoBackward(
|
|||||||
Stream* stream = context->op_device_context()->stream();
|
Stream* stream = context->op_device_context()->stream();
|
||||||
bool launch_success =
|
bool launch_success =
|
||||||
stream
|
stream
|
||||||
->ThenRnnBackward(rnn_desc, *input_desc, input_data, *state_desc,
|
->ThenRnnBackward(rnn_desc, *input_desc, input_data, *h_state_desc,
|
||||||
input_h_data, *state_desc, input_c_data,
|
input_h_data, *c_state_desc, input_c_data,
|
||||||
params_data, *output_desc, output_data, *state_desc,
|
params_data, *output_desc, output_data,
|
||||||
output_h_data, *state_desc, output_c_data,
|
*h_state_desc, output_h_data, *c_state_desc,
|
||||||
output_backprop_data, output_h_backprop_data,
|
output_c_data, output_backprop_data,
|
||||||
output_c_backprop_data, &input_backprop_data,
|
output_h_backprop_data, output_c_backprop_data,
|
||||||
&input_h_backprop_data, &input_c_backprop_data,
|
&input_backprop_data, &input_h_backprop_data,
|
||||||
¶ms_backprop_data, &reserve_space_uint8,
|
&input_c_backprop_data, ¶ms_backprop_data,
|
||||||
workspace_allocator, output_profile_result)
|
&reserve_space_uint8, workspace_allocator,
|
||||||
|
output_profile_result)
|
||||||
.ok();
|
.ok();
|
||||||
return launch_success
|
return launch_success
|
||||||
? Status::OK()
|
? Status::OK()
|
||||||
@ -932,7 +992,7 @@ class CudnnRNNKernelCommon : public OpKernel {
|
|||||||
bool ResetRndGenState() { return reset_rnd_gen_state_; }
|
bool ResetRndGenState() { return reset_rnd_gen_state_; }
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
Status ExtractCudnnRNNParamsInfo(OpKernelContext* context,
|
Status ExtractCudnnRNNParamsInfo(OpKernelContext* context, int num_proj,
|
||||||
std::unique_ptr<RnnDescriptor>* rnn_desc) {
|
std::unique_ptr<RnnDescriptor>* rnn_desc) {
|
||||||
const Tensor* num_layers_t = nullptr;
|
const Tensor* num_layers_t = nullptr;
|
||||||
TF_RETURN_IF_ERROR(context->input("num_layers", &num_layers_t));
|
TF_RETURN_IF_ERROR(context->input("num_layers", &num_layers_t));
|
||||||
@ -953,6 +1013,9 @@ class CudnnRNNKernelCommon : public OpKernel {
|
|||||||
}
|
}
|
||||||
int input_size = input_size_t->scalar<int>()();
|
int input_size = input_size_t->scalar<int>()();
|
||||||
|
|
||||||
|
int h_num_units = (num_proj == 0 ? num_units : num_proj);
|
||||||
|
int c_num_units = (num_proj == 0 ? 0 : num_units);
|
||||||
|
|
||||||
RnnInputMode input_mode;
|
RnnInputMode input_mode;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
ToRNNInputMode(rnn_input_mode(), num_units, input_size, &input_mode));
|
ToRNNInputMode(rnn_input_mode(), num_units, input_size, &input_mode));
|
||||||
@ -962,9 +1025,10 @@ class CudnnRNNKernelCommon : public OpKernel {
|
|||||||
// random number generator, therefore set state_allocator to nullptr.
|
// random number generator, therefore set state_allocator to nullptr.
|
||||||
const AlgorithmConfig algo_config;
|
const AlgorithmConfig algo_config;
|
||||||
auto rnn_desc_s = stream->parent()->createRnnDescriptor(
|
auto rnn_desc_s = stream->parent()->createRnnDescriptor(
|
||||||
num_layers, num_units, input_size, /*batch_size=*/0, input_mode,
|
num_layers, h_num_units, input_size, /*c_size=*/c_num_units,
|
||||||
rnn_direction_mode(), rnn_mode(), ToDataType<T>::value, algo_config,
|
/*batch_size=*/0, input_mode, rnn_direction_mode(), rnn_mode(),
|
||||||
dropout(), seed(), /* state_allocator=*/nullptr);
|
ToDataType<T>::value, algo_config, dropout(), seed(),
|
||||||
|
/* state_allocator=*/nullptr);
|
||||||
if (!rnn_desc_s.ok()) {
|
if (!rnn_desc_s.ok()) {
|
||||||
return FromExecutorStatus(rnn_desc_s);
|
return FromExecutorStatus(rnn_desc_s);
|
||||||
}
|
}
|
||||||
@ -983,9 +1047,9 @@ class CudnnRNNKernelCommon : public OpKernel {
|
|||||||
se::dnn::DataType data_type = ToDataType<T>::value;
|
se::dnn::DataType data_type = ToDataType<T>::value;
|
||||||
auto rnn_desc_s = executor->createRnnDescriptor(
|
auto rnn_desc_s = executor->createRnnDescriptor(
|
||||||
model_shapes.num_layers, model_shapes.num_units,
|
model_shapes.num_layers, model_shapes.num_units,
|
||||||
model_shapes.input_size, model_shapes.batch_size, input_mode,
|
model_shapes.input_size, model_shapes.c_num_units,
|
||||||
rnn_direction_mode(), rnn_mode(), data_type, algo_config, dropout(),
|
model_shapes.batch_size, input_mode, rnn_direction_mode(), rnn_mode(),
|
||||||
seed(), dropout_state_allocator);
|
data_type, algo_config, dropout(), seed(), dropout_state_allocator);
|
||||||
TF_RETURN_IF_ERROR(rnn_desc_s.status());
|
TF_RETURN_IF_ERROR(rnn_desc_s.status());
|
||||||
|
|
||||||
*rnn_desc = rnn_desc_s.ConsumeValueOrDie();
|
*rnn_desc = rnn_desc_s.ConsumeValueOrDie();
|
||||||
@ -1035,11 +1099,18 @@ template <typename T, typename Index>
|
|||||||
class CudnnRNNParamsSizeOp<GPUDevice, T, Index> : public CudnnRNNKernelCommon {
|
class CudnnRNNParamsSizeOp<GPUDevice, T, Index> : public CudnnRNNKernelCommon {
|
||||||
public:
|
public:
|
||||||
explicit CudnnRNNParamsSizeOp(OpKernelConstruction* context)
|
explicit CudnnRNNParamsSizeOp(OpKernelConstruction* context)
|
||||||
: CudnnRNNKernelCommon(context) {}
|
: CudnnRNNKernelCommon(context) {
|
||||||
|
if (context->HasAttr("num_proj")) {
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("num_proj", &num_proj_));
|
||||||
|
} else {
|
||||||
|
num_proj_ = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
std::unique_ptr<RnnDescriptor> rnn_desc;
|
std::unique_ptr<RnnDescriptor> rnn_desc;
|
||||||
OP_REQUIRES_OK(context, ExtractCudnnRNNParamsInfo<T>(context, &rnn_desc));
|
OP_REQUIRES_OK(context,
|
||||||
|
ExtractCudnnRNNParamsInfo<T>(context, num_proj_, &rnn_desc));
|
||||||
int64 params_size_in_bytes = rnn_desc->ParamsSizeInBytes();
|
int64 params_size_in_bytes = rnn_desc->ParamsSizeInBytes();
|
||||||
CHECK(params_size_in_bytes % sizeof(T) == 0)
|
CHECK(params_size_in_bytes % sizeof(T) == 0)
|
||||||
<< "params_size_in_bytes must be multiple of element size";
|
<< "params_size_in_bytes must be multiple of element size";
|
||||||
@ -1049,6 +1120,9 @@ class CudnnRNNParamsSizeOp<GPUDevice, T, Index> : public CudnnRNNKernelCommon {
|
|||||||
OP_REQUIRES_OK(context, context->allocate_output(0, {1}, &output_t));
|
OP_REQUIRES_OK(context, context->allocate_output(0, {1}, &output_t));
|
||||||
*output_t->template flat<Index>().data() = params_size;
|
*output_t->template flat<Index>().data() = params_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int num_proj_;
|
||||||
};
|
};
|
||||||
|
|
||||||
#define REGISTER_GPU(T) \
|
#define REGISTER_GPU(T) \
|
||||||
@ -1074,7 +1148,33 @@ class CudnnRNNParamsToCanonical<GPUDevice, T> : public CudnnRNNKernelCommon {
|
|||||||
public:
|
public:
|
||||||
explicit CudnnRNNParamsToCanonical(OpKernelConstruction* context)
|
explicit CudnnRNNParamsToCanonical(OpKernelConstruction* context)
|
||||||
: CudnnRNNKernelCommon(context) {
|
: CudnnRNNKernelCommon(context) {
|
||||||
OP_REQUIRES_OK(context, context->GetAttr("num_params", &num_params_));
|
if (context->HasAttr("num_params")) {
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("num_params", &num_params_));
|
||||||
|
} else {
|
||||||
|
num_params_ = 0;
|
||||||
|
}
|
||||||
|
if (context->HasAttr("num_params_weights")) {
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("num_params_weights",
|
||||||
|
&num_params_weights_));
|
||||||
|
} else {
|
||||||
|
num_params_weights_ = 0;
|
||||||
|
}
|
||||||
|
if (context->HasAttr("num_params_biases")) {
|
||||||
|
OP_REQUIRES_OK(context,
|
||||||
|
context->GetAttr("num_params_biases",
|
||||||
|
&num_params_biases_));
|
||||||
|
} else {
|
||||||
|
num_params_biases_ = 0;
|
||||||
|
}
|
||||||
|
if (context->HasAttr("num_proj")) {
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("num_proj", &num_proj_));
|
||||||
|
} else {
|
||||||
|
num_proj_ = 0;
|
||||||
|
}
|
||||||
|
if (num_proj_ == 0) {
|
||||||
|
num_params_weights_ = num_params_;
|
||||||
|
num_params_biases_ = num_params_;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
@ -1083,7 +1183,8 @@ class CudnnRNNParamsToCanonical<GPUDevice, T> : public CudnnRNNKernelCommon {
|
|||||||
Stream* stream = context->op_device_context()->stream();
|
Stream* stream = context->op_device_context()->stream();
|
||||||
|
|
||||||
std::unique_ptr<RnnDescriptor> rnn_desc;
|
std::unique_ptr<RnnDescriptor> rnn_desc;
|
||||||
OP_REQUIRES_OK(context, ExtractCudnnRNNParamsInfo<T>(context, &rnn_desc));
|
OP_REQUIRES_OK(context,
|
||||||
|
ExtractCudnnRNNParamsInfo<T>(context, num_proj_, &rnn_desc));
|
||||||
int64 params_size_in_bytes = rnn_desc->ParamsSizeInBytes();
|
int64 params_size_in_bytes = rnn_desc->ParamsSizeInBytes();
|
||||||
CHECK(params_size_in_bytes % sizeof(T) == 0)
|
CHECK(params_size_in_bytes % sizeof(T) == 0)
|
||||||
<< "params_size_in_bytes must be multiple of element size";
|
<< "params_size_in_bytes must be multiple of element size";
|
||||||
@ -1109,25 +1210,38 @@ class CudnnRNNParamsToCanonical<GPUDevice, T> : public CudnnRNNKernelCommon {
|
|||||||
if (rnn_direction_mode() == RnnDirectionMode::kRnnBidirectional) {
|
if (rnn_direction_mode() == RnnDirectionMode::kRnnBidirectional) {
|
||||||
num_dirs = 2;
|
num_dirs = 2;
|
||||||
}
|
}
|
||||||
const int num_params_per_layer = num_params_ / num_layers / num_dirs;
|
const int num_params_weights_per_layer =
|
||||||
|
num_params_weights_ / num_layers / num_dirs;
|
||||||
// Number of params applied on inputs. The rest are applied on recurrent
|
// Number of params applied on inputs. The rest are applied on recurrent
|
||||||
// hidden states.
|
// hidden states.
|
||||||
const int num_params_input_state = num_params_per_layer / 2;
|
const int num_params_input_state = num_params_weights_per_layer / 2;
|
||||||
CHECK(num_params_ % (num_layers * num_dirs) == 0)
|
CHECK(num_params_weights_ % (num_layers * num_dirs) == 0)
|
||||||
<< "Number of params is not a multiple of num_layers * num_dirs.";
|
<< "Number of params (weights) is not a multiple of num_layers * "
|
||||||
CHECK(num_params_per_layer % 2 == 0)
|
"num_dirs.";
|
||||||
<< "Number of params per layer is not a even number.";
|
CHECK(num_params_biases_ % (num_layers * num_dirs) == 0)
|
||||||
|
<< "Number of params (bias) is not a multiple of num_layers * "
|
||||||
|
"num_dirs.";
|
||||||
|
if (num_proj_ == 0) {
|
||||||
|
CHECK(num_params_weights_per_layer % 2 == 0)
|
||||||
|
<< "Number of params per layer is not a even number w/o projection.";
|
||||||
|
} else {
|
||||||
|
CHECK(num_params_weights_per_layer % 2 != 0)
|
||||||
|
<< "Number of params per layer is not a odd number w/ projection.";
|
||||||
|
}
|
||||||
|
|
||||||
CHECK(num_params_ == rnn_desc->ParamsWeightRegions().size())
|
CHECK(num_params_weights_ == rnn_desc->ParamsWeightRegions().size())
|
||||||
<< "Number of params mismatch. Expected " << num_params_ << ", got "
|
<< "C Number of params mismatch. Expected " << num_params_weights_
|
||||||
<< rnn_desc->ParamsWeightRegions().size();
|
<< ", got " << rnn_desc->ParamsWeightRegions().size();
|
||||||
|
int h_num_units = (num_proj_ == 0 ? num_units : num_proj_);
|
||||||
|
int c_num_units = (num_proj_ == 0 ? 0 : num_units);
|
||||||
for (int i = 0; i < rnn_desc->ParamsWeightRegions().size(); i++) {
|
for (int i = 0; i < rnn_desc->ParamsWeightRegions().size(); i++) {
|
||||||
int64 size_in_bytes = rnn_desc->ParamsWeightRegions()[i].size;
|
int64 size_in_bytes = rnn_desc->ParamsWeightRegions()[i].size;
|
||||||
int64 size = size_in_bytes / sizeof(T);
|
int64 size = size_in_bytes / sizeof(T);
|
||||||
const int layer_idx = i / num_params_per_layer;
|
const int layer_idx = i / num_params_weights_per_layer;
|
||||||
const int index_within_layer = i % num_params_per_layer;
|
const int index_within_layer = i % num_params_weights_per_layer;
|
||||||
int width = 0, height = num_units;
|
int width = 0, height = (num_proj_ == 0 ? h_num_units : c_num_units);
|
||||||
// In CuDNN layout, each layer has num_params_per_layer params, with the
|
// In CuDNN layout, each layer has num_params_weights_per_layer params,
|
||||||
|
// with the
|
||||||
// first half a.k.a num_params_input_state params applied on the inputs,
|
// first half a.k.a num_params_input_state params applied on the inputs,
|
||||||
// and the second half on the recurrent hidden states.
|
// and the second half on the recurrent hidden states.
|
||||||
bool apply_on_input_state = index_within_layer < num_params_input_state;
|
bool apply_on_input_state = index_within_layer < num_params_input_state;
|
||||||
@ -1135,7 +1249,7 @@ class CudnnRNNParamsToCanonical<GPUDevice, T> : public CudnnRNNKernelCommon {
|
|||||||
if (layer_idx == 0 && apply_on_input_state) {
|
if (layer_idx == 0 && apply_on_input_state) {
|
||||||
width = input_size;
|
width = input_size;
|
||||||
} else {
|
} else {
|
||||||
width = num_units;
|
width = h_num_units;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (apply_on_input_state) {
|
if (apply_on_input_state) {
|
||||||
@ -1145,15 +1259,19 @@ class CudnnRNNParamsToCanonical<GPUDevice, T> : public CudnnRNNKernelCommon {
|
|||||||
} else {
|
} else {
|
||||||
// Following layers, cell inputs are concatenated outputs of
|
// Following layers, cell inputs are concatenated outputs of
|
||||||
// its prior layer.
|
// its prior layer.
|
||||||
width = 2 * num_units;
|
width = 2 * h_num_units;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
width = num_units;
|
width = h_num_units;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
CHECK(size == width * height) << "Params size mismatch. Expected "
|
CHECK(size == width * height) << "Params size mismatch. Expected "
|
||||||
<< width * height << ", got " << size;
|
<< width * height << ", got " << size;
|
||||||
Tensor* output = nullptr;
|
Tensor* output = nullptr;
|
||||||
|
int id_in_layer = i % num_params_weights_per_layer;
|
||||||
|
if (num_proj_ != 0 && id_in_layer == num_params_weights_per_layer-1) {
|
||||||
|
std::swap(height, width);
|
||||||
|
}
|
||||||
OP_REQUIRES_OK(context, context->allocate_output(
|
OP_REQUIRES_OK(context, context->allocate_output(
|
||||||
i, TensorShape({height, width}), &output));
|
i, TensorShape({height, width}), &output));
|
||||||
DeviceMemoryBase data_src_ptr = SliceDeviceMemory(
|
DeviceMemoryBase data_src_ptr = SliceDeviceMemory(
|
||||||
@ -1162,10 +1280,11 @@ class CudnnRNNParamsToCanonical<GPUDevice, T> : public CudnnRNNKernelCommon {
|
|||||||
stream->ThenMemcpy(&data_dst_ptr, data_src_ptr, size_in_bytes);
|
stream->ThenMemcpy(&data_dst_ptr, data_src_ptr, size_in_bytes);
|
||||||
}
|
}
|
||||||
|
|
||||||
OP_REQUIRES(context, num_params_ == rnn_desc->ParamsBiasRegions().size(),
|
OP_REQUIRES(
|
||||||
errors::InvalidArgument("Number of params mismatch. Expected ",
|
context, num_params_biases_ == rnn_desc->ParamsBiasRegions().size(),
|
||||||
num_params_, ", got ",
|
errors::InvalidArgument("A Number of params mismatch. Expected ",
|
||||||
rnn_desc->ParamsBiasRegions().size()));
|
num_params_biases_, ", got ",
|
||||||
|
rnn_desc->ParamsBiasRegions().size()));
|
||||||
for (int i = 0; i < rnn_desc->ParamsBiasRegions().size(); i++) {
|
for (int i = 0; i < rnn_desc->ParamsBiasRegions().size(); i++) {
|
||||||
int64 size_in_bytes = rnn_desc->ParamsBiasRegions()[i].size;
|
int64 size_in_bytes = rnn_desc->ParamsBiasRegions()[i].size;
|
||||||
int64 size = size_in_bytes / sizeof(T);
|
int64 size = size_in_bytes / sizeof(T);
|
||||||
@ -1175,7 +1294,7 @@ class CudnnRNNParamsToCanonical<GPUDevice, T> : public CudnnRNNKernelCommon {
|
|||||||
|
|
||||||
Tensor* output = nullptr;
|
Tensor* output = nullptr;
|
||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
context->allocate_output(num_params_ + i,
|
context->allocate_output(num_params_weights_ + i,
|
||||||
TensorShape({size}), &output));
|
TensorShape({size}), &output));
|
||||||
DeviceMemoryBase data_src_ptr = SliceDeviceMemory(
|
DeviceMemoryBase data_src_ptr = SliceDeviceMemory(
|
||||||
input_ptr, rnn_desc->ParamsBiasRegions()[i].offset, size_in_bytes);
|
input_ptr, rnn_desc->ParamsBiasRegions()[i].offset, size_in_bytes);
|
||||||
@ -1186,6 +1305,9 @@ class CudnnRNNParamsToCanonical<GPUDevice, T> : public CudnnRNNKernelCommon {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
int num_params_;
|
int num_params_;
|
||||||
|
int num_params_weights_;
|
||||||
|
int num_params_biases_;
|
||||||
|
int num_proj_;
|
||||||
};
|
};
|
||||||
|
|
||||||
#define REGISTER_GPU(T) \
|
#define REGISTER_GPU(T) \
|
||||||
@ -1201,17 +1323,37 @@ TF_CALL_float(REGISTER_GPU);
|
|||||||
TF_CALL_double(REGISTER_GPU);
|
TF_CALL_double(REGISTER_GPU);
|
||||||
#undef REGISTER_GPU
|
#undef REGISTER_GPU
|
||||||
|
|
||||||
|
#define REGISTER_GPU(T) \
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("CudnnRNNParamsToCanonicalV2") \
|
||||||
|
.Device(DEVICE_GPU) \
|
||||||
|
.HostMemory("num_layers") \
|
||||||
|
.HostMemory("num_units") \
|
||||||
|
.HostMemory("input_size") \
|
||||||
|
.TypeConstraint<T>("T"), \
|
||||||
|
CudnnRNNParamsToCanonical<GPUDevice, T>);
|
||||||
|
TF_CALL_half(REGISTER_GPU);
|
||||||
|
TF_CALL_float(REGISTER_GPU);
|
||||||
|
TF_CALL_double(REGISTER_GPU);
|
||||||
|
#undef REGISTER_GPU
|
||||||
|
|
||||||
// Convert weight and bias params from the canonical form to a
|
// Convert weight and bias params from the canonical form to a
|
||||||
// platform-specific layout.
|
// platform-specific layout.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
class CudnnRNNCanonicalToParams<GPUDevice, T> : public CudnnRNNKernelCommon {
|
class CudnnRNNCanonicalToParams<GPUDevice, T> : public CudnnRNNKernelCommon {
|
||||||
public:
|
public:
|
||||||
explicit CudnnRNNCanonicalToParams(OpKernelConstruction* context)
|
explicit CudnnRNNCanonicalToParams(OpKernelConstruction* context)
|
||||||
: CudnnRNNKernelCommon(context) {}
|
: CudnnRNNKernelCommon(context) {
|
||||||
|
if (context->HasAttr("num_proj")) {
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("num_proj", &num_proj_));
|
||||||
|
} else {
|
||||||
|
num_proj_ = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
std::unique_ptr<RnnDescriptor> rnn_desc;
|
std::unique_ptr<RnnDescriptor> rnn_desc;
|
||||||
OP_REQUIRES_OK(context, ExtractCudnnRNNParamsInfo<T>(context, &rnn_desc));
|
OP_REQUIRES_OK(context,
|
||||||
|
ExtractCudnnRNNParamsInfo<T>(context, num_proj_, &rnn_desc));
|
||||||
int64 params_size_in_bytes = rnn_desc->ParamsSizeInBytes();
|
int64 params_size_in_bytes = rnn_desc->ParamsSizeInBytes();
|
||||||
CHECK(params_size_in_bytes % sizeof(T) == 0)
|
CHECK(params_size_in_bytes % sizeof(T) == 0)
|
||||||
<< "params_size_in_bytes must be multiple of element size";
|
<< "params_size_in_bytes must be multiple of element size";
|
||||||
@ -1232,6 +1374,9 @@ class CudnnRNNCanonicalToParams<GPUDevice, T> : public CudnnRNNKernelCommon {
|
|||||||
RestoreParams<T>(biases, rnn_desc->ParamsBiasRegions(), &output_ptr,
|
RestoreParams<T>(biases, rnn_desc->ParamsBiasRegions(), &output_ptr,
|
||||||
stream);
|
stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int num_proj_;
|
||||||
};
|
};
|
||||||
|
|
||||||
#define REGISTER_GPU(T) \
|
#define REGISTER_GPU(T) \
|
||||||
@ -1247,6 +1392,19 @@ TF_CALL_float(REGISTER_GPU);
|
|||||||
TF_CALL_double(REGISTER_GPU);
|
TF_CALL_double(REGISTER_GPU);
|
||||||
#undef REGISTER_GPU
|
#undef REGISTER_GPU
|
||||||
|
|
||||||
|
#define REGISTER_GPU(T) \
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("CudnnRNNCanonicalToParamsV2") \
|
||||||
|
.Device(DEVICE_GPU) \
|
||||||
|
.HostMemory("num_layers") \
|
||||||
|
.HostMemory("num_units") \
|
||||||
|
.HostMemory("input_size") \
|
||||||
|
.TypeConstraint<T>("T"), \
|
||||||
|
CudnnRNNCanonicalToParams<GPUDevice, T>);
|
||||||
|
TF_CALL_half(REGISTER_GPU);
|
||||||
|
TF_CALL_float(REGISTER_GPU);
|
||||||
|
TF_CALL_double(REGISTER_GPU);
|
||||||
|
#undef REGISTER_GPU
|
||||||
|
|
||||||
// Run the forward operation of the RNN model.
|
// Run the forward operation of the RNN model.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
|
class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
|
||||||
@ -1264,14 +1422,14 @@ class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
|
|||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
AlgorithmConfig algo_config;
|
AlgorithmConfig algo_config;
|
||||||
ComputeAndReturnAlgorithm(context, &algo_config, /*var_seq_lengths=*/false,
|
ComputeAndReturnAlgorithm(context, &algo_config, /*var_seq_lengths=*/false,
|
||||||
/*time_major=*/true);
|
/*time_major=*/true, /*num_proj=*/0);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
virtual void ComputeAndReturnAlgorithm(OpKernelContext* context,
|
virtual void ComputeAndReturnAlgorithm(OpKernelContext* context,
|
||||||
AlgorithmConfig* output_algo_config,
|
AlgorithmConfig* output_algo_config,
|
||||||
bool var_seq_lengths,
|
bool var_seq_lengths,
|
||||||
bool time_major) {
|
bool time_major, int num_proj) {
|
||||||
CHECK_NE(output_algo_config, nullptr);
|
CHECK_NE(output_algo_config, nullptr);
|
||||||
|
|
||||||
const Tensor* input = nullptr;
|
const Tensor* input = nullptr;
|
||||||
@ -1284,11 +1442,13 @@ class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
|
|||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
ExtractForwardInput(context, model_types(), time_major,
|
ExtractForwardInput(context, model_types(), time_major,
|
||||||
&input, &input_h, &input_c, ¶ms,
|
&input, &input_h, &input_c, ¶ms,
|
||||||
&sequence_lengths, &model_shapes));
|
&sequence_lengths, num_proj,
|
||||||
|
&model_shapes));
|
||||||
} else {
|
} else {
|
||||||
OP_REQUIRES_OK(context, ExtractForwardInput(
|
OP_REQUIRES_OK(context, ExtractForwardInput(
|
||||||
context, model_types(), time_major, &input,
|
context, model_types(), time_major, &input,
|
||||||
&input_h, &input_c, ¶ms, &model_shapes));
|
&input_h, &input_c, ¶ms, num_proj,
|
||||||
|
&model_shapes));
|
||||||
}
|
}
|
||||||
RnnInputMode input_mode;
|
RnnInputMode input_mode;
|
||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
@ -1362,13 +1522,14 @@ class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
|
|||||||
Tensor** output_c) {
|
Tensor** output_c) {
|
||||||
const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape;
|
const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape;
|
||||||
const TensorShape& output_shape = model_shapes.output_shape;
|
const TensorShape& output_shape = model_shapes.output_shape;
|
||||||
|
const TensorShape& c_state_shape = model_shapes.c_state_shape;
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(context->allocate_output(0, output_shape, output));
|
TF_RETURN_IF_ERROR(context->allocate_output(0, output_shape, output));
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
context->allocate_output(1, hidden_state_shape, output_h));
|
context->allocate_output(1, hidden_state_shape, output_h));
|
||||||
if (HasInputC()) {
|
if (HasInputC()) {
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
context->allocate_output(2, hidden_state_shape, output_c));
|
context->allocate_output(2, c_state_shape, output_c));
|
||||||
} else {
|
} else {
|
||||||
// Only LSTM uses input_c and output_c. So for all other models, we only
|
// Only LSTM uses input_c and output_c. So for all other models, we only
|
||||||
// need to create dummy outputs.
|
// need to create dummy outputs.
|
||||||
@ -1414,7 +1575,7 @@ class CudnnRNNForwardOpV2<GPUDevice, T>
|
|||||||
AlgorithmConfig best_algo_config;
|
AlgorithmConfig best_algo_config;
|
||||||
CudnnRNNForwardOp<GPUDevice, T>::ComputeAndReturnAlgorithm(
|
CudnnRNNForwardOp<GPUDevice, T>::ComputeAndReturnAlgorithm(
|
||||||
context, &best_algo_config, /*var_seq_lengths=*/false,
|
context, &best_algo_config, /*var_seq_lengths=*/false,
|
||||||
/*time_major=*/true);
|
/*time_major=*/true, /*num_proj=*/0);
|
||||||
if (!context->status().ok()) {
|
if (!context->status().ok()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -1613,13 +1774,18 @@ class CudnnRNNForwardOpV3<GPUDevice, T>
|
|||||||
explicit CudnnRNNForwardOpV3(OpKernelConstruction* context)
|
explicit CudnnRNNForwardOpV3(OpKernelConstruction* context)
|
||||||
: CudnnRNNForwardOp<GPUDevice, T>(context) {
|
: CudnnRNNForwardOp<GPUDevice, T>(context) {
|
||||||
OP_REQUIRES_OK(context, context->GetAttr("time_major", &time_major_));
|
OP_REQUIRES_OK(context, context->GetAttr("time_major", &time_major_));
|
||||||
|
if (context->HasAttr("num_proj")) {
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("num_proj", &num_proj_));
|
||||||
|
} else {
|
||||||
|
num_proj_ = 0;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
AlgorithmConfig best_algo_config;
|
AlgorithmConfig best_algo_config;
|
||||||
CudnnRNNForwardOp<GPUDevice, T>::ComputeAndReturnAlgorithm(
|
CudnnRNNForwardOp<GPUDevice, T>::ComputeAndReturnAlgorithm(
|
||||||
context, &best_algo_config, /*var_seq_lengths=*/true,
|
context, &best_algo_config, /*var_seq_lengths=*/true,
|
||||||
/*time_major=*/time_major());
|
/*time_major=*/time_major(), num_proj_);
|
||||||
if (!context->status().ok()) {
|
if (!context->status().ok()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -1631,6 +1797,9 @@ class CudnnRNNForwardOpV3<GPUDevice, T>
|
|||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
context->allocate_output(4, {}, &output_host_reserved));
|
context->allocate_output(4, {}, &output_host_reserved));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int num_proj_;
|
||||||
};
|
};
|
||||||
|
|
||||||
#define REGISTER_GPU(T) \
|
#define REGISTER_GPU(T) \
|
||||||
@ -1654,12 +1823,12 @@ class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
|
|||||||
: CudnnRNNKernelCommon(context) {}
|
: CudnnRNNKernelCommon(context) {}
|
||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
ComputeImpl(context, false, true);
|
ComputeImpl(context, false, true, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
virtual void ComputeImpl(OpKernelContext* context, bool var_seq_lengths,
|
virtual void ComputeImpl(OpKernelContext* context, bool var_seq_lengths,
|
||||||
bool time_major) {
|
bool time_major, int num_proj) {
|
||||||
const Tensor* input = nullptr;
|
const Tensor* input = nullptr;
|
||||||
const Tensor* input_h = nullptr;
|
const Tensor* input_h = nullptr;
|
||||||
const Tensor* input_c = nullptr;
|
const Tensor* input_c = nullptr;
|
||||||
@ -1670,11 +1839,13 @@ class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
|
|||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
ExtractForwardInput(context, model_types(), time_major,
|
ExtractForwardInput(context, model_types(), time_major,
|
||||||
&input, &input_h, &input_c, ¶ms,
|
&input, &input_h, &input_c, ¶ms,
|
||||||
&sequence_lengths, &model_shapes));
|
&sequence_lengths, num_proj,
|
||||||
|
&model_shapes));
|
||||||
} else {
|
} else {
|
||||||
OP_REQUIRES_OK(context, ExtractForwardInput(
|
OP_REQUIRES_OK(context, ExtractForwardInput(
|
||||||
context, model_types(), time_major, &input,
|
context, model_types(), time_major, &input,
|
||||||
&input_h, &input_c, ¶ms, &model_shapes));
|
&input_h, &input_c, ¶ms, num_proj,
|
||||||
|
&model_shapes));
|
||||||
}
|
}
|
||||||
RnnInputMode input_mode;
|
RnnInputMode input_mode;
|
||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
@ -1757,6 +1928,7 @@ class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
|
|||||||
TF_RETURN_IF_ERROR(context->input("reserve_space", reserve_space));
|
TF_RETURN_IF_ERROR(context->input("reserve_space", reserve_space));
|
||||||
const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape;
|
const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape;
|
||||||
const TensorShape& output_shape = model_shapes.output_shape;
|
const TensorShape& output_shape = model_shapes.output_shape;
|
||||||
|
const TensorShape& c_state_shape = model_shapes.c_state_shape;
|
||||||
|
|
||||||
if (output_shape != (*output)->shape()) {
|
if (output_shape != (*output)->shape()) {
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
@ -1782,16 +1954,16 @@ class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (model_types.HasInputC()) {
|
if (model_types.HasInputC()) {
|
||||||
if (hidden_state_shape != (*output_c)->shape()) {
|
if (c_state_shape != (*output_c)->shape()) {
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"Invalid output_c shape: ", (*output_c)->shape().DebugString(), " ",
|
"Invalid output_c shape: ", (*output_c)->shape().DebugString(), " ",
|
||||||
hidden_state_shape.DebugString());
|
c_state_shape.DebugString());
|
||||||
}
|
}
|
||||||
if (hidden_state_shape != (*output_c_backprop)->shape()) {
|
if (c_state_shape != (*output_c_backprop)->shape()) {
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"Invalid output_c_backprop shape: ",
|
"Invalid output_c_backprop shape: ",
|
||||||
(*output_c_backprop)->shape().DebugString(), " ",
|
(*output_c_backprop)->shape().DebugString(), " ",
|
||||||
hidden_state_shape.DebugString());
|
c_state_shape.DebugString());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -1804,6 +1976,7 @@ class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
|
|||||||
Tensor** input_c_backprop, Tensor** params_backprop) {
|
Tensor** input_c_backprop, Tensor** params_backprop) {
|
||||||
const TensorShape& input_shape = model_shapes.input_shape;
|
const TensorShape& input_shape = model_shapes.input_shape;
|
||||||
const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape;
|
const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape;
|
||||||
|
const TensorShape& c_state_shape = model_shapes.c_state_shape;
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
context->allocate_output(0, input_shape, input_backprop));
|
context->allocate_output(0, input_shape, input_backprop));
|
||||||
@ -1811,7 +1984,7 @@ class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
|
|||||||
context->allocate_output(1, hidden_state_shape, input_h_backprop));
|
context->allocate_output(1, hidden_state_shape, input_h_backprop));
|
||||||
if (HasInputC()) {
|
if (HasInputC()) {
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
context->allocate_output(2, hidden_state_shape, input_c_backprop));
|
context->allocate_output(2, c_state_shape, input_c_backprop));
|
||||||
} else {
|
} else {
|
||||||
// Only LSTM uses input_c and output_c. So for all other models, we only
|
// Only LSTM uses input_c and output_c. So for all other models, we only
|
||||||
// need to create dummy outputs.
|
// need to create dummy outputs.
|
||||||
@ -1879,11 +2052,20 @@ class CudnnRNNBackwardOpV3<GPUDevice, T>
|
|||||||
explicit CudnnRNNBackwardOpV3(OpKernelConstruction* context)
|
explicit CudnnRNNBackwardOpV3(OpKernelConstruction* context)
|
||||||
: CudnnRNNBackwardOp<GPUDevice, T>(context) {
|
: CudnnRNNBackwardOp<GPUDevice, T>(context) {
|
||||||
OP_REQUIRES_OK(context, context->GetAttr("time_major", &time_major_));
|
OP_REQUIRES_OK(context, context->GetAttr("time_major", &time_major_));
|
||||||
|
if (context->HasAttr("num_proj")) {
|
||||||
|
OP_REQUIRES_OK(context, context->GetAttr("num_proj", &num_proj_));
|
||||||
|
} else {
|
||||||
|
num_proj_ = 0;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
CudnnRNNBackwardOp<GPUDevice, T>::ComputeImpl(context, true, time_major());
|
CudnnRNNBackwardOp<GPUDevice, T>::ComputeImpl(context, true, time_major(),
|
||||||
|
num_proj_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
int num_proj_;
|
||||||
};
|
};
|
||||||
|
|
||||||
#define REGISTER_GPU(T) \
|
#define REGISTER_GPU(T) \
|
||||||
|
@ -49,6 +49,7 @@ REGISTER_OP("CudnnRNNParamsSize")
|
|||||||
.Attr("dropout: float = 0.0")
|
.Attr("dropout: float = 0.0")
|
||||||
.Attr("seed: int = 0")
|
.Attr("seed: int = 0")
|
||||||
.Attr("seed2: int = 0")
|
.Attr("seed2: int = 0")
|
||||||
|
.Attr("num_proj: int = 0")
|
||||||
.Output("params_size: S")
|
.Output("params_size: S")
|
||||||
.SetShapeFn([](InferenceContext* c) {
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
ShapeHandle unused;
|
ShapeHandle unused;
|
||||||
@ -166,11 +167,13 @@ REGISTER_OP("CudnnRNNV3")
|
|||||||
.Attr("dropout: float = 0.0")
|
.Attr("dropout: float = 0.0")
|
||||||
.Attr("seed: int = 0")
|
.Attr("seed: int = 0")
|
||||||
.Attr("seed2: int = 0")
|
.Attr("seed2: int = 0")
|
||||||
|
.Attr("num_proj: int = 0")
|
||||||
.Attr("is_training: bool = true")
|
.Attr("is_training: bool = true")
|
||||||
.Attr("time_major: bool = true")
|
.Attr("time_major: bool = true")
|
||||||
.SetShapeFn([](InferenceContext* c) {
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
auto input_shape = c->input(0);
|
auto input_shape = c->input(0);
|
||||||
auto input_h_shape = c->input(1);
|
auto input_h_shape = c->input(1);
|
||||||
|
auto input_c_shape = c->input(2);
|
||||||
auto max_seq_length = c->Dim(input_shape, 0);
|
auto max_seq_length = c->Dim(input_shape, 0);
|
||||||
auto batch_size = c->Dim(input_shape, 1);
|
auto batch_size = c->Dim(input_shape, 1);
|
||||||
auto num_units = c->Dim(input_h_shape, 2);
|
auto num_units = c->Dim(input_h_shape, 2);
|
||||||
@ -185,7 +188,7 @@ REGISTER_OP("CudnnRNNV3")
|
|||||||
c->MakeShape({max_seq_length, batch_size, output_size});
|
c->MakeShape({max_seq_length, batch_size, output_size});
|
||||||
auto output_h_shape = input_h_shape;
|
auto output_h_shape = input_h_shape;
|
||||||
auto output_c_shape TF_ATTRIBUTE_UNUSED =
|
auto output_c_shape TF_ATTRIBUTE_UNUSED =
|
||||||
(rnn_mode == "lstm") ? output_h_shape : c->MakeShape({});
|
(rnn_mode == "lstm") ? input_c_shape : c->MakeShape({});
|
||||||
c->set_output(0, output_shape);
|
c->set_output(0, output_shape);
|
||||||
c->set_output(1, output_h_shape);
|
c->set_output(1, output_h_shape);
|
||||||
c->set_output(2, output_c_shape);
|
c->set_output(2, output_c_shape);
|
||||||
@ -293,6 +296,7 @@ REGISTER_OP("CudnnRNNBackpropV3")
|
|||||||
.Attr("dropout: float = 0.0")
|
.Attr("dropout: float = 0.0")
|
||||||
.Attr("seed: int = 0")
|
.Attr("seed: int = 0")
|
||||||
.Attr("seed2: int = 0")
|
.Attr("seed2: int = 0")
|
||||||
|
.Attr("num_proj: int = 0")
|
||||||
.Attr("time_major: bool = true")
|
.Attr("time_major: bool = true")
|
||||||
.SetShapeFn([](InferenceContext* c) {
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
auto input_shape = c->input(0);
|
auto input_shape = c->input(0);
|
||||||
@ -338,6 +342,43 @@ REGISTER_OP("CudnnRNNParamsToCanonical")
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
REGISTER_OP("CudnnRNNParamsToCanonicalV2")
|
||||||
|
.Input("num_layers: int32")
|
||||||
|
.Input("num_units: int32")
|
||||||
|
.Input("input_size: int32")
|
||||||
|
.Input("params: T")
|
||||||
|
.Output("weights: num_params_weights * T")
|
||||||
|
.Output("biases: num_params_biases * T")
|
||||||
|
.Attr("T: {float16, float32, float64}")
|
||||||
|
.Attr("num_params_weights: int")
|
||||||
|
.Attr("num_params_biases: int")
|
||||||
|
.Attr(kRNNModeAttrs)
|
||||||
|
.Attr(kRNNInputModeAttrs)
|
||||||
|
.Attr(kRNNDirectionAttrs)
|
||||||
|
.Attr("dropout: float = 0.0")
|
||||||
|
.Attr("seed: int = 0")
|
||||||
|
.Attr("seed2: int = 0")
|
||||||
|
.Attr("num_proj: int = 0")
|
||||||
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
|
ShapeHandle unused;
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &unused));
|
||||||
|
int num_params_weights;
|
||||||
|
int num_params_biases;
|
||||||
|
TF_RETURN_IF_ERROR(c->GetAttr("num_params_weights", &num_params_weights));
|
||||||
|
TF_RETURN_IF_ERROR(c->GetAttr("num_params_biases", &num_params_biases));
|
||||||
|
// Set shape for weight matrices
|
||||||
|
for (int i = 0; i < num_params_weights; i++) {
|
||||||
|
c->set_output(i, c->Matrix(InferenceContext::kUnknownDim,
|
||||||
|
InferenceContext::kUnknownDim));
|
||||||
|
}
|
||||||
|
// Set shape for bias vectors
|
||||||
|
for (int i = 0; i < num_params_biases; i++) {
|
||||||
|
c->set_output(num_params_weights + i,
|
||||||
|
c->Vector(InferenceContext::kUnknownDim));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
});
|
||||||
|
|
||||||
REGISTER_OP("CudnnRNNCanonicalToParams")
|
REGISTER_OP("CudnnRNNCanonicalToParams")
|
||||||
.Input("num_layers: int32")
|
.Input("num_layers: int32")
|
||||||
.Input("num_units: int32")
|
.Input("num_units: int32")
|
||||||
@ -358,4 +399,26 @@ REGISTER_OP("CudnnRNNCanonicalToParams")
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
});
|
});
|
||||||
|
|
||||||
|
REGISTER_OP("CudnnRNNCanonicalToParamsV2")
|
||||||
|
.Input("num_layers: int32")
|
||||||
|
.Input("num_units: int32")
|
||||||
|
.Input("input_size: int32")
|
||||||
|
.Input("weights: num_params_weights * T")
|
||||||
|
.Input("biases: num_params_biases * T")
|
||||||
|
.Output("params: T")
|
||||||
|
.Attr("T: {float16, float32, float64}")
|
||||||
|
.Attr("num_params_weights: int")
|
||||||
|
.Attr("num_params_biases: int")
|
||||||
|
.Attr(kRNNModeAttrs)
|
||||||
|
.Attr(kRNNInputModeAttrs)
|
||||||
|
.Attr(kRNNDirectionAttrs)
|
||||||
|
.Attr("dropout: float = 0.0")
|
||||||
|
.Attr("seed: int = 0")
|
||||||
|
.Attr("seed2: int = 0")
|
||||||
|
.Attr("num_proj: int = 0")
|
||||||
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
|
c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
|
||||||
|
return Status::OK();
|
||||||
|
});
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -98,6 +98,7 @@ def _cudnn_rnn_backwardv3(op, *grads):
|
|||||||
seed=op.get_attr("seed"),
|
seed=op.get_attr("seed"),
|
||||||
seed2=op.get_attr("seed2"),
|
seed2=op.get_attr("seed2"),
|
||||||
time_major=op.get_attr("time_major"),
|
time_major=op.get_attr("time_major"),
|
||||||
|
num_proj=op.get_attr("num_proj"),
|
||||||
rnn_mode=op.get_attr("rnn_mode"),
|
rnn_mode=op.get_attr("rnn_mode"),
|
||||||
input_mode=op.get_attr("input_mode"),
|
input_mode=op.get_attr("input_mode"),
|
||||||
direction=op.get_attr("direction")) + (None,)
|
direction=op.get_attr("direction")) + (None,)
|
||||||
|
@ -1002,8 +1002,8 @@ class CudnnRnnParamsDescriptor {
|
|||||||
class CudnnRnnDescriptor : public dnn::RnnDescriptor {
|
class CudnnRnnDescriptor : public dnn::RnnDescriptor {
|
||||||
CudnnRnnDescriptor(const CudnnHandle& cudnn, gpu::RnnDescriptor rnn_desc,
|
CudnnRnnDescriptor(const CudnnHandle& cudnn, gpu::RnnDescriptor rnn_desc,
|
||||||
PersistentRnnPlan rnn_plan, int num_layers,
|
PersistentRnnPlan rnn_plan, int num_layers,
|
||||||
int hidden_size, int input_size, int batch_size,
|
int hidden_size, int input_size, int c_size,
|
||||||
cudnnRNNInputMode_t input_mode,
|
int batch_size, cudnnRNNInputMode_t input_mode,
|
||||||
cudnnDirectionMode_t direction_mode,
|
cudnnDirectionMode_t direction_mode,
|
||||||
cudnnRNNMode_t rnn_mode, cudnnDataType_t data_type,
|
cudnnRNNMode_t rnn_mode, cudnnDataType_t data_type,
|
||||||
cudnnDataType_t compute_type,
|
cudnnDataType_t compute_type,
|
||||||
@ -1015,6 +1015,7 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor {
|
|||||||
num_layers_(num_layers),
|
num_layers_(num_layers),
|
||||||
hidden_size_(hidden_size),
|
hidden_size_(hidden_size),
|
||||||
input_size_(input_size),
|
input_size_(input_size),
|
||||||
|
c_size_(c_size),
|
||||||
batch_size_(batch_size),
|
batch_size_(batch_size),
|
||||||
rnn_algo_(ToCudnnRNNAlgo(algorithm_config.algorithm())),
|
rnn_algo_(ToCudnnRNNAlgo(algorithm_config.algorithm())),
|
||||||
input_mode_(input_mode),
|
input_mode_(input_mode),
|
||||||
@ -1031,7 +1032,7 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor {
|
|||||||
|
|
||||||
static port::StatusOr<CudnnRnnDescriptor> Create(
|
static port::StatusOr<CudnnRnnDescriptor> Create(
|
||||||
const CudnnHandle& cudnn, int num_layers, int hidden_size, int input_size,
|
const CudnnHandle& cudnn, int num_layers, int hidden_size, int input_size,
|
||||||
int batch_size, cudnnRNNInputMode_t input_mode,
|
int c_size, int batch_size, cudnnRNNInputMode_t input_mode,
|
||||||
cudnnDirectionMode_t direction_mode, cudnnRNNMode_t rnn_mode,
|
cudnnDirectionMode_t direction_mode, cudnnRNNMode_t rnn_mode,
|
||||||
cudnnDataType_t data_type, cudnnDataType_t compute_type,
|
cudnnDataType_t data_type, cudnnDataType_t compute_type,
|
||||||
const dnn::AlgorithmConfig& algorithm_config, float dropout, uint64 seed,
|
const dnn::AlgorithmConfig& algorithm_config, float dropout, uint64 seed,
|
||||||
@ -1044,12 +1045,29 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor {
|
|||||||
cudnnRNNAlgo_t rnn_algo = ToCudnnRNNAlgo(algorithm_config.algorithm());
|
cudnnRNNAlgo_t rnn_algo = ToCudnnRNNAlgo(algorithm_config.algorithm());
|
||||||
|
|
||||||
// TODO: allow the user to choose an algorithm.
|
// TODO: allow the user to choose an algorithm.
|
||||||
RETURN_IF_CUDNN_ERROR(cudnnSetRNNDescriptor_v6(
|
if (c_size != 0 && hidden_size < c_size) {
|
||||||
cudnn.handle(), /*rnnDesc=*/rnn_desc.get(), /*hiddenSize=*/hidden_size,
|
RETURN_IF_CUDNN_ERROR(cudnnSetRNNDescriptor_v6(
|
||||||
/*numLayers=*/num_layers, /*dropoutDesc=*/dropout_desc.handle(),
|
cudnn.handle(), /*rnnDesc=*/rnn_desc.get(), /*hiddenSize=*/c_size,
|
||||||
/*inputMode=*/input_mode, /*direction=*/direction_mode,
|
/*numLayers=*/num_layers, /*dropoutDesc=*/dropout_desc.handle(),
|
||||||
/*mode=*/rnn_mode, /*algo=*/rnn_algo,
|
/*inputMode=*/input_mode, /*direction=*/direction_mode,
|
||||||
/*dataType=*/compute_type));
|
/*mode=*/rnn_mode, /*algo=*/rnn_algo, /*dataType=*/compute_type));
|
||||||
|
#if CUDNN_VERSION >= 7101
|
||||||
|
RETURN_IF_CUDNN_ERROR(cudnnSetRNNProjectionLayers(
|
||||||
|
cudnn.handle(), /*rnnDesc=*/rnn_desc.get(),
|
||||||
|
/*recProjSize=*/hidden_size, /*outProjSize=*/0));
|
||||||
|
#else
|
||||||
|
return port::Status(port::error::INVALID_ARGUMENT,
|
||||||
|
"No supported cudnnSetRNNProjectionLayers when "
|
||||||
|
"CUDNN_VERSION < 7.1.1");
|
||||||
|
#endif
|
||||||
|
} else {
|
||||||
|
RETURN_IF_CUDNN_ERROR(cudnnSetRNNDescriptor_v6(
|
||||||
|
cudnn.handle(), /*rnnDesc=*/rnn_desc.get(),
|
||||||
|
/*hiddenSize=*/hidden_size, /*numLayers=*/num_layers,
|
||||||
|
/*dropoutDesc=*/dropout_desc.handle(), /*inputMode=*/input_mode,
|
||||||
|
/*direction=*/direction_mode, /*mode=*/rnn_mode, /*algo=*/rnn_algo,
|
||||||
|
/*dataType=*/compute_type));
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: For now, we only use cudnnRNN**Ex API to process padded inputs.
|
// TODO: For now, we only use cudnnRNN**Ex API to process padded inputs.
|
||||||
// But in the future if these APIs are used to process full length arrays,
|
// But in the future if these APIs are used to process full length arrays,
|
||||||
@ -1098,9 +1116,9 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor {
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
return CudnnRnnDescriptor(cudnn, std::move(rnn_desc), std::move(rnn_plan),
|
return CudnnRnnDescriptor(cudnn, std::move(rnn_desc), std::move(rnn_plan),
|
||||||
num_layers, hidden_size, input_size, batch_size,
|
num_layers, hidden_size, input_size, c_size,
|
||||||
input_mode, direction_mode, rnn_mode, data_type,
|
batch_size, input_mode, direction_mode, rnn_mode,
|
||||||
compute_type, algorithm_config,
|
data_type, compute_type, algorithm_config,
|
||||||
std::move(dropout_desc), std::move(params_desc));
|
std::move(dropout_desc), std::move(params_desc));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1108,6 +1126,7 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor {
|
|||||||
int num_layers() const { return num_layers_; }
|
int num_layers() const { return num_layers_; }
|
||||||
int hidden_size() const { return hidden_size_; }
|
int hidden_size() const { return hidden_size_; }
|
||||||
int input_size() const { return input_size_; }
|
int input_size() const { return input_size_; }
|
||||||
|
int c_size() const { return c_size_; }
|
||||||
int batch_size() const { return batch_size_; }
|
int batch_size() const { return batch_size_; }
|
||||||
cudnnRNNInputMode_t input_mode() const { return input_mode_; }
|
cudnnRNNInputMode_t input_mode() const { return input_mode_; }
|
||||||
cudnnDirectionMode_t direction_mode() const { return direction_mode_; }
|
cudnnDirectionMode_t direction_mode() const { return direction_mode_; }
|
||||||
@ -1136,6 +1155,7 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor {
|
|||||||
int num_layers_;
|
int num_layers_;
|
||||||
int hidden_size_;
|
int hidden_size_;
|
||||||
int input_size_;
|
int input_size_;
|
||||||
|
int c_size_;
|
||||||
// batch_size_ is set to -1 when not using CUDNN_RNN_ALGO_PERSIST_DYNAMIC
|
// batch_size_ is set to -1 when not using CUDNN_RNN_ALGO_PERSIST_DYNAMIC
|
||||||
// algorithm.
|
// algorithm.
|
||||||
int batch_size_;
|
int batch_size_;
|
||||||
@ -1240,6 +1260,62 @@ port::StatusOr<CudnnRnnParamsDescriptor> CudnnRnnParamsDescriptor::Create(
|
|||||||
(type == 0 ? weights : biases).push_back(region);
|
(type == 0 ? weights : biases).push_back(region);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
int hidden_size_v;
|
||||||
|
int num_layers_v;
|
||||||
|
cudnnDropoutDescriptor_t dropout_desc;
|
||||||
|
cudnnRNNInputMode_t input_mode;
|
||||||
|
cudnnDirectionMode_t direction;
|
||||||
|
cudnnRNNMode_t mode;
|
||||||
|
cudnnRNNAlgo_t algo;
|
||||||
|
cudnnDataType_t data_dype;
|
||||||
|
RETURN_IF_CUDNN_ERROR(cudnnGetRNNDescriptor(
|
||||||
|
/*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
|
||||||
|
/*hiddenSize=*/&hidden_size_v,
|
||||||
|
/*numLayers=*/&num_layers_v,
|
||||||
|
/*dropoutDesc=*/&dropout_desc,
|
||||||
|
/*inputMode=*/&input_mode,
|
||||||
|
/*direction=*/&direction,
|
||||||
|
/*mode=*/&mode,
|
||||||
|
/*algo=*/&algo,
|
||||||
|
/*dataType=*/&data_type));
|
||||||
|
int rec_proj_size_v;
|
||||||
|
int out_proj_size_v;
|
||||||
|
#if CUDNN_VERSION >= 7101
|
||||||
|
RETURN_IF_CUDNN_ERROR(cudnnGetRNNProjectionLayers(
|
||||||
|
/*handle=*/cudnn.handle(),
|
||||||
|
/*rnnDesc=*/rnn_desc,
|
||||||
|
/*recProjSize*/ &rec_proj_size_v,
|
||||||
|
/*outProjSize*/ &out_proj_size_v));
|
||||||
|
#else
|
||||||
|
return port::Status(port::error::INVALID_ARGUMENT,
|
||||||
|
"No supported cudnnGetRNNProjectionLayers when "
|
||||||
|
"CUDNN_VERSION < 7.1.1");
|
||||||
|
#endif
|
||||||
|
if (rec_proj_size_v != hidden_size_v) {
|
||||||
|
void* offset = nullptr;
|
||||||
|
int region_id = 8;
|
||||||
|
RETURN_IF_CUDNN_ERROR(cudnnGetRNNLinLayerMatrixParams(
|
||||||
|
/*handle=*/cudnn.handle(), /*rnnDesc=*/rnn_desc,
|
||||||
|
/*layer=*/layer, /*xDesc=*/input_desc.get(),
|
||||||
|
/*wDesc=*/filter_desc.get(),
|
||||||
|
/*w=*/nullptr, /*linLayerID=*/region_id,
|
||||||
|
/*linLayerMatDesc=*/region_desc_handle.get(),
|
||||||
|
/*linLayerMat or linLayerBias=*/&offset));
|
||||||
|
int dims[] = {1, 1, 1};
|
||||||
|
cudnnDataType_t data_type;
|
||||||
|
cudnnTensorFormat_t tensor_format;
|
||||||
|
int n_dims;
|
||||||
|
RETURN_IF_CUDNN_ERROR(cudnnGetFilterNdDescriptor(
|
||||||
|
/*filterDesc=*/region_desc_handle.get(),
|
||||||
|
/*nbDimsRequested=*/sizeof(dims) / sizeof(dims[0]),
|
||||||
|
/*dataType=*/&data_type, /*format=*/&tensor_format,
|
||||||
|
/*nbDims=*/&n_dims, /*filterDimA=*/dims));
|
||||||
|
int64 size =
|
||||||
|
dims[0] * dims[1] * dims[2] * CudnnDataTypeToByteSize(data_type);
|
||||||
|
dnn::RnnDescriptor::ParamsRegion region = {
|
||||||
|
reinterpret_cast<int64>(offset), size};
|
||||||
|
weights.push_back(region);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return CudnnRnnParamsDescriptor(std::move(filter_desc), params_size_in_bytes,
|
return CudnnRnnParamsDescriptor(std::move(filter_desc), params_size_in_bytes,
|
||||||
@ -1404,6 +1480,7 @@ struct RnnModelDims {
|
|||||||
int max_seq_length = 0;
|
int max_seq_length = 0;
|
||||||
int hidden_size = 0;
|
int hidden_size = 0;
|
||||||
int input_size = 0;
|
int input_size = 0;
|
||||||
|
int c_size = 0;
|
||||||
int dir_count = 0;
|
int dir_count = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -1429,6 +1506,7 @@ port::StatusOr<RnnModelDims> ExtractAndCheckRnnForward(
|
|||||||
model_dims.max_seq_length = input_desc.max_seq_length();
|
model_dims.max_seq_length = input_desc.max_seq_length();
|
||||||
model_dims.hidden_size = rnn_desc.hidden_size();
|
model_dims.hidden_size = rnn_desc.hidden_size();
|
||||||
model_dims.input_size = input_desc.data_size();
|
model_dims.input_size = input_desc.data_size();
|
||||||
|
model_dims.c_size = rnn_desc.c_size();
|
||||||
model_dims.dir_count =
|
model_dims.dir_count =
|
||||||
(rnn_desc.direction_mode() == CUDNN_BIDIRECTIONAL) ? 2 : 1;
|
(rnn_desc.direction_mode() == CUDNN_BIDIRECTIONAL) ? 2 : 1;
|
||||||
|
|
||||||
@ -1441,7 +1519,7 @@ port::StatusOr<RnnModelDims> ExtractAndCheckRnnForward(
|
|||||||
}
|
}
|
||||||
if (!(input_h_desc.num_layers() == input_c_desc.num_layers() &&
|
if (!(input_h_desc.num_layers() == input_c_desc.num_layers() &&
|
||||||
input_h_desc.batch_size() == input_c_desc.batch_size() &&
|
input_h_desc.batch_size() == input_c_desc.batch_size() &&
|
||||||
input_h_desc.data_size() == input_c_desc.data_size())) {
|
input_h_desc.data_size() <= input_c_desc.data_size())) {
|
||||||
return port::Status(port::error::INVALID_ARGUMENT, "Invalid input_c shape");
|
return port::Status(port::error::INVALID_ARGUMENT, "Invalid input_c shape");
|
||||||
}
|
}
|
||||||
if (!(output_desc.max_seq_length() == model_dims.max_seq_length &&
|
if (!(output_desc.max_seq_length() == model_dims.max_seq_length &&
|
||||||
@ -1458,7 +1536,7 @@ port::StatusOr<RnnModelDims> ExtractAndCheckRnnForward(
|
|||||||
}
|
}
|
||||||
if (!(input_h_desc.num_layers() == output_c_desc.num_layers() &&
|
if (!(input_h_desc.num_layers() == output_c_desc.num_layers() &&
|
||||||
input_h_desc.batch_size() == output_c_desc.batch_size() &&
|
input_h_desc.batch_size() == output_c_desc.batch_size() &&
|
||||||
input_h_desc.data_size() == output_c_desc.data_size())) {
|
input_h_desc.data_size() <= output_c_desc.data_size())) {
|
||||||
return port::Status(port::error::INVALID_ARGUMENT,
|
return port::Status(port::error::INVALID_ARGUMENT,
|
||||||
"Invalid output_c shape");
|
"Invalid output_c shape");
|
||||||
}
|
}
|
||||||
@ -1814,7 +1892,7 @@ port::Status CudnnSupport::DoRnnBackwardImpl(
|
|||||||
|
|
||||||
port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
|
port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
|
||||||
CudnnSupport::createRnnDescriptor(
|
CudnnSupport::createRnnDescriptor(
|
||||||
int num_layers, int hidden_size, int input_size, int batch_size,
|
int num_layers, int hidden_size, int input_size, int c_size, int batch_size,
|
||||||
dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode,
|
dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode,
|
||||||
dnn::RnnMode rnn_mode, dnn::DataType data_type,
|
dnn::RnnMode rnn_mode, dnn::DataType data_type,
|
||||||
const dnn::AlgorithmConfig& algorithm_config, float dropout, uint64 seed,
|
const dnn::AlgorithmConfig& algorithm_config, float dropout, uint64 seed,
|
||||||
@ -1825,7 +1903,7 @@ CudnnSupport::createRnnDescriptor(
|
|||||||
SE_ASSIGN_OR_RETURN(
|
SE_ASSIGN_OR_RETURN(
|
||||||
CudnnRnnDescriptor rnn_desc,
|
CudnnRnnDescriptor rnn_desc,
|
||||||
CudnnRnnDescriptor::Create(
|
CudnnRnnDescriptor::Create(
|
||||||
cudnn, num_layers, hidden_size, input_size, batch_size,
|
cudnn, num_layers, hidden_size, input_size, c_size, batch_size,
|
||||||
ToCudnnRnnInputMode(input_mode),
|
ToCudnnRnnInputMode(input_mode),
|
||||||
ToCudnnRnnDirectionMode(direction_mode), ToCudnnRnnMode(rnn_mode),
|
ToCudnnRnnDirectionMode(direction_mode), ToCudnnRnnMode(rnn_mode),
|
||||||
ToCudnnDataType(data_type), GetRnnComputeType(data_type),
|
ToCudnnDataType(data_type), GetRnnComputeType(data_type),
|
||||||
|
@ -48,11 +48,11 @@ class CudnnSupport : public dnn::DnnSupport {
|
|||||||
port::StatusOr<perftools::gputools::dnn::VersionInfo> GetVersion() override;
|
port::StatusOr<perftools::gputools::dnn::VersionInfo> GetVersion() override;
|
||||||
|
|
||||||
port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor(
|
port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor(
|
||||||
int num_layers, int hidden_size, int input_size, int batch_size,
|
int num_layers, int hidden_size, int input_size, int c_size,
|
||||||
dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode,
|
int batch_size, dnn::RnnInputMode input_mode,
|
||||||
dnn::RnnMode rnn_mode, dnn::DataType data_type,
|
dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
|
||||||
const dnn::AlgorithmConfig& algorithm_config, float dropout, uint64 seed,
|
dnn::DataType data_type, const dnn::AlgorithmConfig& algorithm_config,
|
||||||
ScratchAllocator* state_allocator) override;
|
float dropout, uint64 seed, ScratchAllocator* state_allocator) override;
|
||||||
|
|
||||||
port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
|
port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
|
||||||
createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
|
createRnnSequenceTensorDescriptor(int max_seq_length, int batch_size,
|
||||||
|
@ -2052,7 +2052,7 @@ class DnnSupport {
|
|||||||
// is no longer in use.
|
// is no longer in use.
|
||||||
virtual port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
|
virtual port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
|
||||||
createRnnDescriptor(int num_layers, int hidden_size, int input_size,
|
createRnnDescriptor(int num_layers, int hidden_size, int input_size,
|
||||||
int batch_size, dnn::RnnInputMode input_mode,
|
int c_size, int batch_size, dnn::RnnInputMode input_mode,
|
||||||
dnn::RnnDirectionMode direction_mode,
|
dnn::RnnDirectionMode direction_mode,
|
||||||
dnn::RnnMode rnn_mode, dnn::DataType data_type,
|
dnn::RnnMode rnn_mode, dnn::DataType data_type,
|
||||||
const dnn::AlgorithmConfig& algorithm_config,
|
const dnn::AlgorithmConfig& algorithm_config,
|
||||||
|
@ -379,7 +379,7 @@ bool StreamExecutor::GetBlasGemmAlgorithms(
|
|||||||
|
|
||||||
port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
|
port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
|
||||||
StreamExecutor::createRnnDescriptor(
|
StreamExecutor::createRnnDescriptor(
|
||||||
int num_layers, int hidden_size, int input_size, int batch_size,
|
int num_layers, int hidden_size, int input_size, int c_size, int batch_size,
|
||||||
dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode,
|
dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode,
|
||||||
dnn::RnnMode rnn_mode, dnn::DataType data_type,
|
dnn::RnnMode rnn_mode, dnn::DataType data_type,
|
||||||
const dnn::AlgorithmConfig &algorithm_config, float dropout, uint64 seed,
|
const dnn::AlgorithmConfig &algorithm_config, float dropout, uint64 seed,
|
||||||
@ -390,7 +390,7 @@ StreamExecutor::createRnnDescriptor(
|
|||||||
"Fail to find the dnn implementation.");
|
"Fail to find the dnn implementation.");
|
||||||
}
|
}
|
||||||
return dnn_support->createRnnDescriptor(
|
return dnn_support->createRnnDescriptor(
|
||||||
num_layers, hidden_size, input_size, batch_size, input_mode,
|
num_layers, hidden_size, input_size, c_size, batch_size, input_mode,
|
||||||
direction_mode, rnn_mode, data_type, algorithm_config, dropout, seed,
|
direction_mode, rnn_mode, data_type, algorithm_config, dropout, seed,
|
||||||
state_allocator);
|
state_allocator);
|
||||||
}
|
}
|
||||||
|
@ -405,11 +405,11 @@ class StreamExecutor {
|
|||||||
// Create an RNN descriptor based on model shapes and configurations.
|
// Create an RNN descriptor based on model shapes and configurations.
|
||||||
// The caller retains the ownership of the descriptor.
|
// The caller retains the ownership of the descriptor.
|
||||||
port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor(
|
port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor(
|
||||||
int num_layers, int hidden_size, int input_size, int batch_size,
|
int num_layers, int hidden_size, int input_size, int c_size,
|
||||||
dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode,
|
int batch_size, dnn::RnnInputMode input_mode,
|
||||||
dnn::RnnMode rnn_mode, dnn::DataType data_type,
|
dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode,
|
||||||
const dnn::AlgorithmConfig &algorithm_config, float dropout, uint64 seed,
|
dnn::DataType data_type, const dnn::AlgorithmConfig &algorithm_config,
|
||||||
ScratchAllocator *state_allocator);
|
float dropout, uint64 seed, ScratchAllocator *state_allocator);
|
||||||
|
|
||||||
// Create a RNN sequence descriptor that specifies either the input or output
|
// Create a RNN sequence descriptor that specifies either the input or output
|
||||||
// sequence. The caller retains the ownership of the returned descriptor.
|
// sequence. The caller retains the ownership of the returned descriptor.
|
||||||
|
@ -1036,6 +1036,14 @@ tf_module {
|
|||||||
name: "cross"
|
name: "cross"
|
||||||
argspec: "args=[\'a\', \'b\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'a\', \'b\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "cudnn_rnn_canonical_to_params_v2"
|
||||||
|
argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'weights\', \'biases\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'num_proj\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'0\', \'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "cudnn_rnn_params_to_canonical_v2"
|
||||||
|
argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'params\', \'num_params_weights\', \'num_params_biases\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'num_proj\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'0\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "cumprod"
|
name: "cumprod"
|
||||||
argspec: "args=[\'x\', \'axis\', \'exclusive\', \'reverse\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'False\', \'False\', \'None\'], "
|
argspec: "args=[\'x\', \'axis\', \'exclusive\', \'reverse\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'False\', \'False\', \'None\'], "
|
||||||
|
@ -766,27 +766,35 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "CudnnRNNBackpropV3"
|
name: "CudnnRNNBackpropV3"
|
||||||
argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'sequence_lengths\', \'output\', \'output_h\', \'output_c\', \'output_backprop\', \'output_h_backprop\', \'output_c_backprop\', \'reserve_space\', \'host_reserved\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'time_major\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'True\', \'None\'], "
|
argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'sequence_lengths\', \'output\', \'output_h\', \'output_c\', \'output_backprop\', \'output_h_backprop\', \'output_c_backprop\', \'reserve_space\', \'host_reserved\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'num_proj\', \'time_major\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'0\', \'True\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "CudnnRNNCanonicalToParams"
|
name: "CudnnRNNCanonicalToParams"
|
||||||
argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'weights\', \'biases\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'None\'], "
|
argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'weights\', \'biases\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "CudnnRNNCanonicalToParamsV2"
|
||||||
|
argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'weights\', \'biases\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'num_proj\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'0\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "CudnnRNNParamsSize"
|
name: "CudnnRNNParamsSize"
|
||||||
argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'T\', \'S\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'None\'], "
|
argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'T\', \'S\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'num_proj\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'0\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "CudnnRNNParamsToCanonical"
|
name: "CudnnRNNParamsToCanonical"
|
||||||
argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'params\', \'num_params\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'None\'], "
|
argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'params\', \'num_params\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "CudnnRNNParamsToCanonicalV2"
|
||||||
|
argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'params\', \'num_params_weights\', \'num_params_biases\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'num_proj\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'0\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "CudnnRNNV2"
|
name: "CudnnRNNV2"
|
||||||
argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'is_training\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'True\', \'None\'], "
|
argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'is_training\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'True\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "CudnnRNNV3"
|
name: "CudnnRNNV3"
|
||||||
argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'sequence_lengths\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'is_training\', \'time_major\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'True\', \'True\', \'None\'], "
|
argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'sequence_lengths\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'num_proj\', \'is_training\', \'time_major\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'0\', \'True\', \'True\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "Cumprod"
|
name: "Cumprod"
|
||||||
|
@ -540,6 +540,14 @@ tf_module {
|
|||||||
name: "cosh"
|
name: "cosh"
|
||||||
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "cudnn_rnn_canonical_to_params_v2"
|
||||||
|
argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'weights\', \'biases\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'num_proj\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'0\', \'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "cudnn_rnn_params_to_canonical_v2"
|
||||||
|
argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'params\', \'num_params_weights\', \'num_params_biases\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'num_proj\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'0\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "cumsum"
|
name: "cumsum"
|
||||||
argspec: "args=[\'x\', \'axis\', \'exclusive\', \'reverse\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'False\', \'False\', \'None\'], "
|
argspec: "args=[\'x\', \'axis\', \'exclusive\', \'reverse\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'False\', \'False\', \'None\'], "
|
||||||
|
@ -766,27 +766,35 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "CudnnRNNBackpropV3"
|
name: "CudnnRNNBackpropV3"
|
||||||
argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'sequence_lengths\', \'output\', \'output_h\', \'output_c\', \'output_backprop\', \'output_h_backprop\', \'output_c_backprop\', \'reserve_space\', \'host_reserved\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'time_major\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'True\', \'None\'], "
|
argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'sequence_lengths\', \'output\', \'output_h\', \'output_c\', \'output_backprop\', \'output_h_backprop\', \'output_c_backprop\', \'reserve_space\', \'host_reserved\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'num_proj\', \'time_major\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'0\', \'True\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "CudnnRNNCanonicalToParams"
|
name: "CudnnRNNCanonicalToParams"
|
||||||
argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'weights\', \'biases\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'None\'], "
|
argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'weights\', \'biases\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "CudnnRNNCanonicalToParamsV2"
|
||||||
|
argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'weights\', \'biases\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'num_proj\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'0\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "CudnnRNNParamsSize"
|
name: "CudnnRNNParamsSize"
|
||||||
argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'T\', \'S\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'None\'], "
|
argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'T\', \'S\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'num_proj\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'0\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "CudnnRNNParamsToCanonical"
|
name: "CudnnRNNParamsToCanonical"
|
||||||
argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'params\', \'num_params\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'None\'], "
|
argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'params\', \'num_params\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "CudnnRNNParamsToCanonicalV2"
|
||||||
|
argspec: "args=[\'num_layers\', \'num_units\', \'input_size\', \'params\', \'num_params_weights\', \'num_params_biases\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'num_proj\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'0\', \'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "CudnnRNNV2"
|
name: "CudnnRNNV2"
|
||||||
argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'is_training\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'True\', \'None\'], "
|
argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'is_training\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'True\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "CudnnRNNV3"
|
name: "CudnnRNNV3"
|
||||||
argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'sequence_lengths\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'is_training\', \'time_major\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'True\', \'True\', \'None\'], "
|
argspec: "args=[\'input\', \'input_h\', \'input_c\', \'params\', \'sequence_lengths\', \'rnn_mode\', \'input_mode\', \'direction\', \'dropout\', \'seed\', \'seed2\', \'num_proj\', \'is_training\', \'time_major\', \'name\'], varargs=None, keywords=None, defaults=[\'lstm\', \'linear_input\', \'unidirectional\', \'0\', \'0\', \'0\', \'0\', \'True\', \'True\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "Cumprod"
|
name: "Cumprod"
|
||||||
|
@ -10,7 +10,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "audio"
|
name: "audio"
|
||||||
argspec: "args=[\'name\', \'data\', \'sample_rate\', \'step\', \'max_outputs\', \'encoding\', \'description\'], varargs=None, keywords=None, defaults=[\'None\', \'3\', \'None\', \'None\'], "
|
argspec: "args=[\'name\', \'data\', \'sample_rate\', \'step\', \'max_outputs\', \'encoding\', \'description\'], varargs=None, keywords=None, defaults=[\'3\', \'None\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "create_file_writer"
|
name: "create_file_writer"
|
||||||
@ -26,11 +26,11 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "histogram"
|
name: "histogram"
|
||||||
argspec: "args=[\'name\', \'data\', \'step\', \'buckets\', \'description\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
argspec: "args=[\'name\', \'data\', \'step\', \'buckets\', \'description\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "image"
|
name: "image"
|
||||||
argspec: "args=[\'name\', \'data\', \'step\', \'max_outputs\', \'description\'], varargs=None, keywords=None, defaults=[\'None\', \'3\', \'None\'], "
|
argspec: "args=[\'name\', \'data\', \'step\', \'max_outputs\', \'description\'], varargs=None, keywords=None, defaults=[\'3\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "import_event"
|
name: "import_event"
|
||||||
@ -42,7 +42,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "scalar"
|
name: "scalar"
|
||||||
argspec: "args=[\'name\', \'data\', \'step\', \'description\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
argspec: "args=[\'name\', \'data\', \'step\', \'description\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "summary_scope"
|
name: "summary_scope"
|
||||||
@ -50,7 +50,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "text"
|
name: "text"
|
||||||
argspec: "args=[\'name\', \'data\', \'step\', \'description\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
argspec: "args=[\'name\', \'data\', \'step\', \'description\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "trace_export"
|
name: "trace_export"
|
||||||
|
Loading…
Reference in New Issue
Block a user