Add symbolic gradient functions for Conv2D and MaxPool

Change: 115608522
This commit is contained in:
A. Unique TensorFlower 2016-02-25 15:02:54 -08:00 committed by TensorFlower Gardener
parent 03fed366e4
commit aeae4825b3

View File

@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h"
namespace tensorflow {
@ -94,4 +96,68 @@ Status CrossEntropyGrad(const AttrSlice& attrs, FunctionDef* g) {
}
REGISTER_OP_GRADIENT("CrossEntropy", CrossEntropyGrad);
Status Conv2DGrad(const AttrSlice& attrs, FunctionDef* g) {
// clang-format off
*g = FDH::Define(
// Arg defs
{"input: T", "filter: T", "grad: T"},
// Ret val defs
{"input_grad: T", "filter_grad: T"},
// Attr defs
{"T: {float, double}",
"strides: list(int)",
"use_cudnn_on_gpu: bool = true",
GetPaddingAttrString(),
GetConvnetDataFormatAttrString()},
// Nodes
{
{{"i_shape"}, "Shape", {"input"}, {{"T", "$T"}}},
{{"input_grad"}, "Conv2DBackpropInput", {"i_shape", "filter", "grad"},
/*Attrs=*/{{"T", "$T"},
{"strides", "$strides"},
{"padding", "$padding"},
{"data_format", "$data_format"},
{"use_cudnn_on_gpu", "$use_cudnn_on_gpu"}}},
{{"f_shape"}, "Shape", {"filter"}, {{"T", "$T"}}},
{{"filter_grad"}, "Conv2DBackpropFilter", {"input", "f_shape", "grad"},
/*Attrs=*/{{"T", "$T"},
{"strides", "$strides"},
{"padding", "$padding"},
{"data_format", "$data_format"},
{"use_cudnn_on_gpu", "$use_cudnn_on_gpu"}}},
});
// clang-format on
return Status::OK();
}
REGISTER_OP_GRADIENT("Conv2D", Conv2DGrad);
Status MaxPoolGrad(const AttrSlice& attrs, FunctionDef* g) {
// clang-format off
*g = FDH::Define(
// Arg defs
{"input: float", "grad: float"},
// Ret val defs
{"output: float"},
// Attr defs
{"ksize: list(int) >= 4",
"strides: list(int) >= 4",
GetPaddingAttrString()},
// Nodes
{
// Invoke MaxPool again to recompute the outputs (removed by CSE?).
{{"maxpool"}, "MaxPool", {"input"},
/*Attrs=*/{{"ksize", "$ksize"},
{"strides", "$strides"},
{"padding", "$padding"}}},
{{"output"}, "MaxPoolGrad", {"input", "maxpool", "grad"},
/*Attrs=*/{{"ksize", "$ksize"},
{"strides", "$strides"},
{"padding", "$padding"}}}
});
// clang-format on
return Status::OK();
}
REGISTER_OP_GRADIENT("MaxPool", MaxPoolGrad);
} // end namespace tensorflow