diff --git a/tensorflow/core/api_def/base_api/api_def_BlockLSTMGradV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_BlockLSTMGradV2.pbtxt new file mode 100644 index 00000000000..638d1549804 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_BlockLSTMGradV2.pbtxt @@ -0,0 +1,171 @@ +op { + graph_op_name: "BlockLSTMGradV2" + visibility: HIDDEN + in_arg { + name: "seq_len_max" + description: <<END +Maximum time length actually used by this input. Outputs are padded +with zeros beyond this length. +END + } + in_arg { + name: "x" + description: <<END +The sequence input to the LSTM, shape (timelen, batch_size, num_inputs). +END + } + in_arg { + name: "cs_prev" + description: <<END +Value of the initial cell state. +END + } + in_arg { + name: "h_prev" + description: <<END +Initial output of cell (to be used for peephole). +END + } + in_arg { + name: "w" + description: <<END +The weight matrix. +END + } + in_arg { + name: "wci" + description: <<END +The weight matrix for input gate peephole connection. +END + } + in_arg { + name: "wcf" + description: <<END +The weight matrix for forget gate peephole connection. +END + } + in_arg { + name: "wco" + description: <<END +The weight matrix for output gate peephole connection. +END + } + in_arg { + name: "b" + description: <<END +The bias vector. +END + } + in_arg { + name: "i" + description: <<END +The input gate over the whole time sequence. +END + } + in_arg { + name: "cs" + description: <<END +The cell state before the tanh over the whole time sequence. +END + } + in_arg { + name: "f" + description: <<END +The forget gate over the whole time sequence. +END + } + in_arg { + name: "o" + description: <<END +The output gate over the whole time sequence. +END + } + in_arg { + name: "ci" + description: <<END +The cell input over the whole time sequence. +END + } + in_arg { + name: "co" + description: <<END +The cell after the tanh over the whole time sequence. +END + } + in_arg { + name: "h" + description: <<END +The output h vector over the whole time sequence. +END + } + in_arg { + name: "cs_grad" + description: <<END +The current gradient of cs. +END + } + in_arg { + name: "h_grad" + description: <<END +The gradient of h vector. +END + } + out_arg { + name: "x_grad" + description: <<END +The gradient of x to be back-propped. +END + } + out_arg { + name: "cs_prev_grad" + description: <<END +The gradient of cs_prev to be back-propped. +END + } + out_arg { + name: "h_prev_grad" + description: <<END +The gradient of h_prev to be back-propped. +END + } + out_arg { + name: "w_grad" + description: <<END +The gradient for w to be back-propped. +END + } + out_arg { + name: "wci_grad" + description: <<END +The gradient for wci to be back-propped. +END + } + out_arg { + name: "wcf_grad" + description: <<END +The gradient for wcf to be back-propped. +END + } + out_arg { + name: "wco_grad" + description: <<END +The gradient for wco to be back-propped. +END + } + out_arg { + name: "b_grad" + description: <<END +The gradient for w to be back-propped. +END + } + attr { + name: "use_peephole" + description: <<END +Whether to use peephole weights. +END + } + summary: "Computes the LSTM cell backward propagation for the entire time sequence." + description: <<END +This implementation is to be used in conjunction of BlockLSTMV2. +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_BlockLSTMV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_BlockLSTMV2.pbtxt new file mode 100644 index 00000000000..4da9ebaf863 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_BlockLSTMV2.pbtxt @@ -0,0 +1,137 @@ +op { + graph_op_name: "BlockLSTMV2" + visibility: HIDDEN + in_arg { + name: "seq_len_max" + description: <<END +Maximum time length actually used by this input. Outputs are padded +with zeros beyond this length. +END + } + in_arg { + name: "x" + description: <<END +The sequence input to the LSTM, shape (timelen, batch_size, num_inputs). +END + } + in_arg { + name: "cs_prev" + description: <<END +Value of the initial cell state. +END + } + in_arg { + name: "h_prev" + description: <<END +Initial output of cell (to be used for peephole). +END + } + in_arg { + name: "w" + description: <<END +The weight matrix. +END + } + in_arg { + name: "wci" + description: <<END +The weight matrix for input gate peephole connection. +END + } + in_arg { + name: "wcf" + description: <<END +The weight matrix for forget gate peephole connection. +END + } + in_arg { + name: "wco" + description: <<END +The weight matrix for output gate peephole connection. +END + } + in_arg { + name: "b" + description: <<END +The bias vector. +END + } + out_arg { + name: "i" + description: <<END +The input gate over the whole time sequence. +END + } + out_arg { + name: "cs" + description: <<END +The cell state before the tanh over the whole time sequence. +END + } + out_arg { + name: "f" + description: <<END +The forget gate over the whole time sequence. +END + } + out_arg { + name: "o" + description: <<END +The output gate over the whole time sequence. +END + } + out_arg { + name: "ci" + description: <<END +The cell input over the whole time sequence. +END + } + out_arg { + name: "co" + description: <<END +The cell after the tanh over the whole time sequence. +END + } + out_arg { + name: "h" + description: <<END +The output h vector over the whole time sequence. +END + } + attr { + name: "cell_clip" + description: <<END +Value to clip the 'cs' value to. +END + } + attr { + name: "use_peephole" + description: <<END +Whether to use peephole weights. +END + } + summary: "Computes the LSTM cell forward propagation for all the time steps." + description: <<END +This is equivalent to applying LSTMBlockCell in a loop, like so: + +```python +for x1 in unpack(x): + i1, cs1, f1, o1, ci1, co1, h1 = LSTMBlock( + x1, cs_prev, h_prev, w, wci, wcf, wco, b) + cs_prev = cs1 + h_prev = h1 + i.append(i1) + cs.append(cs1) + f.append(f1) + o.append(o1) + ci.append(ci1) + co.append(co1) + h.append(h1) +return pack(i), pack(cs), pack(f), pack(o), pack(ci), pack(ch), pack(h) + +Note that unlike LSTMBlockCell (and BlockLSTM) which uses ICFO gate layout, +this op uses IFCO. So in order for the following snippet to be equivalent +all gate-related outputs should be reordered. +``` +END +} diff --git a/tensorflow/core/kernels/rnn/lstm_ops.cc b/tensorflow/core/kernels/rnn/lstm_ops.cc index 57d3e9b1323..dfd77f9ca5f 100644 --- a/tensorflow/core/kernels/rnn/lstm_ops.cc +++ b/tensorflow/core/kernels/rnn/lstm_ops.cc @@ -235,7 +235,9 @@ void LSTMBlockCellBpropWithEigen( template struct LSTMBlockCellBprop<CPUDevice, T, false /* USE_CUBLAS */, \ GATE_LAYOUT>; -#define DECLARE_CPU_SPECS(T) DECLARE_CPU_FBPROP(T, ICFO); +#define DECLARE_CPU_SPECS(T) \ + DECLARE_CPU_FBPROP(T, ICFO); \ + DECLARE_CPU_FBPROP(T, IFCO); DECLARE_CPU_SPECS(Eigen::half); DECLARE_CPU_SPECS(float); @@ -827,7 +829,12 @@ template <typename Device, typename T, bool USE_CUBLAS, GateLayout gate_layout> class BlockLSTMOp : public OpKernel { public: explicit BlockLSTMOp(OpKernelConstruction* ctx) : OpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("forget_bias", &forget_bias_)); + if (ctx->HasAttr("forget_bias")) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("forget_bias", &forget_bias_)); + } else { + // V2 version does not have "forget_bias" attribute. + forget_bias_ = 0.0; + } OP_REQUIRES_OK(ctx, ctx->GetAttr("cell_clip", &cell_clip_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("use_peephole", &use_peephole_)); } @@ -1010,10 +1017,13 @@ class BlockLSTMOp : public OpKernel { bool use_peephole_; }; -#define REGISTER_KERNEL(T) \ - REGISTER_KERNEL_BUILDER( \ - Name("BlockLSTM").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ - BlockLSTMOp<CPUDevice, T, false, ICFO>); +#define REGISTER_KERNEL(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("BlockLSTM").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ + BlockLSTMOp<CPUDevice, T, false, ICFO>); \ + REGISTER_KERNEL_BUILDER( \ + Name("BlockLSTMV2").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ + BlockLSTMOp<CPUDevice, T, false, IFCO>); REGISTER_KERNEL(Eigen::half); REGISTER_KERNEL(float); @@ -1039,12 +1049,17 @@ DECLARE_GPU_SPECS(float); #undef DECLARE_GPU_SPECS } // end namespace functor -#define REGISTER_GPU_KERNEL(T) \ - REGISTER_KERNEL_BUILDER(Name("BlockLSTM") \ - .Device(DEVICE_GPU) \ - .HostMemory("seq_len_max") \ - .TypeConstraint<T>("T"), \ - BlockLSTMOp<GPUDevice, T, true, ICFO>); +#define REGISTER_GPU_KERNEL(T) \ + REGISTER_KERNEL_BUILDER(Name("BlockLSTM") \ + .Device(DEVICE_GPU) \ + .HostMemory("seq_len_max") \ + .TypeConstraint<T>("T"), \ + BlockLSTMOp<GPUDevice, T, true, ICFO>); \ + REGISTER_KERNEL_BUILDER(Name("BlockLSTMV2") \ + .Device(DEVICE_GPU) \ + .HostMemory("seq_len_max") \ + .TypeConstraint<T>("T"), \ + BlockLSTMOp<GPUDevice, T, true, IFCO>); REGISTER_GPU_KERNEL(Eigen::half); REGISTER_GPU_KERNEL(float); @@ -1284,10 +1299,13 @@ class BlockLSTMGradOp : public OpKernel { bool use_peephole_; }; -#define REGISTER_KERNEL(T) \ - REGISTER_KERNEL_BUILDER( \ - Name("BlockLSTMGrad").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ - BlockLSTMGradOp<CPUDevice, T, false, ICFO>); +#define REGISTER_KERNEL(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("BlockLSTMGrad").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ + BlockLSTMGradOp<CPUDevice, T, false, ICFO>); \ + REGISTER_KERNEL_BUILDER( \ + Name("BlockLSTMGradV2").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ + BlockLSTMGradOp<CPUDevice, T, false, IFCO>); REGISTER_KERNEL(Eigen::half); REGISTER_KERNEL(float); @@ -1345,7 +1363,8 @@ namespace functor { extern template struct TensorCopy<GPUDevice, T>; \ extern template struct TensorAdd<GPUDevice, T>; \ \ - DECLARE_GPU_BPROP(T, ICFO); + DECLARE_GPU_BPROP(T, ICFO); \ + DECLARE_GPU_BPROP(T, IFCO); DECLARE_GPU_SPECS(Eigen::half); DECLARE_GPU_SPECS(float); @@ -1353,12 +1372,17 @@ DECLARE_GPU_SPECS(float); #undef DECLARE_GPU_BPROP } // end namespace functor -#define REGISTER_GPU_KERNEL(T) \ - REGISTER_KERNEL_BUILDER(Name("BlockLSTMGrad") \ - .Device(DEVICE_GPU) \ - .HostMemory("seq_len_max") \ - .TypeConstraint<T>("T"), \ - BlockLSTMGradOp<GPUDevice, T, true, ICFO>); +#define REGISTER_GPU_KERNEL(T) \ + REGISTER_KERNEL_BUILDER(Name("BlockLSTMGrad") \ + .Device(DEVICE_GPU) \ + .HostMemory("seq_len_max") \ + .TypeConstraint<T>("T"), \ + BlockLSTMGradOp<GPUDevice, T, true, ICFO>); \ + REGISTER_KERNEL_BUILDER(Name("BlockLSTMGradV2") \ + .Device(DEVICE_GPU) \ + .HostMemory("seq_len_max") \ + .TypeConstraint<T>("T"), \ + BlockLSTMGradOp<GPUDevice, T, true, IFCO>); REGISTER_GPU_KERNEL(Eigen::half); REGISTER_GPU_KERNEL(float); diff --git a/tensorflow/core/kernels/rnn/lstm_ops_gpu.cu.cc b/tensorflow/core/kernels/rnn/lstm_ops_gpu.cu.cc index f3f37858986..3c1ea27b1ea 100644 --- a/tensorflow/core/kernels/rnn/lstm_ops_gpu.cu.cc +++ b/tensorflow/core/kernels/rnn/lstm_ops_gpu.cu.cc @@ -468,7 +468,8 @@ void LSTMBlockCellBpropWithCUDA( template struct TensorCopyToUnaligned<GPUDevice, T>; \ template struct TensorAdd<GPUDevice, T>; \ \ - DECLARE_GPU_FBPROP(T, ICFO); + DECLARE_GPU_FBPROP(T, ICFO); \ + DECLARE_GPU_FBPROP(T, IFCO); DECLARE_GPU_SPECS(Eigen::half); DECLARE_GPU_SPECS(float); diff --git a/tensorflow/core/ops/rnn_ops.cc b/tensorflow/core/ops/rnn_ops.cc index b926feb9d2e..af5dc3d26d1 100644 --- a/tensorflow/core/ops/rnn_ops.cc +++ b/tensorflow/core/ops/rnn_ops.cc @@ -199,6 +199,45 @@ REGISTER_OP("BlockLSTM") return Status::OK(); }); +REGISTER_OP("BlockLSTMV2") + .Input("seq_len_max: int64") + .Input("x: T") + .Input("cs_prev: T") + .Input("h_prev: T") + .Input("w: T") + .Input("wci: T") + .Input("wcf: T") + .Input("wco: T") + .Input("b: T") + .Output("i: T") + .Output("cs: T") + .Output("f: T") + .Output("o: T") + .Output("ci: T") + .Output("co: T") + .Output("h: T") + .Attr("cell_clip: float = 0.0") + .Attr("use_peephole: bool = false") + .Attr("T: {half, float}") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle x, b; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &x)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 1), 1, &b)); + + DimensionHandle timelen = c->Dim(x, 0); + DimensionHandle batch_size = c->Dim(x, 1); + DimensionHandle cell_size; + TF_RETURN_IF_ERROR( + c->Divide(c->Dim(b, 0), 4, true /* evenly_divisible */, &cell_size)); + + DCHECK_EQ(7, c->num_outputs()); + ShapeHandle output = c->MakeShape({timelen, batch_size, cell_size}); + for (int i = 0; i < 7; ++i) { + c->set_output(i, output); + } + return Status::OK(); + }); + REGISTER_OP("BlockLSTMGrad") .Input("seq_len_max: int64") .Input("x: T") @@ -251,4 +290,56 @@ REGISTER_OP("BlockLSTMGrad") return Status::OK(); }); +REGISTER_OP("BlockLSTMGradV2") + .Input("seq_len_max: int64") + .Input("x: T") + .Input("cs_prev: T") + .Input("h_prev: T") + .Input("w: T") + .Input("wci: T") + .Input("wcf: T") + .Input("wco: T") + .Input("b: T") + .Input("i: T") + .Input("cs: T") + .Input("f: T") + .Input("o: T") + .Input("ci: T") + .Input("co: T") + .Input("h: T") + .Input("cs_grad: T") + .Input("h_grad: T") + .Output("x_grad: T") + .Output("cs_prev_grad: T") + .Output("h_prev_grad: T") + .Output("w_grad: T") + .Output("wci_grad: T") + .Output("wcf_grad: T") + .Output("wco_grad: T") + .Output("b_grad: T") + .Attr("use_peephole: bool") + .Attr("T: {half, float}") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle x, cs_prev, h_prev, w, wci, wco, wcf, b; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &x)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &cs_prev)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 2, &h_prev)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 2, &w)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 1, &wci)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 1, &wco)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 1, &wcf)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 1, &b)); + + c->set_output(0, x); + c->set_output(1, cs_prev); + c->set_output(2, h_prev); + c->set_output(3, w); + c->set_output(4, wci); + c->set_output(5, wco); + c->set_output(6, wcf); + c->set_output(7, b); + + return Status::OK(); + }); + } // end namespace tensorflow diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 996c31b8635..77a21caaf63 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -3691,6 +3691,24 @@ py_library( ], ) +py_test( + name = "rnn_grad_test", + srcs = ["ops/rnn_grad_test.py"], + python_version = "PY3", + deps = [ + ":array_ops", + ":client_testlib", + ":dtypes", + ":framework_ops", + ":framework_test_lib", + ":gradients", + ":math_ops", + ":rnn_grad", + ":rnn_ops_gen", + "//third_party/py/numpy", + ], +) + py_library( name = "standard_ops", srcs = ["ops/standard_ops.py"], diff --git a/tensorflow/python/ops/rnn_grad.py b/tensorflow/python/ops/rnn_grad.py index f2707e178b0..e316b7fb8a1 100644 --- a/tensorflow/python/ops/rnn_grad.py +++ b/tensorflow/python/ops/rnn_grad.py @@ -21,7 +21,6 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import gen_rnn_ops -@ops.RegisterGradient("BlockLSTM") def _block_lstm_grad(op, *grads): """Gradient for the BlockLSTM op.""" seq_len_max, x, cs_prev, h_prev, w, wci, wcf, wco, b = op.inputs @@ -50,3 +49,7 @@ def _block_lstm_grad(op, *grads): use_peephole=op.get_attr("use_peephole")) return (None, x_grad, cs_prev_grad, h_prev_grad, w_grad, wci_grad, wcf_grad, wco_grad, b_grad) + + +ops.RegisterGradient("BlockLSTM")(_block_lstm_grad) +ops.RegisterGradient("BlockLSTMV2")(_block_lstm_grad) diff --git a/tensorflow/python/ops/rnn_grad_test.py b/tensorflow/python/ops/rnn_grad_test.py new file mode 100644 index 00000000000..2b320234538 --- /dev/null +++ b/tensorflow/python/ops/rnn_grad_test.py @@ -0,0 +1,99 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for gradients of (block) LSTM/GRU operations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools + +import numpy as np + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_rnn_ops +from tensorflow.python.ops import gradients +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import rnn_grad # pylint: disable=unused-import +from tensorflow.python.platform import test + + +class RNNGradTest(test.TestCase): + + @test_util.deprecated_graph_mode_only + def testBlockLSTMV1V2Consistency(self): + num_steps = 1 + batch_size = 1 + input_size = 1 + hidden_size = 8 + w = deterministic_random_uniform( + [input_size + hidden_size, 4 * hidden_size]) + b = deterministic_random_uniform([4 * hidden_size]) + x = deterministic_random_uniform([num_steps, batch_size, input_size]) + cs_prev = h_prev = deterministic_random_uniform([batch_size, hidden_size]) + + all_cs, all_h = self._lstm_block( + functools.partial( + gen_rnn_ops.BlockLSTM, + forget_bias=0.0, # Disable to match V2 default. + cell_clip=0.0), # Disable to match V2 default. + w, b, x, cs_prev, h_prev) + w_grad, b_grad = gradients.gradients(all_cs + all_h, [w, b]) + + w_ifco, b_ifco = icfo_to_ifco(w, b) + all_cs_ifco, all_h_ifco = self._lstm_block( + gen_rnn_ops.BlockLSTMV2, w_ifco, b_ifco, x, cs_prev, h_prev) + w_ifco_grad, b_ifco_grad = gradients.gradients( + all_cs_ifco + all_h_ifco, [w_ifco, b_ifco]) + + self.assertAllEqual(all_cs, all_cs_ifco) + self.assertAllEqual(all_h, all_h_ifco) + self.assertAllEqual(w_grad, w_ifco_grad) + self.assertAllEqual(b_grad, b_ifco_grad) + + def _lstm_block(self, op, w, b, x, cs_prev, h_prev): + w_peephole = array_ops.zeros(cs_prev.shape[1:], dtype=w.dtype) + _, all_cs, _, _, _, _, all_h = op( + seq_len_max=math_ops.cast(array_ops.shape(x)[0], dtypes.int64), + x=x, + cs_prev=cs_prev, + h_prev=h_prev, + w=w, + wci=w_peephole, + wcf=w_peephole, + wco=w_peephole, + b=b, + use_peephole=False) + return all_cs, all_h + + +def deterministic_random_uniform(shape): + return ops.convert_to_tensor(np.random.random(shape), dtype=dtypes.float32) + + +def icfo_to_ifco(w, b): + """Convert gates' weights and biases from ICFO to IFCO layout.""" + w_i, w_c, w_f, w_o = array_ops.split(w, num_or_size_splits=4, axis=1) + b_i, b_c, b_f, b_o = array_ops.split(b, num_or_size_splits=4) + w_ifco = array_ops.concat([w_i, w_f, w_c, w_o], axis=1) + b_ifco = array_ops.concat([b_i, b_f, b_c, b_o], axis=0) + return w_ifco, b_ifco + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index 0f9e320753f..1136881b8a5 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -472,6 +472,14 @@ tf_module { name: "BlockLSTMGrad" argspec: "args=[\'seq_len_max\', \'x\', \'cs_prev\', \'h_prev\', \'w\', \'wci\', \'wcf\', \'wco\', \'b\', \'i\', \'cs\', \'f\', \'o\', \'ci\', \'co\', \'h\', \'cs_grad\', \'h_grad\', \'use_peephole\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "BlockLSTMGradV2" + argspec: "args=[\'seq_len_max\', \'x\', \'cs_prev\', \'h_prev\', \'w\', \'wci\', \'wcf\', \'wco\', \'b\', \'i\', \'cs\', \'f\', \'o\', \'ci\', \'co\', \'h\', \'cs_grad\', \'h_grad\', \'use_peephole\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "BlockLSTMV2" + argspec: "args=[\'seq_len_max\', \'x\', \'cs_prev\', \'h_prev\', \'w\', \'wci\', \'wcf\', \'wco\', \'b\', \'cell_clip\', \'use_peephole\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'False\', \'None\'], " + } member_method { name: "BoostedTreesAggregateStats" argspec: "args=[\'node_ids\', \'gradients\', \'hessians\', \'feature\', \'max_splits\', \'num_buckets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index 0f9e320753f..1136881b8a5 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -472,6 +472,14 @@ tf_module { name: "BlockLSTMGrad" argspec: "args=[\'seq_len_max\', \'x\', \'cs_prev\', \'h_prev\', \'w\', \'wci\', \'wcf\', \'wco\', \'b\', \'i\', \'cs\', \'f\', \'o\', \'ci\', \'co\', \'h\', \'cs_grad\', \'h_grad\', \'use_peephole\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "BlockLSTMGradV2" + argspec: "args=[\'seq_len_max\', \'x\', \'cs_prev\', \'h_prev\', \'w\', \'wci\', \'wcf\', \'wco\', \'b\', \'i\', \'cs\', \'f\', \'o\', \'ci\', \'co\', \'h\', \'cs_grad\', \'h_grad\', \'use_peephole\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "BlockLSTMV2" + argspec: "args=[\'seq_len_max\', \'x\', \'cs_prev\', \'h_prev\', \'w\', \'wci\', \'wcf\', \'wco\', \'b\', \'cell_clip\', \'use_peephole\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'False\', \'None\'], " + } member_method { name: "BoostedTreesAggregateStats" argspec: "args=[\'node_ids\', \'gradients\', \'hessians\', \'feature\', \'max_splits\', \'num_buckets\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "