Added BlockLSTMV2 and BlockLSTMGradV2 with IFCO layout
Note also that BlockLSTMV2 does not allow setting forget_bias (fixed at 0.0) and defaults cell_clip to 0.0 (disabled). PiperOrigin-RevId: 263492926
This commit is contained in:
parent
3f1deea3cb
commit
722b96b229
171
tensorflow/core/api_def/base_api/api_def_BlockLSTMGradV2.pbtxt
Normal file
171
tensorflow/core/api_def/base_api/api_def_BlockLSTMGradV2.pbtxt
Normal file
@ -0,0 +1,171 @@
|
||||
op {
|
||||
graph_op_name: "BlockLSTMGradV2"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "seq_len_max"
|
||||
description: <<END
|
||||
Maximum time length actually used by this input. Outputs are padded
|
||||
with zeros beyond this length.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "x"
|
||||
description: <<END
|
||||
The sequence input to the LSTM, shape (timelen, batch_size, num_inputs).
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "cs_prev"
|
||||
description: <<END
|
||||
Value of the initial cell state.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "h_prev"
|
||||
description: <<END
|
||||
Initial output of cell (to be used for peephole).
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "w"
|
||||
description: <<END
|
||||
The weight matrix.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "wci"
|
||||
description: <<END
|
||||
The weight matrix for input gate peephole connection.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "wcf"
|
||||
description: <<END
|
||||
The weight matrix for forget gate peephole connection.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "wco"
|
||||
description: <<END
|
||||
The weight matrix for output gate peephole connection.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "b"
|
||||
description: <<END
|
||||
The bias vector.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "i"
|
||||
description: <<END
|
||||
The input gate over the whole time sequence.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "cs"
|
||||
description: <<END
|
||||
The cell state before the tanh over the whole time sequence.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "f"
|
||||
description: <<END
|
||||
The forget gate over the whole time sequence.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "o"
|
||||
description: <<END
|
||||
The output gate over the whole time sequence.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "ci"
|
||||
description: <<END
|
||||
The cell input over the whole time sequence.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "co"
|
||||
description: <<END
|
||||
The cell after the tanh over the whole time sequence.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "h"
|
||||
description: <<END
|
||||
The output h vector over the whole time sequence.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "cs_grad"
|
||||
description: <<END
|
||||
The current gradient of cs.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "h_grad"
|
||||
description: <<END
|
||||
The gradient of h vector.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "x_grad"
|
||||
description: <<END
|
||||
The gradient of x to be back-propped.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "cs_prev_grad"
|
||||
description: <<END
|
||||
The gradient of cs_prev to be back-propped.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "h_prev_grad"
|
||||
description: <<END
|
||||
The gradient of h_prev to be back-propped.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "w_grad"
|
||||
description: <<END
|
||||
The gradient for w to be back-propped.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "wci_grad"
|
||||
description: <<END
|
||||
The gradient for wci to be back-propped.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "wcf_grad"
|
||||
description: <<END
|
||||
The gradient for wcf to be back-propped.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "wco_grad"
|
||||
description: <<END
|
||||
The gradient for wco to be back-propped.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "b_grad"
|
||||
description: <<END
|
||||
The gradient for w to be back-propped.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "use_peephole"
|
||||
description: <<END
|
||||
Whether to use peephole weights.
|
||||
END
|
||||
}
|
||||
summary: "Computes the LSTM cell backward propagation for the entire time sequence."
|
||||
description: <<END
|
||||
This implementation is to be used in conjunction of BlockLSTMV2.
|
||||
END
|
||||
}
|
137
tensorflow/core/api_def/base_api/api_def_BlockLSTMV2.pbtxt
Normal file
137
tensorflow/core/api_def/base_api/api_def_BlockLSTMV2.pbtxt
Normal file
@ -0,0 +1,137 @@
|
||||
op {
|
||||
graph_op_name: "BlockLSTMV2"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "seq_len_max"
|
||||
description: <<END
|
||||
Maximum time length actually used by this input. Outputs are padded
|
||||
with zeros beyond this length.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "x"
|
||||
description: <<END
|
||||
The sequence input to the LSTM, shape (timelen, batch_size, num_inputs).
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "cs_prev"
|
||||
description: <<END
|
||||
Value of the initial cell state.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "h_prev"
|
||||
description: <<END
|
||||
Initial output of cell (to be used for peephole).
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "w"
|
||||
description: <<END
|
||||
The weight matrix.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "wci"
|
||||
description: <<END
|
||||
The weight matrix for input gate peephole connection.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "wcf"
|
||||
description: <<END
|
||||
The weight matrix for forget gate peephole connection.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "wco"
|
||||
description: <<END
|
||||
The weight matrix for output gate peephole connection.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "b"
|
||||
description: <<END
|
||||
The bias vector.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "i"
|
||||
description: <<END
|
||||
The input gate over the whole time sequence.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "cs"
|
||||
description: <<END
|
||||
The cell state before the tanh over the whole time sequence.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "f"
|
||||
description: <<END
|
||||
The forget gate over the whole time sequence.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "o"
|
||||
description: <<END
|
||||
The output gate over the whole time sequence.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "ci"
|
||||
description: <<END
|
||||
The cell input over the whole time sequence.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "co"
|
||||
description: <<END
|
||||
The cell after the tanh over the whole time sequence.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "h"
|
||||
description: <<END
|
||||
The output h vector over the whole time sequence.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "cell_clip"
|
||||
description: <<END
|
||||
Value to clip the 'cs' value to.
|
||||
END
|
||||
}
|
||||
attr {
|
||||
name: "use_peephole"
|
||||
description: <<END
|
||||
Whether to use peephole weights.
|
||||
END
|
||||
}
|
||||
summary: "Computes the LSTM cell forward propagation for all the time steps."
|
||||
description: <<END
|
||||
This is equivalent to applying LSTMBlockCell in a loop, like so:
|
||||
|
||||
```python
|
||||
for x1 in unpack(x):
|
||||
i1, cs1, f1, o1, ci1, co1, h1 = LSTMBlock(
|
||||
x1, cs_prev, h_prev, w, wci, wcf, wco, b)
|
||||
cs_prev = cs1
|
||||
h_prev = h1
|
||||
i.append(i1)
|
||||
cs.append(cs1)
|
||||
f.append(f1)
|
||||
o.append(o1)
|
||||
ci.append(ci1)
|
||||
co.append(co1)
|
||||
h.append(h1)
|
||||
return pack(i), pack(cs), pack(f), pack(o), pack(ci), pack(ch), pack(h)
|
||||
|
||||
Note that unlike LSTMBlockCell (and BlockLSTM) which uses ICFO gate layout,
|
||||
this op uses IFCO. So in order for the following snippet to be equivalent
|
||||
all gate-related outputs should be reordered.
|
||||
```
|
||||
END
|
||||
}
|
@ -235,7 +235,9 @@ void LSTMBlockCellBpropWithEigen(
|
||||
template struct LSTMBlockCellBprop<CPUDevice, T, false /* USE_CUBLAS */, \
|
||||
GATE_LAYOUT>;
|
||||
|
||||
#define DECLARE_CPU_SPECS(T) DECLARE_CPU_FBPROP(T, ICFO);
|
||||
#define DECLARE_CPU_SPECS(T) \
|
||||
DECLARE_CPU_FBPROP(T, ICFO); \
|
||||
DECLARE_CPU_FBPROP(T, IFCO);
|
||||
|
||||
DECLARE_CPU_SPECS(Eigen::half);
|
||||
DECLARE_CPU_SPECS(float);
|
||||
@ -827,7 +829,12 @@ template <typename Device, typename T, bool USE_CUBLAS, GateLayout gate_layout>
|
||||
class BlockLSTMOp : public OpKernel {
|
||||
public:
|
||||
explicit BlockLSTMOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("forget_bias", &forget_bias_));
|
||||
if (ctx->HasAttr("forget_bias")) {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("forget_bias", &forget_bias_));
|
||||
} else {
|
||||
// V2 version does not have "forget_bias" attribute.
|
||||
forget_bias_ = 0.0;
|
||||
}
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("cell_clip", &cell_clip_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("use_peephole", &use_peephole_));
|
||||
}
|
||||
@ -1010,10 +1017,13 @@ class BlockLSTMOp : public OpKernel {
|
||||
bool use_peephole_;
|
||||
};
|
||||
|
||||
#define REGISTER_KERNEL(T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("BlockLSTM").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
||||
BlockLSTMOp<CPUDevice, T, false, ICFO>);
|
||||
#define REGISTER_KERNEL(T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("BlockLSTM").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
||||
BlockLSTMOp<CPUDevice, T, false, ICFO>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("BlockLSTMV2").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
||||
BlockLSTMOp<CPUDevice, T, false, IFCO>);
|
||||
|
||||
REGISTER_KERNEL(Eigen::half);
|
||||
REGISTER_KERNEL(float);
|
||||
@ -1039,12 +1049,17 @@ DECLARE_GPU_SPECS(float);
|
||||
#undef DECLARE_GPU_SPECS
|
||||
} // end namespace functor
|
||||
|
||||
#define REGISTER_GPU_KERNEL(T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("BlockLSTM") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.HostMemory("seq_len_max") \
|
||||
.TypeConstraint<T>("T"), \
|
||||
BlockLSTMOp<GPUDevice, T, true, ICFO>);
|
||||
#define REGISTER_GPU_KERNEL(T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("BlockLSTM") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.HostMemory("seq_len_max") \
|
||||
.TypeConstraint<T>("T"), \
|
||||
BlockLSTMOp<GPUDevice, T, true, ICFO>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("BlockLSTMV2") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.HostMemory("seq_len_max") \
|
||||
.TypeConstraint<T>("T"), \
|
||||
BlockLSTMOp<GPUDevice, T, true, IFCO>);
|
||||
|
||||
REGISTER_GPU_KERNEL(Eigen::half);
|
||||
REGISTER_GPU_KERNEL(float);
|
||||
@ -1284,10 +1299,13 @@ class BlockLSTMGradOp : public OpKernel {
|
||||
bool use_peephole_;
|
||||
};
|
||||
|
||||
#define REGISTER_KERNEL(T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("BlockLSTMGrad").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
||||
BlockLSTMGradOp<CPUDevice, T, false, ICFO>);
|
||||
#define REGISTER_KERNEL(T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("BlockLSTMGrad").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
||||
BlockLSTMGradOp<CPUDevice, T, false, ICFO>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("BlockLSTMGradV2").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
||||
BlockLSTMGradOp<CPUDevice, T, false, IFCO>);
|
||||
|
||||
REGISTER_KERNEL(Eigen::half);
|
||||
REGISTER_KERNEL(float);
|
||||
@ -1345,7 +1363,8 @@ namespace functor {
|
||||
extern template struct TensorCopy<GPUDevice, T>; \
|
||||
extern template struct TensorAdd<GPUDevice, T>; \
|
||||
\
|
||||
DECLARE_GPU_BPROP(T, ICFO);
|
||||
DECLARE_GPU_BPROP(T, ICFO); \
|
||||
DECLARE_GPU_BPROP(T, IFCO);
|
||||
|
||||
DECLARE_GPU_SPECS(Eigen::half);
|
||||
DECLARE_GPU_SPECS(float);
|
||||
@ -1353,12 +1372,17 @@ DECLARE_GPU_SPECS(float);
|
||||
#undef DECLARE_GPU_BPROP
|
||||
} // end namespace functor
|
||||
|
||||
#define REGISTER_GPU_KERNEL(T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("BlockLSTMGrad") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.HostMemory("seq_len_max") \
|
||||
.TypeConstraint<T>("T"), \
|
||||
BlockLSTMGradOp<GPUDevice, T, true, ICFO>);
|
||||
#define REGISTER_GPU_KERNEL(T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("BlockLSTMGrad") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.HostMemory("seq_len_max") \
|
||||
.TypeConstraint<T>("T"), \
|
||||
BlockLSTMGradOp<GPUDevice, T, true, ICFO>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("BlockLSTMGradV2") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.HostMemory("seq_len_max") \
|
||||
.TypeConstraint<T>("T"), \
|
||||
BlockLSTMGradOp<GPUDevice, T, true, IFCO>);
|
||||
|
||||
REGISTER_GPU_KERNEL(Eigen::half);
|
||||
REGISTER_GPU_KERNEL(float);
|
||||
|
@ -468,7 +468,8 @@ void LSTMBlockCellBpropWithCUDA(
|
||||
template struct TensorCopyToUnaligned<GPUDevice, T>; \
|
||||
template struct TensorAdd<GPUDevice, T>; \
|
||||
\
|
||||
DECLARE_GPU_FBPROP(T, ICFO);
|
||||
DECLARE_GPU_FBPROP(T, ICFO); \
|
||||
DECLARE_GPU_FBPROP(T, IFCO);
|
||||
|
||||
DECLARE_GPU_SPECS(Eigen::half);
|
||||
DECLARE_GPU_SPECS(float);
|
||||
|
@ -199,6 +199,45 @@ REGISTER_OP("BlockLSTM")
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
REGISTER_OP("BlockLSTMV2")
|
||||
.Input("seq_len_max: int64")
|
||||
.Input("x: T")
|
||||
.Input("cs_prev: T")
|
||||
.Input("h_prev: T")
|
||||
.Input("w: T")
|
||||
.Input("wci: T")
|
||||
.Input("wcf: T")
|
||||
.Input("wco: T")
|
||||
.Input("b: T")
|
||||
.Output("i: T")
|
||||
.Output("cs: T")
|
||||
.Output("f: T")
|
||||
.Output("o: T")
|
||||
.Output("ci: T")
|
||||
.Output("co: T")
|
||||
.Output("h: T")
|
||||
.Attr("cell_clip: float = 0.0")
|
||||
.Attr("use_peephole: bool = false")
|
||||
.Attr("T: {half, float}")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
ShapeHandle x, b;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &x));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 1), 1, &b));
|
||||
|
||||
DimensionHandle timelen = c->Dim(x, 0);
|
||||
DimensionHandle batch_size = c->Dim(x, 1);
|
||||
DimensionHandle cell_size;
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->Divide(c->Dim(b, 0), 4, true /* evenly_divisible */, &cell_size));
|
||||
|
||||
DCHECK_EQ(7, c->num_outputs());
|
||||
ShapeHandle output = c->MakeShape({timelen, batch_size, cell_size});
|
||||
for (int i = 0; i < 7; ++i) {
|
||||
c->set_output(i, output);
|
||||
}
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
REGISTER_OP("BlockLSTMGrad")
|
||||
.Input("seq_len_max: int64")
|
||||
.Input("x: T")
|
||||
@ -251,4 +290,56 @@ REGISTER_OP("BlockLSTMGrad")
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
REGISTER_OP("BlockLSTMGradV2")
|
||||
.Input("seq_len_max: int64")
|
||||
.Input("x: T")
|
||||
.Input("cs_prev: T")
|
||||
.Input("h_prev: T")
|
||||
.Input("w: T")
|
||||
.Input("wci: T")
|
||||
.Input("wcf: T")
|
||||
.Input("wco: T")
|
||||
.Input("b: T")
|
||||
.Input("i: T")
|
||||
.Input("cs: T")
|
||||
.Input("f: T")
|
||||
.Input("o: T")
|
||||
.Input("ci: T")
|
||||
.Input("co: T")
|
||||
.Input("h: T")
|
||||
.Input("cs_grad: T")
|
||||
.Input("h_grad: T")
|
||||
.Output("x_grad: T")
|
||||
.Output("cs_prev_grad: T")
|
||||
.Output("h_prev_grad: T")
|
||||
.Output("w_grad: T")
|
||||
.Output("wci_grad: T")
|
||||
.Output("wcf_grad: T")
|
||||
.Output("wco_grad: T")
|
||||
.Output("b_grad: T")
|
||||
.Attr("use_peephole: bool")
|
||||
.Attr("T: {half, float}")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
ShapeHandle x, cs_prev, h_prev, w, wci, wco, wcf, b;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &x));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &cs_prev));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 2, &h_prev));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 2, &w));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 1, &wci));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 1, &wco));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 1, &wcf));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 1, &b));
|
||||
|
||||
c->set_output(0, x);
|
||||
c->set_output(1, cs_prev);
|
||||
c->set_output(2, h_prev);
|
||||
c->set_output(3, w);
|
||||
c->set_output(4, wci);
|
||||
c->set_output(5, wco);
|
||||
c->set_output(6, wcf);
|
||||
c->set_output(7, b);
|
||||
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
@ -3691,6 +3691,24 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "rnn_grad_test",
|
||||
srcs = ["ops/rnn_grad_test.py"],
|
||||
python_version = "PY3",
|
||||
deps = [
|
||||
":array_ops",
|
||||
":client_testlib",
|
||||
":dtypes",
|
||||
":framework_ops",
|
||||
":framework_test_lib",
|
||||
":gradients",
|
||||
":math_ops",
|
||||
":rnn_grad",
|
||||
":rnn_ops_gen",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "standard_ops",
|
||||
srcs = ["ops/standard_ops.py"],
|
||||
|
@ -21,7 +21,6 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import gen_rnn_ops
|
||||
|
||||
|
||||
@ops.RegisterGradient("BlockLSTM")
|
||||
def _block_lstm_grad(op, *grads):
|
||||
"""Gradient for the BlockLSTM op."""
|
||||
seq_len_max, x, cs_prev, h_prev, w, wci, wcf, wco, b = op.inputs
|
||||
@ -50,3 +49,7 @@ def _block_lstm_grad(op, *grads):
|
||||
use_peephole=op.get_attr("use_peephole"))
|
||||
return (None, x_grad, cs_prev_grad, h_prev_grad, w_grad, wci_grad, wcf_grad,
|
||||
wco_grad, b_grad)
|
||||
|
||||
|
||||
ops.RegisterGradient("BlockLSTM")(_block_lstm_grad)
|
||||
ops.RegisterGradient("BlockLSTMV2")(_block_lstm_grad)
|
||||
|
99
tensorflow/python/ops/rnn_grad_test.py
Normal file
99
tensorflow/python/ops/rnn_grad_test.py
Normal file
@ -0,0 +1,99 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for gradients of (block) LSTM/GRU operations."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_rnn_ops
|
||||
from tensorflow.python.ops import gradients
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import rnn_grad # pylint: disable=unused-import
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class RNNGradTest(test.TestCase):
|
||||
|
||||
@test_util.deprecated_graph_mode_only
|
||||
def testBlockLSTMV1V2Consistency(self):
|
||||
num_steps = 1
|
||||
batch_size = 1
|
||||
input_size = 1
|
||||
hidden_size = 8
|
||||
w = deterministic_random_uniform(
|
||||
[input_size + hidden_size, 4 * hidden_size])
|
||||
b = deterministic_random_uniform([4 * hidden_size])
|
||||
x = deterministic_random_uniform([num_steps, batch_size, input_size])
|
||||
cs_prev = h_prev = deterministic_random_uniform([batch_size, hidden_size])
|
||||
|
||||
all_cs, all_h = self._lstm_block(
|
||||
functools.partial(
|
||||
gen_rnn_ops.BlockLSTM,
|
||||
forget_bias=0.0, # Disable to match V2 default.
|
||||
cell_clip=0.0), # Disable to match V2 default.
|
||||
w, b, x, cs_prev, h_prev)
|
||||
w_grad, b_grad = gradients.gradients(all_cs + all_h, [w, b])
|
||||
|
||||
w_ifco, b_ifco = icfo_to_ifco(w, b)
|
||||
all_cs_ifco, all_h_ifco = self._lstm_block(
|
||||
gen_rnn_ops.BlockLSTMV2, w_ifco, b_ifco, x, cs_prev, h_prev)
|
||||
w_ifco_grad, b_ifco_grad = gradients.gradients(
|
||||
all_cs_ifco + all_h_ifco, [w_ifco, b_ifco])
|
||||
|
||||
self.assertAllEqual(all_cs, all_cs_ifco)
|
||||
self.assertAllEqual(all_h, all_h_ifco)
|
||||
self.assertAllEqual(w_grad, w_ifco_grad)
|
||||
self.assertAllEqual(b_grad, b_ifco_grad)
|
||||
|
||||
def _lstm_block(self, op, w, b, x, cs_prev, h_prev):
|
||||
w_peephole = array_ops.zeros(cs_prev.shape[1:], dtype=w.dtype)
|
||||
_, all_cs, _, _, _, _, all_h = op(
|
||||
seq_len_max=math_ops.cast(array_ops.shape(x)[0], dtypes.int64),
|
||||
x=x,
|
||||
cs_prev=cs_prev,
|
||||
h_prev=h_prev,
|
||||
w=w,
|
||||
wci=w_peephole,
|
||||
wcf=w_peephole,
|
||||
wco=w_peephole,
|
||||
b=b,
|
||||
use_peephole=False)
|
||||
return all_cs, all_h
|
||||
|
||||
|
||||
def deterministic_random_uniform(shape):
|
||||
return ops.convert_to_tensor(np.random.random(shape), dtype=dtypes.float32)
|
||||
|
||||
|
||||
def icfo_to_ifco(w, b):
|
||||
"""Convert gates' weights and biases from ICFO to IFCO layout."""
|
||||
w_i, w_c, w_f, w_o = array_ops.split(w, num_or_size_splits=4, axis=1)
|
||||
b_i, b_c, b_f, b_o = array_ops.split(b, num_or_size_splits=4)
|
||||
w_ifco = array_ops.concat([w_i, w_f, w_c, w_o], axis=1)
|
||||
b_ifco = array_ops.concat([b_i, b_f, b_c, b_o], axis=0)
|
||||
return w_ifco, b_ifco
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -472,6 +472,14 @@ tf_module {
|
||||
name: "BlockLSTMGrad"
|
||||
argspec: "args=[\'seq_len_max\', \'x\', \'cs_prev\', \'h_prev\', \'w\', \'wci\', \'wcf\', \'wco\', \'b\', \'i\', \'cs\', \'f\', \'o\', \'ci\', \'co\', \'h\', \'cs_grad\', \'h_grad\', \'use_peephole\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "BlockLSTMGradV2"
|
||||
argspec: "args=[\'seq_len_max\', \'x\', \'cs_prev\', \'h_prev\', \'w\', \'wci\', \'wcf\', \'wco\', \'b\', \'i\', \'cs\', \'f\', \'o\', \'ci\', \'co\', \'h\', \'cs_grad\', \'h_grad\', \'use_peephole\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "BlockLSTMV2"
|
||||
argspec: "args=[\'seq_len_max\', \'x\', \'cs_prev\', \'h_prev\', \'w\', \'wci\', \'wcf\', \'wco\', \'b\', \'cell_clip\', \'use_peephole\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "BoostedTreesAggregateStats"
|
||||
argspec: "args=[\'node_ids\', \'gradients\', \'hessians\', \'feature\', \'max_splits\', \'num_buckets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -472,6 +472,14 @@ tf_module {
|
||||
name: "BlockLSTMGrad"
|
||||
argspec: "args=[\'seq_len_max\', \'x\', \'cs_prev\', \'h_prev\', \'w\', \'wci\', \'wcf\', \'wco\', \'b\', \'i\', \'cs\', \'f\', \'o\', \'ci\', \'co\', \'h\', \'cs_grad\', \'h_grad\', \'use_peephole\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "BlockLSTMGradV2"
|
||||
argspec: "args=[\'seq_len_max\', \'x\', \'cs_prev\', \'h_prev\', \'w\', \'wci\', \'wcf\', \'wco\', \'b\', \'i\', \'cs\', \'f\', \'o\', \'ci\', \'co\', \'h\', \'cs_grad\', \'h_grad\', \'use_peephole\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "BlockLSTMV2"
|
||||
argspec: "args=[\'seq_len_max\', \'x\', \'cs_prev\', \'h_prev\', \'w\', \'wci\', \'wcf\', \'wco\', \'b\', \'cell_clip\', \'use_peephole\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'False\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "BoostedTreesAggregateStats"
|
||||
argspec: "args=[\'node_ids\', \'gradients\', \'hessians\', \'feature\', \'max_splits\', \'num_buckets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user