Merge pull request from Intel-tensorflow:mabuzain/native-fmt-max-pooling-graph

PiperOrigin-RevId: 336653431
Change-Id: I7ac3af7ba1b1d5f86963e265a54169bdbe2955eb
This commit is contained in:
TensorFlower Gardener 2020-10-12 06:49:38 -07:00
commit cb42c4dea1
6 changed files with 184 additions and 60 deletions

View File

@ -2458,7 +2458,7 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
for (auto ws : wsinfo_) {
if (orig_node->type_string() == ws.fwd_op &&
mkl_op_registry::IsMklLayoutDependentOp(
mkl_op_registry::IsMklOp(
mkl_op_registry::GetMklOpName(orig_node->type_string()), T)) {
// If this op is a fwd op, then we need to check if there is an
// edge from this node's fwd_slot to bwdop's bwd_slot. If there is
@ -2485,7 +2485,7 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(
nb->Attr("workspace_enabled", false);
}
} else if (orig_node->type_string() == ws.bwd_op &&
mkl_op_registry::IsMklLayoutDependentOp(
mkl_op_registry::IsMklOp(
mkl_op_registry::GetMklOpName(orig_node->type_string()),
T)) {
// If this op is a bwd op, then we need to add workspace edge and
@ -2509,10 +2509,14 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(
CHECK_NOTNULL(ws_tensors);
// Add workspace edge between fwd op and bwd op.
ws_tensors->push_back(NodeBuilder::NodeOut(e->src(), ws.ws_fwd_slot));
// Add Mkl tensor edge for workspace edge between fwd op and bwd op.
ws_tensors->push_back(NodeBuilder::NodeOut(
e->src(), DataIndexToMetaDataIndex(ws.ws_fwd_slot,
e->src()->num_outputs())));
// Check if we are running in native format mode. If so,
// we don't need to have an Mkl metadata tensor for the workspace.
if (!NativeFormatEnabled()) {
// Add Mkl tensor edge for workspace edge between fwd op and bwd op.
ws_tensors->push_back(NodeBuilder::NodeOut(
e->src(), DataIndexToMetaDataIndex(ws.ws_fwd_slot,
e->src()->num_outputs())));
}
*are_ws_tensors_added = true;
// In terms of input ordering, we add these calls to add Input
// here because workspace edge (and its Mkl tensor) is the last
@ -3671,6 +3675,15 @@ Status MklLayoutRewritePass::RewriteNodeForJustOpNameChange(
return s;
}
std::vector<NodeBuilder::NodeOut> workspace_tensors;
bool are_workspace_tensors_available = false;
AddWorkSpaceEdgeIfNeeded(g, orig_node, &nb, &workspace_tensors,
&are_workspace_tensors_available);
if (are_workspace_tensors_available) {
DCHECK_EQ(workspace_tensors.size(), 1);
nb.Input(workspace_tensors[0].node, workspace_tensors[0].index);
}
if (!NativeFormatEnabled()) {
ri->copy_attrs(const_cast<const Node*>(orig_node), &nb, true);
} else {

View File

@ -1591,6 +1591,10 @@ Status MaxPoolShape(shape_inference::InferenceContext* c) {
return MaxPoolShapeImpl(c, /*supports_explicit_padding=*/false);
}
Status MaxPoolGradShape(shape_inference::InferenceContext* c) {
return UnchangedShapeWithRank(c, 4);
}
Status MaxPoolShapeWithExplicitPadding(shape_inference::InferenceContext* c) {
return MaxPoolShapeImpl(c, /*supports_explicit_padding=*/true);
}
@ -1779,6 +1783,10 @@ Status Pool3DShape(shape_inference::InferenceContext* c) {
return Status::OK();
}
Status MaxPool3DGradShape(shape_inference::InferenceContext* c) {
return UnchangedShapeWithRank(c, 5);
}
Status AvgPool3DGradShape(shape_inference::InferenceContext* c) {
ShapeHandle s;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));

View File

@ -181,9 +181,15 @@ Status MaxPoolShape(shape_inference::InferenceContext* c);
// Shape function for MaxPoolV2-like operations.
Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs);
// Shape function for MaxPoolGrad-like operations.
Status MaxPoolGradShape(shape_inference::InferenceContext* c);
// Shape function for 3D Pooling operations.
Status Pool3DShape(shape_inference::InferenceContext* c);
// Shape function for MaxPool3DGrad-like operations.
Status MaxPool3DGradShape(shape_inference::InferenceContext* c);
// Shape function for AvgPool3DGrad-like operations.
Status AvgPool3DGradShape(shape_inference::InferenceContext* c);

View File

@ -44,7 +44,7 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
// An implementation of MaxPooling (forward).
template <typename Device, typename T>
template <typename Device, typename T, bool native_format = false>
class MklMaxPoolingOp : public MklPoolingForwardOpBase<T> {
public:
explicit MklMaxPoolingOp(OpKernelConstruction* context)
@ -52,6 +52,7 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase<T> {
// In Max Pooling, MKL-DNN does not allow passing workspace as nullptr.
// So we set workspace_enabled_ to true.
this->workspace_enabled_ = true;
this->native_format_ = native_format;
}
void Compute(OpKernelContext* context) override {
@ -59,7 +60,8 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase<T> {
const Tensor& input_tensor =
MklGetInput(context, this->kInputTensorIndexInput);
MklDnnShape dnn_shape_input;
GetMklShape(context, this->kInputTensorIndexInput, &dnn_shape_input);
GetMklShape(context, this->kInputTensorIndexInput, &dnn_shape_input,
this->native_format_);
this->SanityCheckInput(context, input_tensor, dnn_shape_input);
if (!context->status().ok()) return;
@ -230,7 +232,7 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase<T> {
workspace_tf_shape.AddDim(workspace_bytes);
AllocateOutputSetMklShape(context, kOutputTensorIndexWorkspace,
&workspace_tensor, workspace_tf_shape,
workspace_mkl_shape);
workspace_mkl_shape, this->native_format_);
DCHECK(workspace_tensor);
dnn_data_wksp->SetUsrMem(workspace_pd, workspace_tensor);
}
@ -242,11 +244,13 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase<T> {
// - The original output tensor
// - Backprop tensor for output
// It produces one output: backprop tensor for input.
template <class Device, class T>
template <class Device, class T, bool native_format = false>
class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> {
public:
explicit MklMaxPoolingGradOp(OpKernelConstruction* context)
: MklPoolingBackwardOpBase<T>(context) {}
: MklPoolingBackwardOpBase<T>(context) {
this->native_format_ = native_format;
}
void Compute(OpKernelContext* context) override {
try {
const Tensor& orig_input_tensor =
@ -256,8 +260,10 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> {
const Tensor& workspace_tensor =
MklGetInput(context, kInputTensorIndexWorkspace);
MklDnnShape orig_input_mkl_shape, grad_mkl_shape;
GetMklShape(context, kInputTensorIndexOrigInput, &orig_input_mkl_shape);
GetMklShape(context, kInputTensorIndexGradient, &grad_mkl_shape);
GetMklShape(context, kInputTensorIndexOrigInput, &orig_input_mkl_shape,
this->native_format_);
GetMklShape(context, kInputTensorIndexGradient, &grad_mkl_shape,
this->native_format_);
if (!context->status().ok()) return;
MklDnnData<T> grad_dnn_data(&cpu_engine_);
@ -337,7 +343,8 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> {
std::shared_ptr<PoolingBwdPd> pooling_bwd_pd =
pooling_bwd->GetPoolingBwdPd();
T* diff_dst_data = nullptr;
if (IS_DIFF_DST_REORDER_NEEDED(diff_dst_md, pooling_bwd_pd,
if (!this->native_format_ &&
IS_DIFF_DST_REORDER_NEEDED(diff_dst_md, pooling_bwd_pd,
pooling_bwd)) {
grad_dnn_data.SetUsrMem(diff_dst_md, &grad_tensor);
grad_dnn_data.CheckReorderToOpMem(
@ -391,36 +398,56 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> {
engine cpu_engine_ = engine(ENGINE_CPU, 0);
}; // MklMaxPoolingGradOp
#define REGISTER_MKL_MAXPOOL3D_KERNELS(T) \
REGISTER_KERNEL_BUILDER( \
Name("_MklMaxPool3D") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
MklMaxPoolingOp<CPUDevice, T>); \
REGISTER_KERNEL_BUILDER( \
Name("_MklMaxPool3DGrad") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
MklMaxPoolingGradOp<CPUDevice, T>);
#define REGISTER_MKL_MAXPOOL3D_KERNELS(T) \
REGISTER_KERNEL_BUILDER( \
Name("_MklMaxPool3D") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
MklMaxPoolingOp<CPUDevice, T>); \
REGISTER_KERNEL_BUILDER( \
Name("_MklMaxPool3DGrad") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
MklMaxPoolingGradOp<CPUDevice, T>); \
REGISTER_KERNEL_BUILDER(Name("_MklNativeMaxPool3D") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklNameChangeOpLabel), \
MklMaxPoolingOp<CPUDevice, T, true>); \
REGISTER_KERNEL_BUILDER(Name("_MklNativeMaxPool3DGrad") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklNameChangeOpLabel), \
MklMaxPoolingGradOp<CPUDevice, T, true>);
TF_CALL_float(REGISTER_MKL_MAXPOOL3D_KERNELS);
TF_CALL_bfloat16(REGISTER_MKL_MAXPOOL3D_KERNELS);
#define REGISTER_MKL_MAXPOOL_KERNELS(T) \
REGISTER_KERNEL_BUILDER( \
Name("_MklMaxPool") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
MklMaxPoolingOp<CPUDevice, T>); \
REGISTER_KERNEL_BUILDER( \
Name("_MklMaxPoolGrad") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
MklMaxPoolingGradOp<CPUDevice, T>);
#define REGISTER_MKL_MAXPOOL_KERNELS(T) \
REGISTER_KERNEL_BUILDER( \
Name("_MklMaxPool") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
MklMaxPoolingOp<CPUDevice, T>); \
REGISTER_KERNEL_BUILDER( \
Name("_MklMaxPoolGrad") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
MklMaxPoolingGradOp<CPUDevice, T>); \
REGISTER_KERNEL_BUILDER(Name("_MklNativeMaxPool") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklNameChangeOpLabel), \
MklMaxPoolingOp<CPUDevice, T, true>); \
REGISTER_KERNEL_BUILDER(Name("_MklNativeMaxPoolGrad") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklNameChangeOpLabel), \
MklMaxPoolingGradOp<CPUDevice, T, true>);
TF_CALL_float(REGISTER_MKL_MAXPOOL_KERNELS);
TF_CALL_bfloat16(REGISTER_MKL_MAXPOOL_KERNELS);

View File

@ -508,6 +508,86 @@ propagation. Uses oneDNN APIs to compute gradients of AvgPool3D function.
expected to invoke these operators.
)doc");
REGISTER_OP("_MklNativeMaxPool")
.Attr("T: {float, half, bfloat16} = DT_FLOAT")
.Attr("ksize: list(int) >= 4")
.Attr("strides: list(int) >= 4")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
.Attr(GetExplicitPaddingsAttrString())
.Attr("workspace_enabled: bool = false")
.Input("input: T")
.Output("output: T")
.Output("workspace: uint8")
.SetShapeFn(shape_inference::MaxPoolShape)
.Doc(R"doc(
oneDNN version of MaxPool operator that does not depend
on layout propagation. Uses oneDNN APIs to perform max pooling
on the input.
*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
expected to invoke these operators.
)doc");
REGISTER_OP("_MklNativeMaxPoolGrad")
.Attr("T: {float, half, bfloat16} = DT_FLOAT")
.Attr("ksize: list(int) >= 4")
.Attr("strides: list(int) >= 4")
.Attr("workspace_enabled: bool = false")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
.Attr(GetExplicitPaddingsAttrString())
.Input("orig_input: T")
.Input("orig_output: T")
.Input("grad: T")
.Input("workspace: uint8")
.Output("output: T")
.SetShapeFn(shape_inference::MaxPoolGradShape)
.Doc(R"doc(
oneDNN version of MaxPoolGrad that does not depend on layout propagation.
Uses oneDNN APIs to compute gradients of MaxPool operator.
*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
expected to invoke these operators.
)doc");
REGISTER_OP("_MklNativeMaxPool3D")
.Input("input: T")
.Output("output: T")
.Output("workspace: uint8")
.Attr("ksize: list(int) >= 5")
.Attr("strides: list(int) >= 5")
.Attr(GetPaddingAttrString())
.Attr(GetConvnet3dDataFormatAttrString())
.Attr("T: {half, bfloat16, float}")
.Attr("workspace_enabled: bool = false")
.SetShapeFn(shape_inference::Pool3DShape)
.Doc(R"doc(
oneDNN version of MaxPool3D operator that does not depend on layout propagation.
Uses oneDNN APIs to perform 3D max pooling on the input.
*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
expected to invoke these operators.
)doc");
REGISTER_OP("_MklNativeMaxPool3DGrad")
.Input("orig_input: TInput")
.Input("orig_output: TInput")
.Input("grad: T")
.Input("workspace: uint8")
.Output("output: T")
.Attr("ksize: list(int) >= 5")
.Attr("strides: list(int) >= 5")
.Attr(GetPaddingAttrString())
.Attr(GetConvnet3dDataFormatAttrString())
.Attr("T: {half, bfloat16, float} = DT_FLOAT")
.Attr("TInput: {half, bfloat16, float} = DT_FLOAT")
.Attr("workspace_enabled: bool = false")
.SetShapeFn(shape_inference::MaxPool3DGradShape)
.Doc(R"doc(
oneDNN version of MaxPool3DGrad operator that does not depend on layout
propagation. Uses oneDNN APIs to compute gradients of MaxPool3D function.
*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
expected to invoke these operators.
)doc");
REGISTER_OP("_MklQuantizedMaxPool")
.Input("input: T")
.Input("min_input: float")

View File

@ -761,9 +761,7 @@ REGISTER_OP("MaxPool3DGrad")
.Attr(GetConvnet3dDataFormatAttrString())
.Attr("T: {half, bfloat16, float} = DT_FLOAT")
.Attr("TInput: {half, bfloat16, float} = DT_FLOAT")
.SetShapeFn([](InferenceContext* c) {
return UnchangedShapeWithRank(c, 5);
});
.SetShapeFn(shape_inference::MaxPool3DGradShape);
REGISTER_OP("MaxPool3DGradGrad")
.Input("orig_input: T")
@ -867,9 +865,7 @@ REGISTER_OP("MaxPoolGrad")
.Input("grad: T")
.Output("output: T")
.Attr("T: realnumbertype = DT_FLOAT")
.SetShapeFn([](InferenceContext* c) {
return UnchangedShapeWithRank(c, 4);
});
.SetShapeFn(shape_inference::MaxPoolGradShape);
REGISTER_OP("MaxPoolGradV2")
.Attr(GetPaddingAttrString())
@ -881,9 +877,7 @@ REGISTER_OP("MaxPoolGradV2")
.Input("strides: int32")
.Output("output: T")
.Attr("T: realnumbertype = DT_FLOAT")
.SetShapeFn([](InferenceContext* c) {
return UnchangedShapeWithRank(c, 4);
});
.SetShapeFn(shape_inference::MaxPoolGradShape);
// TODO(b/150813181): Implement explicit padding.
REGISTER_OP("MaxPoolGradGrad")
@ -2331,14 +2325,12 @@ REGISTER_OP("_MklMaxPoolGrad")
.Input("mkl_workspace: uint8")
.Output("output: T")
.Output("mkl_output: uint8")
.SetShapeFn([](InferenceContext* c) {
return UnchangedShapeWithRank(c, 4);
})
.SetShapeFn(shape_inference::MaxPoolGradShape)
.Doc(R"doc(
MKL version of MaxPoolGrad. Uses MKL DNN APIs to compute gradients of
oneDNN version of MaxPoolGrad. Uses oneDNN APIs to compute gradients of
MaxPool operator.
NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
expected to invoke these operators.
)doc");
@ -2462,14 +2454,12 @@ REGISTER_OP("_MklMaxPool3DGrad")
.Attr("T: {half, bfloat16, float} = DT_FLOAT")
.Attr("TInput: {half, bfloat16, float} = DT_FLOAT")
.Attr("workspace_enabled: bool = false")
.SetShapeFn([](InferenceContext* c) {
return UnchangedShapeWithRank(c, 5);
})
.SetShapeFn(shape_inference::MaxPool3DGradShape)
.Doc(R"doc(
MKL version of MklPool3DGrad operator. Uses MKL DNN APIs to compute gradients
of MklPool function.
oneDNN version of MaxPool3DGrad operator. Uses oneDNN APIs to compute gradients
of MaxPool3D function.
NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
expected to invoke these operators.
)doc");