Add symbolic gradient functions for Conv2D and MaxPool
Change: 115608522
This commit is contained in:
parent
03fed366e4
commit
aeae4825b3
@ -15,6 +15,8 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/util/padding.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -94,4 +96,68 @@ Status CrossEntropyGrad(const AttrSlice& attrs, FunctionDef* g) {
|
||||
}
|
||||
REGISTER_OP_GRADIENT("CrossEntropy", CrossEntropyGrad);
|
||||
|
||||
Status Conv2DGrad(const AttrSlice& attrs, FunctionDef* g) {
|
||||
// clang-format off
|
||||
*g = FDH::Define(
|
||||
// Arg defs
|
||||
{"input: T", "filter: T", "grad: T"},
|
||||
// Ret val defs
|
||||
{"input_grad: T", "filter_grad: T"},
|
||||
// Attr defs
|
||||
{"T: {float, double}",
|
||||
"strides: list(int)",
|
||||
"use_cudnn_on_gpu: bool = true",
|
||||
GetPaddingAttrString(),
|
||||
GetConvnetDataFormatAttrString()},
|
||||
// Nodes
|
||||
{
|
||||
{{"i_shape"}, "Shape", {"input"}, {{"T", "$T"}}},
|
||||
{{"input_grad"}, "Conv2DBackpropInput", {"i_shape", "filter", "grad"},
|
||||
/*Attrs=*/{{"T", "$T"},
|
||||
{"strides", "$strides"},
|
||||
{"padding", "$padding"},
|
||||
{"data_format", "$data_format"},
|
||||
{"use_cudnn_on_gpu", "$use_cudnn_on_gpu"}}},
|
||||
|
||||
{{"f_shape"}, "Shape", {"filter"}, {{"T", "$T"}}},
|
||||
{{"filter_grad"}, "Conv2DBackpropFilter", {"input", "f_shape", "grad"},
|
||||
/*Attrs=*/{{"T", "$T"},
|
||||
{"strides", "$strides"},
|
||||
{"padding", "$padding"},
|
||||
{"data_format", "$data_format"},
|
||||
{"use_cudnn_on_gpu", "$use_cudnn_on_gpu"}}},
|
||||
});
|
||||
// clang-format on
|
||||
return Status::OK();
|
||||
}
|
||||
REGISTER_OP_GRADIENT("Conv2D", Conv2DGrad);
|
||||
|
||||
Status MaxPoolGrad(const AttrSlice& attrs, FunctionDef* g) {
|
||||
// clang-format off
|
||||
*g = FDH::Define(
|
||||
// Arg defs
|
||||
{"input: float", "grad: float"},
|
||||
// Ret val defs
|
||||
{"output: float"},
|
||||
// Attr defs
|
||||
{"ksize: list(int) >= 4",
|
||||
"strides: list(int) >= 4",
|
||||
GetPaddingAttrString()},
|
||||
// Nodes
|
||||
{
|
||||
// Invoke MaxPool again to recompute the outputs (removed by CSE?).
|
||||
{{"maxpool"}, "MaxPool", {"input"},
|
||||
/*Attrs=*/{{"ksize", "$ksize"},
|
||||
{"strides", "$strides"},
|
||||
{"padding", "$padding"}}},
|
||||
{{"output"}, "MaxPoolGrad", {"input", "maxpool", "grad"},
|
||||
/*Attrs=*/{{"ksize", "$ksize"},
|
||||
{"strides", "$strides"},
|
||||
{"padding", "$padding"}}}
|
||||
});
|
||||
// clang-format on
|
||||
return Status::OK();
|
||||
}
|
||||
REGISTER_OP_GRADIENT("MaxPool", MaxPoolGrad);
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user