Merge pull request #30401 from Intel-tensorflow:eager-conv-fwd
PiperOrigin-RevId: 261397141
This commit is contained in:
commit
660f39fd0d
tensorflow/core
common_runtime/eager
graph
kernels
ops
util
@ -3,6 +3,11 @@ load(
|
||||
"tf_cc_test",
|
||||
"tf_cuda_library",
|
||||
)
|
||||
load(
|
||||
"//third_party/mkl:build_defs.bzl",
|
||||
"if_mkl",
|
||||
"mkl_deps",
|
||||
)
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
@ -266,7 +271,14 @@ cc_library(
|
||||
"//tensorflow/core/distributed_runtime/eager:remote_execute_node",
|
||||
"//tensorflow/core/distributed_runtime/eager:remote_copy_node",
|
||||
],
|
||||
}),
|
||||
}) + if_mkl([":mkl_eager_op_rewrite"]),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mkl_eager_op_rewrite",
|
||||
srcs = ["mkl_eager_op_rewrite.cc"],
|
||||
copts = ["-DINTEL_MKL=1"],
|
||||
deps = [":eager_op_rewrite_registry"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
|
180
tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite.cc
Normal file
180
tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite.cc
Normal file
@ -0,0 +1,180 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifdef INTEL_MKL
|
||||
#include "tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h"
|
||||
#include "tensorflow/core/graph/mkl_graph_util.h"
|
||||
#include "tensorflow/core/graph/mkl_layout_pass.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
#include "tensorflow/core/util/util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class MklEagerOpRewrite : public EagerOpRewrite {
|
||||
public:
|
||||
MklEagerOpRewrite(string name, string file, string line);
|
||||
typedef struct {
|
||||
string op_name;
|
||||
std::function<bool(EagerOperation*)> RewriteRule;
|
||||
std::function<Status(EagerOperation*, std::unique_ptr<EagerOperation>*)>
|
||||
CreateMklOp;
|
||||
} MklEagerOp;
|
||||
|
||||
private:
|
||||
// TODO(intel-tf): refactor with unordered_map;
|
||||
// especially when adding more ops/rewrite rules in future.
|
||||
std::vector<MklEagerOp> mkl_eager_ops_;
|
||||
|
||||
// The entry point to execute the op rewrite.
|
||||
Status Run(EagerOperation* orig_op,
|
||||
std::unique_ptr<tensorflow::EagerOperation>* out_op);
|
||||
|
||||
// Initializes the new op and sets up its inputs and attributes
|
||||
static Status SetupNewOp(EagerOperation* orig_op, const string mkl_op_name,
|
||||
std::unique_ptr<EagerOperation>* new_mkl_op);
|
||||
|
||||
// Creates new MKL op for Conv2D, Conv2DBackpropInput and
|
||||
// Conv2DBackpropFilter.
|
||||
static Status CreateMklConv2DOp(
|
||||
EagerOperation* orig_op, std::unique_ptr<EagerOperation>* mkl_conv2d_op);
|
||||
|
||||
// Rewrite rule for Conv2D, Conv2DBackpropInput and Conv2DBackpropFilter.
|
||||
static bool RewriteConv2D(EagerOperation* op);
|
||||
|
||||
// Calls op-specific rewrite function to create new MKL op.
|
||||
Status RewriteToMklOp(EagerOperation* orig_op,
|
||||
std::unique_ptr<EagerOperation>* mkl_op,
|
||||
const int op_idx);
|
||||
|
||||
// Checks whether we can rewrite the op to MKL one or not.
|
||||
bool ShouldRewriteOp(EagerOperation* op, int* op_idx);
|
||||
};
|
||||
|
||||
REGISTER_REWRITE(EagerOpRewriteRegistry::PRE_EXECUTION, MklEagerOpRewrite);
|
||||
|
||||
// Constructor
|
||||
MklEagerOpRewrite::MklEagerOpRewrite(string name, string file, string line)
|
||||
: EagerOpRewrite(name, file, line) {
|
||||
mkl_eager_ops_.push_back({"Conv2D", RewriteConv2D, CreateMklConv2DOp});
|
||||
mkl_eager_ops_.push_back(
|
||||
{"Conv2DBackpropInput", RewriteConv2D, CreateMklConv2DOp});
|
||||
mkl_eager_ops_.push_back(
|
||||
{"Conv2DBackpropFilter", RewriteConv2D, CreateMklConv2DOp});
|
||||
}
|
||||
|
||||
Status MklEagerOpRewrite::Run(
|
||||
EagerOperation* orig_op,
|
||||
std::unique_ptr<tensorflow::EagerOperation>* out_op) {
|
||||
int found_op_idx = -1;
|
||||
if (ShouldRewriteOp(orig_op, &found_op_idx)) {
|
||||
TF_CHECK_OK(RewriteToMklOp(orig_op, out_op, found_op_idx));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MklEagerOpRewrite::SetupNewOp(
|
||||
EagerOperation* orig_op, const string mkl_op_name,
|
||||
std::unique_ptr<EagerOperation>* new_mkl_op) {
|
||||
const tensorflow::AttrTypeMap* types;
|
||||
bool is_function = false;
|
||||
TF_RETURN_IF_ERROR(
|
||||
tensorflow::AttrTypeMapForOp(mkl_op_name.c_str(), &types, &is_function));
|
||||
EagerContext* ctx = orig_op->EagerContext();
|
||||
new_mkl_op->reset(new tensorflow::EagerOperation(ctx, mkl_op_name.c_str(),
|
||||
is_function, types));
|
||||
|
||||
int num_inputs = orig_op->Inputs().size();
|
||||
// Add all inputs to the new op.
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
(*new_mkl_op)->AddInput(orig_op->Inputs()[i]);
|
||||
}
|
||||
|
||||
// Copy all attributes to the new op.
|
||||
string name;
|
||||
const NodeDef& orig_ndef = orig_op->MutableAttrs()->BuildNodeDef();
|
||||
|
||||
AttrSlice attr_list(orig_ndef);
|
||||
for (const auto& attr : attr_list) {
|
||||
(*new_mkl_op)->MutableAttrs()->Set(attr.first, attr.second);
|
||||
}
|
||||
|
||||
(*new_mkl_op)
|
||||
->MutableAttrs()
|
||||
->Set("_kernel", mkl_op_registry::kMklNameChangeOpLabel);
|
||||
|
||||
if (orig_op->Device() != nullptr) {
|
||||
(*new_mkl_op)->SetDevice(orig_op->Device());
|
||||
} else {
|
||||
string device_name =
|
||||
DeviceNameUtils::ParsedNameToString(orig_op->GetDeviceName());
|
||||
(*new_mkl_op)->SetDeviceName(device_name.c_str());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MklEagerOpRewrite::CreateMklConv2DOp(
|
||||
EagerOperation* orig_op, std::unique_ptr<EagerOperation>* mkl_conv2d_op) {
|
||||
const string mkl_op_name =
|
||||
mkl_op_registry::GetMklEagerOpName(orig_op->Name());
|
||||
TF_CHECK_OK(SetupNewOp(orig_op, mkl_op_name, mkl_conv2d_op));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool MklEagerOpRewrite::ShouldRewriteOp(EagerOperation* op, int* op_idx) {
|
||||
// Don't rewrite the op if MKL use is disabled at runtime.
|
||||
if (DisableMKL()) {
|
||||
return false;
|
||||
}
|
||||
DataType data_type;
|
||||
if (op->Attrs().Get("T", &data_type) != Status::OK()) {
|
||||
return false;
|
||||
}
|
||||
// Check if we have registered MKL kernel for this op.
|
||||
if (!mkl_op_registry::IsMklNameChangeOp(
|
||||
mkl_op_registry::GetMklEagerOpName(op->Name()), data_type) &&
|
||||
!mkl_op_registry::IsMklNameChangeOp(
|
||||
mkl_op_registry::GetMklOpName(op->Name()), data_type)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
*op_idx = -1;
|
||||
// Find and call the op's rewrite rule that determines whether we need to
|
||||
// rewrite this op or not.
|
||||
for (auto it = mkl_eager_ops_.begin(); it != mkl_eager_ops_.end(); ++it) {
|
||||
if (it->op_name.compare(op->Name()) == 0 && it->RewriteRule(op)) {
|
||||
*op_idx = it - mkl_eager_ops_.begin();
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
Status MklEagerOpRewrite::RewriteToMklOp(
|
||||
EagerOperation* orig_op, std::unique_ptr<EagerOperation>* mkl_op,
|
||||
const int op_idx) {
|
||||
mkl_eager_ops_[op_idx].CreateMklOp(orig_op, mkl_op);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool MklEagerOpRewrite::RewriteConv2D(EagerOperation* op) {
|
||||
const NodeDef& ndef = op->MutableAttrs()->BuildNodeDef();
|
||||
string padding;
|
||||
TF_CHECK_OK(GetNodeAttr(ndef, "padding", &padding));
|
||||
// Right now MKL Conv2D does not support explicit padding.
|
||||
return (padding != "EXPLICIT");
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
#endif // INTEL_MKL
|
@ -104,12 +104,24 @@ static const char* kMklQuantizedOpLabelPattern = "label='QuantizedMklOp'";
|
||||
|
||||
// Prefix that we add to Tensorflow op name to construct Mkl op name.
|
||||
static const char* const kMklOpPrefix = "_Mkl";
|
||||
// TODO(intel-tf): PR review feedback (penpornk)
|
||||
// Can we add eager_mode (or is_eager) as an op attribute instead?
|
||||
// This way we don't need to rename the op just to pass eager_mode
|
||||
// through template parameter.
|
||||
static const char* const kMklEagerOpPrefix = "_MklEager";
|
||||
|
||||
// Get the name of Mkl op from original TensorFlow op
|
||||
// We prefix 'Mkl' to the original op to get Mkl op.
|
||||
inline string GetMklOpName(const string& name) {
|
||||
return string(kMklOpPrefix) + name;
|
||||
}
|
||||
|
||||
// Get the name of Mkl Eager op from original TensorFlow op
|
||||
// We prefix 'MklEager' to the original op to get Mkl Eager op.
|
||||
inline string GetMklEagerOpName(const string& name) {
|
||||
return string(kMklEagerOpPrefix) + name;
|
||||
}
|
||||
|
||||
// Check whether opname with type T is registered as MKL operator
|
||||
// that can accept input tensors in MKL layout.
|
||||
//
|
||||
|
@ -518,7 +518,8 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
// Base class for convolution forward operations
|
||||
template <typename Device, typename Tinput, typename Tfilter, typename Tbias,
|
||||
typename Toutput, typename Ttemp_output, typename Tpadding,
|
||||
bool bias_enabled, bool pad_enabled, bool is_depthwise>
|
||||
bool bias_enabled, bool pad_enabled, bool is_depthwise,
|
||||
bool eager_mode>
|
||||
class MklConvOp : public OpKernel {
|
||||
public:
|
||||
~MklConvOp() {}
|
||||
@ -545,8 +546,10 @@ class MklConvOp : public OpKernel {
|
||||
"strides in the batch and depth dimensions."));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
|
||||
is_filter_const_ = false;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("is_filter_const", &is_filter_const_));
|
||||
if (context->HasAttr("is_filter_const")) {
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("is_filter_const", &is_filter_const_));
|
||||
}
|
||||
|
||||
if (strides_.size() == 4) {
|
||||
OP_REQUIRES(context, dilations_.size() == 4,
|
||||
@ -589,8 +592,9 @@ class MklConvOp : public OpKernel {
|
||||
const Tensor& filter_tensor = MklGetInput(context, kInputIndex_Filter);
|
||||
|
||||
MklDnnShape src_mkl_shape, filter_mkl_shape;
|
||||
GetMklShape(context, kInputIndex_Src, &src_mkl_shape);
|
||||
GetMklShape(context, kInputIndex_Filter, &filter_mkl_shape);
|
||||
GetMklShape(context, kInputIndex_Src, &src_mkl_shape, eager_mode);
|
||||
GetMklShape(context, kInputIndex_Filter, &filter_mkl_shape, eager_mode);
|
||||
|
||||
OP_REQUIRES(context, filter_mkl_shape.IsMklTensor() == false,
|
||||
errors::InvalidArgument("Filter should not be in "
|
||||
"Mkl Layout"));
|
||||
@ -620,8 +624,9 @@ class MklConvOp : public OpKernel {
|
||||
// Get shapes of input tensors in MKL-DNN order
|
||||
MklDnnConvUtil conv_utl(context, strides_, padding_, data_format_,
|
||||
dilations_);
|
||||
auto src_tf_shape = GetTfShape(context, kInputIndex_Src);
|
||||
auto filter_tf_shape = GetTfShape(context, kInputIndex_Filter);
|
||||
auto src_tf_shape = GetTfShape(context, kInputIndex_Src, eager_mode);
|
||||
auto filter_tf_shape =
|
||||
GetTfShape(context, kInputIndex_Filter, eager_mode);
|
||||
conv_utl.GetConvFwdSizesInMklOrder(
|
||||
src_tf_shape, filter_tf_shape, &src_dims, &filter_dims, &strides,
|
||||
&dilations, &dst_dims_tf_order, &dst_dims_mkl_order, &padding_left,
|
||||
@ -634,15 +639,17 @@ class MklConvOp : public OpKernel {
|
||||
|
||||
// Corner cases: output with 0 elements and 0 batch size.
|
||||
Tensor* dst_tensor = nullptr;
|
||||
Tensor tmp_tensor;
|
||||
bool emit_filter_output = (typeid(Tinput) == typeid(Tfilter) &&
|
||||
typeid(Tinput) == typeid(Toutput) &&
|
||||
(typeid(Tinput) == typeid(float) ||
|
||||
typeid(Tinput) == typeid(bfloat16)));
|
||||
typeid(Tinput) == typeid(bfloat16))) &&
|
||||
!eager_mode;
|
||||
if (dst_tf_shape.num_elements() == 0 || dst_dims_tf_order[0] == 0) {
|
||||
MklDnnShape dst_mkl_shape;
|
||||
dst_mkl_shape.SetMklTensor(false);
|
||||
AllocateOutputSetMklShape(context, kOutputIndex_Dst, &dst_tensor,
|
||||
dst_tf_shape, dst_mkl_shape);
|
||||
src_tf_shape, dst_mkl_shape, eager_mode);
|
||||
|
||||
// MklConv2D/3D also outputs converted filter as 2nd output.
|
||||
filter_mkl_shape.SetMklTensor(false);
|
||||
@ -754,9 +761,10 @@ class MklConvOp : public OpKernel {
|
||||
convFwdDims, do_not_cache);
|
||||
|
||||
// Allocate output tensors `output_tensor` and `filter_out_tensor`
|
||||
MklDnnShape output_mkl_shape;
|
||||
std::shared_ptr<ConvFwdPd> conv_fwd_pd = conv_fwd->GetPrimitiveDesc();
|
||||
AllocateOutputTensor(context, *conv_fwd_pd, dst_dims_mkl_order, tf_fmt,
|
||||
&dst_tensor);
|
||||
&output_mkl_shape, &dst_tensor, &tmp_tensor);
|
||||
|
||||
Tensor* filter_out_tensor = nullptr;
|
||||
if (emit_filter_output) {
|
||||
@ -828,7 +836,28 @@ class MklConvOp : public OpKernel {
|
||||
this->GetBiasHandle(context, conv_fwd_pd, bias_tensor);
|
||||
conv_fwd->Execute(src_data, filter_data, bias_data, dst_data);
|
||||
} else {
|
||||
conv_fwd->Execute(src_data, filter_data, dst_data);
|
||||
if (!eager_mode) {
|
||||
conv_fwd->Execute(src_data, filter_data, dst_data);
|
||||
} else {
|
||||
// In eager mode we first write the output to temporary
|
||||
// buffer in MKL format. Then we convert the data to TF format.
|
||||
Ttemp_output* tmp_data = reinterpret_cast<Ttemp_output*>(
|
||||
tmp_tensor.flat<Toutput>().data());
|
||||
conv_fwd->Execute(src_data, filter_data, tmp_data);
|
||||
|
||||
// Now we need to convert the output to TF format.
|
||||
auto output_tf_md = output_mkl_shape.GetTfLayout();
|
||||
auto output_tf_pd = memory::primitive_desc(output_tf_md, cpu_engine_);
|
||||
auto dst_pd = (*conv_fwd_pd).dst_primitive_desc();
|
||||
mkldnn::reorder::primitive_desc reorder_pd =
|
||||
mkldnn::reorder::primitive_desc(dst_pd, output_tf_pd);
|
||||
std::vector<mkldnn::primitive> net;
|
||||
memory* tmp_data_mem = new memory(dst_pd, tmp_data);
|
||||
memory* dst_data_mem = new memory(output_tf_pd, dst_data);
|
||||
net.push_back(
|
||||
mkldnn::reorder(reorder_pd, *tmp_data_mem, *dst_data_mem));
|
||||
stream(stream::kind::eager).submit(net).wait();
|
||||
}
|
||||
}
|
||||
|
||||
// Delete primitive since it is not cached.
|
||||
@ -942,7 +971,9 @@ class MklConvOp : public OpKernel {
|
||||
const ConvFwdPd& conv_prim_desc,
|
||||
const memory::dims& output_dims_mkl_order,
|
||||
MKL_TENSOR_FORMAT output_tf_format,
|
||||
Tensor** output_tensor) {
|
||||
MklDnnShape* output_mkl_shape,
|
||||
Tensor** output_tensor,
|
||||
Tensor* tmp_tensor) {
|
||||
DCHECK(output_tensor);
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
auto dst_md = conv_prim_desc.dst_desc();
|
||||
@ -960,25 +991,29 @@ class MklConvOp : public OpKernel {
|
||||
}
|
||||
|
||||
// Allocate shape of MKL tensor
|
||||
MklDnnShape output_mkl_shape;
|
||||
output_mkl_shape.SetMklTensor(true);
|
||||
output_mkl_shape.SetMklLayout(&DST_MD);
|
||||
output_mkl_shape.SetElemType(MklDnnType<Toutput>());
|
||||
output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(),
|
||||
output_dims_mkl_order, output_tf_format);
|
||||
output_mkl_shape->SetMklTensor(true);
|
||||
output_mkl_shape->SetMklLayout(&DST_MD);
|
||||
output_mkl_shape->SetElemType(MklDnnType<Toutput>());
|
||||
output_mkl_shape->SetTfLayout(output_dims_mkl_order.size(),
|
||||
output_dims_mkl_order, output_tf_format);
|
||||
|
||||
// Allocate shape of TF tensor
|
||||
TensorShape output_tf_shape;
|
||||
output_tf_shape.AddDim((DST_MD.get_size() / sizeof(Toutput)));
|
||||
if (eager_mode) {
|
||||
AllocTmpBuffer<Toutput>(context, tmp_tensor, output_tf_shape);
|
||||
output_tf_shape = output_mkl_shape->GetTfShape();
|
||||
}
|
||||
|
||||
AllocateOutputSetMklShape(context, kOutputIndex_Dst, output_tensor,
|
||||
output_tf_shape, output_mkl_shape);
|
||||
output_tf_shape, *output_mkl_shape, eager_mode);
|
||||
if (fuse_add_) {
|
||||
const Tensor& add_tensor = MklGetInput(context, kInputIndex_Add);
|
||||
MklDnnShape add_mkl_shape;
|
||||
GetMklShape(context, kInputIndex_Add, &add_mkl_shape);
|
||||
|
||||
// Check if reorder is needed
|
||||
if (add_mkl_shape == output_mkl_shape) {
|
||||
if (add_mkl_shape == *output_mkl_shape) {
|
||||
DCHECK((*output_tensor)->CopyFrom(add_tensor, output_tf_shape));
|
||||
} else {
|
||||
if (add_mkl_shape.IsMklTensor()) {
|
||||
@ -986,14 +1021,14 @@ class MklConvOp : public OpKernel {
|
||||
} else {
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
auto output_format_tag = MklTensorFormatToMklDnnDataFormat(
|
||||
output_mkl_shape.GetTfDataFormat());
|
||||
output_mkl_shape->GetTfDataFormat());
|
||||
DCHECK_NE(output_format_tag, memory::format_tag::undef);
|
||||
auto add_md = memory::desc(output_dims_mkl_order,
|
||||
MklDnnType<Toutput>(), output_format_tag);
|
||||
#else
|
||||
auto add_md =
|
||||
memory::desc(output_dims_mkl_order, MklDnnType<Toutput>(),
|
||||
output_mkl_shape.GetTfDataFormat());
|
||||
output_mkl_shape->GetTfDataFormat());
|
||||
auto add_pd = memory::primitive_desc(add_md, this->cpu_engine_);
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
void* add_buf = static_cast<void*>(
|
||||
@ -1290,11 +1325,11 @@ template <typename Device, typename Tinput, typename Tfilter, typename Tbias,
|
||||
bool pad_enabled>
|
||||
class MklFusedConvOp
|
||||
: public MklConvOp<Device, Tinput, Tfilter, Tbias, Toutput, Ttemp_output,
|
||||
Tpadding, false, false, false> {
|
||||
Tpadding, false, false, false, false> {
|
||||
public:
|
||||
explicit MklFusedConvOp(OpKernelConstruction* context)
|
||||
: MklConvOp<Device, Tinput, Tfilter, Tbias, Toutput, Ttemp_output,
|
||||
Tpadding, false, false, false>(context) {
|
||||
Tpadding, false, false, false, false>(context) {
|
||||
// Since we came here through the registration of _MklFusedConv2D, get
|
||||
// all information from 'fused_ops' and 'num_args'
|
||||
std::vector<string> fused_ops;
|
||||
@ -1386,7 +1421,7 @@ template <typename Device, typename Tbias, typename Toutput,
|
||||
typename Ttemp_output, bool bias_enabled, bool is_depthwise>
|
||||
class MklQuantizedConv2DOp
|
||||
: public MklConvOp<Device, quint8, qint8, Tbias, Toutput, Ttemp_output,
|
||||
int32, bias_enabled, false, is_depthwise> {
|
||||
int32, bias_enabled, false, is_depthwise, false> {
|
||||
public:
|
||||
virtual ~MklQuantizedConv2DOp() {
|
||||
if (this->input_bias_ != nullptr) {
|
||||
@ -1402,7 +1437,7 @@ class MklQuantizedConv2DOp
|
||||
|
||||
explicit MklQuantizedConv2DOp(OpKernelConstruction* context)
|
||||
: MklConvOp<Device, quint8, qint8, Tbias, Toutput, Ttemp_output, int32,
|
||||
bias_enabled, false, is_depthwise>(context) {
|
||||
bias_enabled, false, is_depthwise, false>(context) {
|
||||
bool is_filter_const;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("is_filter_const", &is_filter_const));
|
||||
@ -1413,7 +1448,7 @@ class MklQuantizedConv2DOp
|
||||
void Compute(OpKernelContext* context) override {
|
||||
// Compute int32 output tensor
|
||||
MklConvOp<Device, quint8, qint8, Tbias, Toutput, Ttemp_output, int32,
|
||||
bias_enabled, false, is_depthwise>::Compute(context);
|
||||
bias_enabled, false, is_depthwise, false>::Compute(context);
|
||||
|
||||
// Compute additional outputs: min/max scalars.
|
||||
int bias_index_offset;
|
||||
@ -1475,8 +1510,8 @@ class MklQuantizedConv2DOp
|
||||
void ExtendConvFwdParams(OpKernelContext* context,
|
||||
MklConvFwdParams& params) override {
|
||||
MklConvOp<Device, quint8, qint8, Tbias, Toutput, Ttemp_output, int32,
|
||||
bias_enabled, false, is_depthwise>::ExtendConvFwdParams(context,
|
||||
params);
|
||||
bias_enabled, false, is_depthwise,
|
||||
false>::ExtendConvFwdParams(context, params);
|
||||
|
||||
// When the output type is quint8, the output data id requantized
|
||||
// into quint8. A post_op "output_scale" is added to do the conversion.
|
||||
@ -1674,7 +1709,9 @@ class MklQuantizedConv2DSumReluOp
|
||||
const ConvFwdPd& conv_prim_desc,
|
||||
const memory::dims& output_dims_mkl_order,
|
||||
MKL_TENSOR_FORMAT output_tf_format,
|
||||
Tensor** output_tensor) override {
|
||||
MklDnnShape* output_mkl_shape,
|
||||
Tensor** output_tensor,
|
||||
Tensor* tmp_tensor) override {
|
||||
int summand_idx = context->num_inputs() / 2 - 1;
|
||||
if (std::is_same<Toutput, quint8>::value) {
|
||||
summand_idx -= 2;
|
||||
@ -1701,12 +1738,12 @@ class MklQuantizedConv2DSumReluOp
|
||||
*output_tensor = const_cast<Tensor*>(&summand);
|
||||
return;
|
||||
}
|
||||
|
||||
MklConvOp<Device, quint8, qint8, Tbias, Toutput, Ttemp_output, int32,
|
||||
bias_enabled, false,
|
||||
bias_enabled, false, false,
|
||||
false>::AllocateOutputTensor(context, conv_prim_desc,
|
||||
output_dims_mkl_order,
|
||||
output_tf_format, output_tensor);
|
||||
output_tf_format, output_mkl_shape,
|
||||
output_tensor, tmp_tensor);
|
||||
const Tensor& summand = MklGetInput(context, summand_idx);
|
||||
if (summand.dtype() != DT_FLOAT)
|
||||
TF_CHECK_OK(Status(error::Code::FAILED_PRECONDITION,
|
||||
@ -2114,46 +2151,52 @@ REGISTER_KERNEL_BUILDER(
|
||||
MklQuantizedConv2DReluOp<CPUDevice, qint32, quint8, quint8, true, true>);
|
||||
|
||||
// Register 2D operations
|
||||
#define REGISTER_MKL_CPU_2D(T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("_MklConv2D") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
|
||||
MklConvOp<CPUDevice, T, T, T, T, T, int32, false, false, false>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("_MklConv2DWithBias") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
|
||||
MklConvOp<CPUDevice, T, T, T, T, T, int32, true, false, false>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("__MklDummyConv2DWithBias") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
|
||||
MklDummyOp<CPUDevice, T>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("_MklPadWithConv2D") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<int32>("Tpaddings") \
|
||||
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
|
||||
MklConvOp<CPUDevice, T, T, T, T, T, int32, false, true, false>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("_MklPadWithConv2D") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<int64>("Tpaddings") \
|
||||
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
|
||||
MklConvOp<CPUDevice, T, T, T, T, T, int64, false, true, false>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("__MklDummyPadWithConv2D") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<int32>("Tpaddings") \
|
||||
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
|
||||
MklDummyOp<CPUDevice, T>);
|
||||
#define REGISTER_MKL_CPU_2D(T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("_MklConv2D") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
|
||||
MklConvOp<CPUDevice, T, T, T, T, T, int32, false, false, false, false>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("_MklConv2DWithBias") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
|
||||
MklConvOp<CPUDevice, T, T, T, T, T, int32, true, false, false, false>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("__MklDummyConv2DWithBias") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
|
||||
MklDummyOp<CPUDevice, T>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("_MklPadWithConv2D") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<int32>("Tpaddings") \
|
||||
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
|
||||
MklConvOp<CPUDevice, T, T, T, T, T, int32, false, true, false, false>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("_MklPadWithConv2D") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<int64>("Tpaddings") \
|
||||
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
|
||||
MklConvOp<CPUDevice, T, T, T, T, T, int64, false, true, false, false>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("__MklDummyPadWithConv2D") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<int32>("Tpaddings") \
|
||||
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
|
||||
MklDummyOp<CPUDevice, T>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("_MklEagerConv2D") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.Label(mkl_op_registry::kMklNameChangeOpLabel), \
|
||||
MklConvOp<CPUDevice, T, T, T, T, T, int32, false, false, false, true>);
|
||||
|
||||
TF_CALL_float(REGISTER_MKL_CPU_2D);
|
||||
TF_CALL_bfloat16(REGISTER_MKL_CPU_2D);
|
||||
@ -2164,7 +2207,7 @@ TF_CALL_bfloat16(REGISTER_MKL_CPU_2D);
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
|
||||
MklConvOp<CPUDevice, T, T, T, T, T, int32, false, false, true>);
|
||||
MklConvOp<CPUDevice, T, T, T, T, T, int32, false, false, true, false>);
|
||||
|
||||
TF_CALL_float(REGISTER_MKL_CPU_2D_DEPTHWISE);
|
||||
TF_CALL_bfloat16(REGISTER_MKL_CPU_2D_DEPTHWISE);
|
||||
@ -2210,7 +2253,7 @@ TF_CALL_bfloat16(REGISTER_MKL_CPU_2D_FUSED);
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
|
||||
MklConvOp<CPUDevice, T, T, T, T, T, int32, false, false, false>);
|
||||
MklConvOp<CPUDevice, T, T, T, T, T, int32, false, false, false, false>);
|
||||
TF_CALL_float(REGISTER_MKL_CPU_3D);
|
||||
TF_CALL_bfloat16(REGISTER_MKL_CPU_3D);
|
||||
|
||||
|
@ -1652,6 +1652,25 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
|
||||
expected to invoke these operators.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("_MklEagerConv2D")
|
||||
.Input("input: T")
|
||||
.Input("filter: T")
|
||||
.Output("output: T")
|
||||
.Attr("T: {bfloat16, float}")
|
||||
.Attr("strides: list(int)")
|
||||
.Attr("use_cudnn_on_gpu: bool = true")
|
||||
.Attr(GetPaddingAttrStringWithExplicit())
|
||||
.Attr(GetExplicitPaddingsAttrString())
|
||||
.Attr(GetConvnetDataFormatAttrString())
|
||||
.Attr("dilations: list(int) = [1, 1, 1, 1]")
|
||||
.SetShapeFn(shape_inference::Conv2DShapeWithExplicitPadding)
|
||||
.Doc(R"doc(
|
||||
MKL version of Conv2D operator for Eager mode. Uses MKL DNN APIs to perform 2D convolution.
|
||||
|
||||
NOTE Do not invoke this operator directly in Python. Eager Op rewrite is
|
||||
expected to invoke these operators.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("__MklDummyConv2DWithBias")
|
||||
.Input("input: T")
|
||||
.Input("filter: T")
|
||||
|
@ -204,6 +204,7 @@ memory::format_tag MklTensorFormatToMklDnnDataFormat(MklTensorFormat format);
|
||||
|
||||
TensorFormat MklDnn3DDataFormatToTFDataFormat(MKL_TENSOR_FORMAT format);
|
||||
TensorFormat MklDnnDataFormatToTFDataFormat(MKL_TENSOR_FORMAT format);
|
||||
|
||||
memory::dims CalculateTFStrides(const memory::dims& dims_tf_order);
|
||||
memory::desc CreateBlockedMemDescHelper(const memory::dims& dim,
|
||||
const memory::dims& strides,
|
||||
@ -696,15 +697,24 @@ inline Tensor ConvertMklToTF(OpKernelContext* context, const Tensor& mkl_tensor,
|
||||
}
|
||||
|
||||
// Get the MKL shape from the second string tensor
|
||||
inline void GetMklShape(OpKernelContext* ctext, int n, MklDnnShape* mklshape,
|
||||
bool eager_mode) {
|
||||
if (!eager_mode) {
|
||||
mklshape->DeSerializeMklDnnShape(
|
||||
ctext->input(GetTensorMetaDataIndex(n, ctext->num_inputs()))
|
||||
.flat<uint8>()
|
||||
.data(),
|
||||
ctext->input(GetTensorMetaDataIndex(n, ctext->num_inputs()))
|
||||
.flat<uint8>()
|
||||
.size() *
|
||||
sizeof(uint8));
|
||||
} else {
|
||||
mklshape->SetMklTensor(false);
|
||||
}
|
||||
}
|
||||
|
||||
inline void GetMklShape(OpKernelContext* ctext, int n, MklDnnShape* mklshape) {
|
||||
mklshape->DeSerializeMklDnnShape(
|
||||
ctext->input(GetTensorMetaDataIndex(n, ctext->num_inputs()))
|
||||
.flat<uint8>()
|
||||
.data(),
|
||||
ctext->input(GetTensorMetaDataIndex(n, ctext->num_inputs()))
|
||||
.flat<uint8>()
|
||||
.size() *
|
||||
sizeof(uint8));
|
||||
GetMklShape(ctext, n, mklshape, false);
|
||||
}
|
||||
|
||||
// Gets the actual input
|
||||
@ -733,14 +743,15 @@ inline void GetMklShapeList(OpKernelContext* ctext, StringPiece name,
|
||||
/// Get shape of input tensor pointed by 'input_idx' in TensorShape format.
|
||||
/// If the input tensor is in MKL layout, then obtains TensorShape from
|
||||
/// MklShape.
|
||||
inline TensorShape GetTfShape(OpKernelContext* context, size_t input_idx) {
|
||||
inline TensorShape GetTfShape(OpKernelContext* context, size_t input_idx,
|
||||
bool eager_mode = false) {
|
||||
// Sanity check.
|
||||
CHECK_NOTNULL(context);
|
||||
CHECK_LT(input_idx, context->num_inputs());
|
||||
|
||||
MklDnnShape input_mkl_shape;
|
||||
GetMklShape(context, input_idx, &input_mkl_shape);
|
||||
if (input_mkl_shape.IsMklTensor()) {
|
||||
GetMklShape(context, input_idx, &input_mkl_shape, eager_mode);
|
||||
if (input_mkl_shape.IsMklTensor() && !eager_mode) {
|
||||
return input_mkl_shape.GetTfShape();
|
||||
} else {
|
||||
const Tensor& t = MklGetInput(context, input_idx);
|
||||
@ -768,19 +779,22 @@ inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
|
||||
inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
|
||||
Tensor** output,
|
||||
const TensorShape& tf_shape,
|
||||
const MklDnnShape& mkl_shape) {
|
||||
Tensor* second_tensor = nullptr;
|
||||
TensorShape second_shape;
|
||||
second_shape.AddDim(mkl_shape.GetSerializeBufferSize());
|
||||
const MklDnnShape& mkl_shape,
|
||||
bool eager_mode = false) {
|
||||
OP_REQUIRES_OK(
|
||||
ctext, ctext->allocate_output(GetTensorDataIndex(n, ctext->num_outputs()),
|
||||
tf_shape, output));
|
||||
OP_REQUIRES_OK(ctext, ctext->allocate_output(
|
||||
GetTensorMetaDataIndex(n, ctext->num_outputs()),
|
||||
second_shape, &second_tensor));
|
||||
mkl_shape.SerializeMklDnnShape(
|
||||
second_tensor->flat<uint8>().data(),
|
||||
second_tensor->flat<uint8>().size() * sizeof(uint8));
|
||||
if (!eager_mode) {
|
||||
Tensor* second_tensor = nullptr;
|
||||
TensorShape second_shape;
|
||||
second_shape.AddDim(mkl_shape.GetSerializeBufferSize());
|
||||
OP_REQUIRES_OK(ctext, ctext->allocate_output(
|
||||
GetTensorMetaDataIndex(n, ctext->num_outputs()),
|
||||
second_shape, &second_tensor));
|
||||
mkl_shape.SerializeMklDnnShape(
|
||||
second_tensor->flat<uint8>().data(),
|
||||
second_tensor->flat<uint8>().size() * sizeof(uint8));
|
||||
}
|
||||
}
|
||||
|
||||
// Allocates a temp tensor and returns the data buffer for temporary storage.
|
||||
|
Loading…
Reference in New Issue
Block a user