From aeae4825b35eb8cb18b5478a1d8863272b8a3634 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 25 Feb 2016 15:02:54 -0800 Subject: [PATCH] Add symbolic gradient functions for Conv2D and MaxPool Change: 115608522 --- tensorflow/core/ops/nn_grad.cc | 66 ++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/tensorflow/core/ops/nn_grad.cc b/tensorflow/core/ops/nn_grad.cc index b12b7469a11..2565fc07194 100644 --- a/tensorflow/core/ops/nn_grad.cc +++ b/tensorflow/core/ops/nn_grad.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/core/framework/function.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/util/padding.h" +#include "tensorflow/core/util/tensor_format.h" namespace tensorflow { @@ -94,4 +96,68 @@ Status CrossEntropyGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("CrossEntropy", CrossEntropyGrad); +Status Conv2DGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + *g = FDH::Define( + // Arg defs + {"input: T", "filter: T", "grad: T"}, + // Ret val defs + {"input_grad: T", "filter_grad: T"}, + // Attr defs + {"T: {float, double}", + "strides: list(int)", + "use_cudnn_on_gpu: bool = true", + GetPaddingAttrString(), + GetConvnetDataFormatAttrString()}, + // Nodes + { + {{"i_shape"}, "Shape", {"input"}, {{"T", "$T"}}}, + {{"input_grad"}, "Conv2DBackpropInput", {"i_shape", "filter", "grad"}, + /*Attrs=*/{{"T", "$T"}, + {"strides", "$strides"}, + {"padding", "$padding"}, + {"data_format", "$data_format"}, + {"use_cudnn_on_gpu", "$use_cudnn_on_gpu"}}}, + + {{"f_shape"}, "Shape", {"filter"}, {{"T", "$T"}}}, + {{"filter_grad"}, "Conv2DBackpropFilter", {"input", "f_shape", "grad"}, + /*Attrs=*/{{"T", "$T"}, + {"strides", "$strides"}, + {"padding", "$padding"}, + {"data_format", "$data_format"}, + {"use_cudnn_on_gpu", "$use_cudnn_on_gpu"}}}, + }); + // clang-format on + return Status::OK(); +} +REGISTER_OP_GRADIENT("Conv2D", Conv2DGrad); + +Status MaxPoolGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + *g = FDH::Define( + // Arg defs + {"input: float", "grad: float"}, + // Ret val defs + {"output: float"}, + // Attr defs + {"ksize: list(int) >= 4", + "strides: list(int) >= 4", + GetPaddingAttrString()}, + // Nodes + { + // Invoke MaxPool again to recompute the outputs (removed by CSE?). + {{"maxpool"}, "MaxPool", {"input"}, + /*Attrs=*/{{"ksize", "$ksize"}, + {"strides", "$strides"}, + {"padding", "$padding"}}}, + {{"output"}, "MaxPoolGrad", {"input", "maxpool", "grad"}, + /*Attrs=*/{{"ksize", "$ksize"}, + {"strides", "$strides"}, + {"padding", "$padding"}}} + }); + // clang-format on + return Status::OK(); +} +REGISTER_OP_GRADIENT("MaxPool", MaxPoolGrad); + } // end namespace tensorflow