Merge pull request #30401 from Intel-tensorflow:eager-conv-fwd
PiperOrigin-RevId: 261397141
This commit is contained in:
commit
660f39fd0d
@ -3,6 +3,11 @@ load(
|
|||||||
"tf_cc_test",
|
"tf_cc_test",
|
||||||
"tf_cuda_library",
|
"tf_cuda_library",
|
||||||
)
|
)
|
||||||
|
load(
|
||||||
|
"//third_party/mkl:build_defs.bzl",
|
||||||
|
"if_mkl",
|
||||||
|
"mkl_deps",
|
||||||
|
)
|
||||||
|
|
||||||
package(
|
package(
|
||||||
default_visibility = [
|
default_visibility = [
|
||||||
@ -266,7 +271,14 @@ cc_library(
|
|||||||
"//tensorflow/core/distributed_runtime/eager:remote_execute_node",
|
"//tensorflow/core/distributed_runtime/eager:remote_execute_node",
|
||||||
"//tensorflow/core/distributed_runtime/eager:remote_copy_node",
|
"//tensorflow/core/distributed_runtime/eager:remote_copy_node",
|
||||||
],
|
],
|
||||||
}),
|
}) + if_mkl([":mkl_eager_op_rewrite"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "mkl_eager_op_rewrite",
|
||||||
|
srcs = ["mkl_eager_op_rewrite.cc"],
|
||||||
|
copts = ["-DINTEL_MKL=1"],
|
||||||
|
deps = [":eager_op_rewrite_registry"],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
|
180
tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite.cc
Normal file
180
tensorflow/core/common_runtime/eager/mkl_eager_op_rewrite.cc
Normal file
@ -0,0 +1,180 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifdef INTEL_MKL
|
||||||
|
#include "tensorflow/core/common_runtime/eager/eager_op_rewrite_registry.h"
|
||||||
|
#include "tensorflow/core/graph/mkl_graph_util.h"
|
||||||
|
#include "tensorflow/core/graph/mkl_layout_pass.h"
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
#include "tensorflow/core/util/mkl_util.h"
|
||||||
|
#include "tensorflow/core/util/util.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
class MklEagerOpRewrite : public EagerOpRewrite {
|
||||||
|
public:
|
||||||
|
MklEagerOpRewrite(string name, string file, string line);
|
||||||
|
typedef struct {
|
||||||
|
string op_name;
|
||||||
|
std::function<bool(EagerOperation*)> RewriteRule;
|
||||||
|
std::function<Status(EagerOperation*, std::unique_ptr<EagerOperation>*)>
|
||||||
|
CreateMklOp;
|
||||||
|
} MklEagerOp;
|
||||||
|
|
||||||
|
private:
|
||||||
|
// TODO(intel-tf): refactor with unordered_map;
|
||||||
|
// especially when adding more ops/rewrite rules in future.
|
||||||
|
std::vector<MklEagerOp> mkl_eager_ops_;
|
||||||
|
|
||||||
|
// The entry point to execute the op rewrite.
|
||||||
|
Status Run(EagerOperation* orig_op,
|
||||||
|
std::unique_ptr<tensorflow::EagerOperation>* out_op);
|
||||||
|
|
||||||
|
// Initializes the new op and sets up its inputs and attributes
|
||||||
|
static Status SetupNewOp(EagerOperation* orig_op, const string mkl_op_name,
|
||||||
|
std::unique_ptr<EagerOperation>* new_mkl_op);
|
||||||
|
|
||||||
|
// Creates new MKL op for Conv2D, Conv2DBackpropInput and
|
||||||
|
// Conv2DBackpropFilter.
|
||||||
|
static Status CreateMklConv2DOp(
|
||||||
|
EagerOperation* orig_op, std::unique_ptr<EagerOperation>* mkl_conv2d_op);
|
||||||
|
|
||||||
|
// Rewrite rule for Conv2D, Conv2DBackpropInput and Conv2DBackpropFilter.
|
||||||
|
static bool RewriteConv2D(EagerOperation* op);
|
||||||
|
|
||||||
|
// Calls op-specific rewrite function to create new MKL op.
|
||||||
|
Status RewriteToMklOp(EagerOperation* orig_op,
|
||||||
|
std::unique_ptr<EagerOperation>* mkl_op,
|
||||||
|
const int op_idx);
|
||||||
|
|
||||||
|
// Checks whether we can rewrite the op to MKL one or not.
|
||||||
|
bool ShouldRewriteOp(EagerOperation* op, int* op_idx);
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_REWRITE(EagerOpRewriteRegistry::PRE_EXECUTION, MklEagerOpRewrite);
|
||||||
|
|
||||||
|
// Constructor
|
||||||
|
MklEagerOpRewrite::MklEagerOpRewrite(string name, string file, string line)
|
||||||
|
: EagerOpRewrite(name, file, line) {
|
||||||
|
mkl_eager_ops_.push_back({"Conv2D", RewriteConv2D, CreateMklConv2DOp});
|
||||||
|
mkl_eager_ops_.push_back(
|
||||||
|
{"Conv2DBackpropInput", RewriteConv2D, CreateMklConv2DOp});
|
||||||
|
mkl_eager_ops_.push_back(
|
||||||
|
{"Conv2DBackpropFilter", RewriteConv2D, CreateMklConv2DOp});
|
||||||
|
}
|
||||||
|
|
||||||
|
Status MklEagerOpRewrite::Run(
|
||||||
|
EagerOperation* orig_op,
|
||||||
|
std::unique_ptr<tensorflow::EagerOperation>* out_op) {
|
||||||
|
int found_op_idx = -1;
|
||||||
|
if (ShouldRewriteOp(orig_op, &found_op_idx)) {
|
||||||
|
TF_CHECK_OK(RewriteToMklOp(orig_op, out_op, found_op_idx));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status MklEagerOpRewrite::SetupNewOp(
|
||||||
|
EagerOperation* orig_op, const string mkl_op_name,
|
||||||
|
std::unique_ptr<EagerOperation>* new_mkl_op) {
|
||||||
|
const tensorflow::AttrTypeMap* types;
|
||||||
|
bool is_function = false;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
tensorflow::AttrTypeMapForOp(mkl_op_name.c_str(), &types, &is_function));
|
||||||
|
EagerContext* ctx = orig_op->EagerContext();
|
||||||
|
new_mkl_op->reset(new tensorflow::EagerOperation(ctx, mkl_op_name.c_str(),
|
||||||
|
is_function, types));
|
||||||
|
|
||||||
|
int num_inputs = orig_op->Inputs().size();
|
||||||
|
// Add all inputs to the new op.
|
||||||
|
for (int i = 0; i < num_inputs; ++i) {
|
||||||
|
(*new_mkl_op)->AddInput(orig_op->Inputs()[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy all attributes to the new op.
|
||||||
|
string name;
|
||||||
|
const NodeDef& orig_ndef = orig_op->MutableAttrs()->BuildNodeDef();
|
||||||
|
|
||||||
|
AttrSlice attr_list(orig_ndef);
|
||||||
|
for (const auto& attr : attr_list) {
|
||||||
|
(*new_mkl_op)->MutableAttrs()->Set(attr.first, attr.second);
|
||||||
|
}
|
||||||
|
|
||||||
|
(*new_mkl_op)
|
||||||
|
->MutableAttrs()
|
||||||
|
->Set("_kernel", mkl_op_registry::kMklNameChangeOpLabel);
|
||||||
|
|
||||||
|
if (orig_op->Device() != nullptr) {
|
||||||
|
(*new_mkl_op)->SetDevice(orig_op->Device());
|
||||||
|
} else {
|
||||||
|
string device_name =
|
||||||
|
DeviceNameUtils::ParsedNameToString(orig_op->GetDeviceName());
|
||||||
|
(*new_mkl_op)->SetDeviceName(device_name.c_str());
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status MklEagerOpRewrite::CreateMklConv2DOp(
|
||||||
|
EagerOperation* orig_op, std::unique_ptr<EagerOperation>* mkl_conv2d_op) {
|
||||||
|
const string mkl_op_name =
|
||||||
|
mkl_op_registry::GetMklEagerOpName(orig_op->Name());
|
||||||
|
TF_CHECK_OK(SetupNewOp(orig_op, mkl_op_name, mkl_conv2d_op));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool MklEagerOpRewrite::ShouldRewriteOp(EagerOperation* op, int* op_idx) {
|
||||||
|
// Don't rewrite the op if MKL use is disabled at runtime.
|
||||||
|
if (DisableMKL()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
DataType data_type;
|
||||||
|
if (op->Attrs().Get("T", &data_type) != Status::OK()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// Check if we have registered MKL kernel for this op.
|
||||||
|
if (!mkl_op_registry::IsMklNameChangeOp(
|
||||||
|
mkl_op_registry::GetMklEagerOpName(op->Name()), data_type) &&
|
||||||
|
!mkl_op_registry::IsMklNameChangeOp(
|
||||||
|
mkl_op_registry::GetMklOpName(op->Name()), data_type)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
*op_idx = -1;
|
||||||
|
// Find and call the op's rewrite rule that determines whether we need to
|
||||||
|
// rewrite this op or not.
|
||||||
|
for (auto it = mkl_eager_ops_.begin(); it != mkl_eager_ops_.end(); ++it) {
|
||||||
|
if (it->op_name.compare(op->Name()) == 0 && it->RewriteRule(op)) {
|
||||||
|
*op_idx = it - mkl_eager_ops_.begin();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status MklEagerOpRewrite::RewriteToMklOp(
|
||||||
|
EagerOperation* orig_op, std::unique_ptr<EagerOperation>* mkl_op,
|
||||||
|
const int op_idx) {
|
||||||
|
mkl_eager_ops_[op_idx].CreateMklOp(orig_op, mkl_op);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool MklEagerOpRewrite::RewriteConv2D(EagerOperation* op) {
|
||||||
|
const NodeDef& ndef = op->MutableAttrs()->BuildNodeDef();
|
||||||
|
string padding;
|
||||||
|
TF_CHECK_OK(GetNodeAttr(ndef, "padding", &padding));
|
||||||
|
// Right now MKL Conv2D does not support explicit padding.
|
||||||
|
return (padding != "EXPLICIT");
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
#endif // INTEL_MKL
|
@ -104,12 +104,24 @@ static const char* kMklQuantizedOpLabelPattern = "label='QuantizedMklOp'";
|
|||||||
|
|
||||||
// Prefix that we add to Tensorflow op name to construct Mkl op name.
|
// Prefix that we add to Tensorflow op name to construct Mkl op name.
|
||||||
static const char* const kMklOpPrefix = "_Mkl";
|
static const char* const kMklOpPrefix = "_Mkl";
|
||||||
|
// TODO(intel-tf): PR review feedback (penpornk)
|
||||||
|
// Can we add eager_mode (or is_eager) as an op attribute instead?
|
||||||
|
// This way we don't need to rename the op just to pass eager_mode
|
||||||
|
// through template parameter.
|
||||||
|
static const char* const kMklEagerOpPrefix = "_MklEager";
|
||||||
|
|
||||||
// Get the name of Mkl op from original TensorFlow op
|
// Get the name of Mkl op from original TensorFlow op
|
||||||
// We prefix 'Mkl' to the original op to get Mkl op.
|
// We prefix 'Mkl' to the original op to get Mkl op.
|
||||||
inline string GetMklOpName(const string& name) {
|
inline string GetMklOpName(const string& name) {
|
||||||
return string(kMklOpPrefix) + name;
|
return string(kMklOpPrefix) + name;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get the name of Mkl Eager op from original TensorFlow op
|
||||||
|
// We prefix 'MklEager' to the original op to get Mkl Eager op.
|
||||||
|
inline string GetMklEagerOpName(const string& name) {
|
||||||
|
return string(kMklEagerOpPrefix) + name;
|
||||||
|
}
|
||||||
|
|
||||||
// Check whether opname with type T is registered as MKL operator
|
// Check whether opname with type T is registered as MKL operator
|
||||||
// that can accept input tensors in MKL layout.
|
// that can accept input tensors in MKL layout.
|
||||||
//
|
//
|
||||||
|
@ -518,7 +518,8 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
|
|||||||
// Base class for convolution forward operations
|
// Base class for convolution forward operations
|
||||||
template <typename Device, typename Tinput, typename Tfilter, typename Tbias,
|
template <typename Device, typename Tinput, typename Tfilter, typename Tbias,
|
||||||
typename Toutput, typename Ttemp_output, typename Tpadding,
|
typename Toutput, typename Ttemp_output, typename Tpadding,
|
||||||
bool bias_enabled, bool pad_enabled, bool is_depthwise>
|
bool bias_enabled, bool pad_enabled, bool is_depthwise,
|
||||||
|
bool eager_mode>
|
||||||
class MklConvOp : public OpKernel {
|
class MklConvOp : public OpKernel {
|
||||||
public:
|
public:
|
||||||
~MklConvOp() {}
|
~MklConvOp() {}
|
||||||
@ -545,8 +546,10 @@ class MklConvOp : public OpKernel {
|
|||||||
"strides in the batch and depth dimensions."));
|
"strides in the batch and depth dimensions."));
|
||||||
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
|
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
|
||||||
is_filter_const_ = false;
|
is_filter_const_ = false;
|
||||||
|
if (context->HasAttr("is_filter_const")) {
|
||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
context->GetAttr("is_filter_const", &is_filter_const_));
|
context->GetAttr("is_filter_const", &is_filter_const_));
|
||||||
|
}
|
||||||
|
|
||||||
if (strides_.size() == 4) {
|
if (strides_.size() == 4) {
|
||||||
OP_REQUIRES(context, dilations_.size() == 4,
|
OP_REQUIRES(context, dilations_.size() == 4,
|
||||||
@ -589,8 +592,9 @@ class MklConvOp : public OpKernel {
|
|||||||
const Tensor& filter_tensor = MklGetInput(context, kInputIndex_Filter);
|
const Tensor& filter_tensor = MklGetInput(context, kInputIndex_Filter);
|
||||||
|
|
||||||
MklDnnShape src_mkl_shape, filter_mkl_shape;
|
MklDnnShape src_mkl_shape, filter_mkl_shape;
|
||||||
GetMklShape(context, kInputIndex_Src, &src_mkl_shape);
|
GetMklShape(context, kInputIndex_Src, &src_mkl_shape, eager_mode);
|
||||||
GetMklShape(context, kInputIndex_Filter, &filter_mkl_shape);
|
GetMklShape(context, kInputIndex_Filter, &filter_mkl_shape, eager_mode);
|
||||||
|
|
||||||
OP_REQUIRES(context, filter_mkl_shape.IsMklTensor() == false,
|
OP_REQUIRES(context, filter_mkl_shape.IsMklTensor() == false,
|
||||||
errors::InvalidArgument("Filter should not be in "
|
errors::InvalidArgument("Filter should not be in "
|
||||||
"Mkl Layout"));
|
"Mkl Layout"));
|
||||||
@ -620,8 +624,9 @@ class MklConvOp : public OpKernel {
|
|||||||
// Get shapes of input tensors in MKL-DNN order
|
// Get shapes of input tensors in MKL-DNN order
|
||||||
MklDnnConvUtil conv_utl(context, strides_, padding_, data_format_,
|
MklDnnConvUtil conv_utl(context, strides_, padding_, data_format_,
|
||||||
dilations_);
|
dilations_);
|
||||||
auto src_tf_shape = GetTfShape(context, kInputIndex_Src);
|
auto src_tf_shape = GetTfShape(context, kInputIndex_Src, eager_mode);
|
||||||
auto filter_tf_shape = GetTfShape(context, kInputIndex_Filter);
|
auto filter_tf_shape =
|
||||||
|
GetTfShape(context, kInputIndex_Filter, eager_mode);
|
||||||
conv_utl.GetConvFwdSizesInMklOrder(
|
conv_utl.GetConvFwdSizesInMklOrder(
|
||||||
src_tf_shape, filter_tf_shape, &src_dims, &filter_dims, &strides,
|
src_tf_shape, filter_tf_shape, &src_dims, &filter_dims, &strides,
|
||||||
&dilations, &dst_dims_tf_order, &dst_dims_mkl_order, &padding_left,
|
&dilations, &dst_dims_tf_order, &dst_dims_mkl_order, &padding_left,
|
||||||
@ -634,15 +639,17 @@ class MklConvOp : public OpKernel {
|
|||||||
|
|
||||||
// Corner cases: output with 0 elements and 0 batch size.
|
// Corner cases: output with 0 elements and 0 batch size.
|
||||||
Tensor* dst_tensor = nullptr;
|
Tensor* dst_tensor = nullptr;
|
||||||
|
Tensor tmp_tensor;
|
||||||
bool emit_filter_output = (typeid(Tinput) == typeid(Tfilter) &&
|
bool emit_filter_output = (typeid(Tinput) == typeid(Tfilter) &&
|
||||||
typeid(Tinput) == typeid(Toutput) &&
|
typeid(Tinput) == typeid(Toutput) &&
|
||||||
(typeid(Tinput) == typeid(float) ||
|
(typeid(Tinput) == typeid(float) ||
|
||||||
typeid(Tinput) == typeid(bfloat16)));
|
typeid(Tinput) == typeid(bfloat16))) &&
|
||||||
|
!eager_mode;
|
||||||
if (dst_tf_shape.num_elements() == 0 || dst_dims_tf_order[0] == 0) {
|
if (dst_tf_shape.num_elements() == 0 || dst_dims_tf_order[0] == 0) {
|
||||||
MklDnnShape dst_mkl_shape;
|
MklDnnShape dst_mkl_shape;
|
||||||
dst_mkl_shape.SetMklTensor(false);
|
dst_mkl_shape.SetMklTensor(false);
|
||||||
AllocateOutputSetMklShape(context, kOutputIndex_Dst, &dst_tensor,
|
AllocateOutputSetMklShape(context, kOutputIndex_Dst, &dst_tensor,
|
||||||
dst_tf_shape, dst_mkl_shape);
|
src_tf_shape, dst_mkl_shape, eager_mode);
|
||||||
|
|
||||||
// MklConv2D/3D also outputs converted filter as 2nd output.
|
// MklConv2D/3D also outputs converted filter as 2nd output.
|
||||||
filter_mkl_shape.SetMklTensor(false);
|
filter_mkl_shape.SetMklTensor(false);
|
||||||
@ -754,9 +761,10 @@ class MklConvOp : public OpKernel {
|
|||||||
convFwdDims, do_not_cache);
|
convFwdDims, do_not_cache);
|
||||||
|
|
||||||
// Allocate output tensors `output_tensor` and `filter_out_tensor`
|
// Allocate output tensors `output_tensor` and `filter_out_tensor`
|
||||||
|
MklDnnShape output_mkl_shape;
|
||||||
std::shared_ptr<ConvFwdPd> conv_fwd_pd = conv_fwd->GetPrimitiveDesc();
|
std::shared_ptr<ConvFwdPd> conv_fwd_pd = conv_fwd->GetPrimitiveDesc();
|
||||||
AllocateOutputTensor(context, *conv_fwd_pd, dst_dims_mkl_order, tf_fmt,
|
AllocateOutputTensor(context, *conv_fwd_pd, dst_dims_mkl_order, tf_fmt,
|
||||||
&dst_tensor);
|
&output_mkl_shape, &dst_tensor, &tmp_tensor);
|
||||||
|
|
||||||
Tensor* filter_out_tensor = nullptr;
|
Tensor* filter_out_tensor = nullptr;
|
||||||
if (emit_filter_output) {
|
if (emit_filter_output) {
|
||||||
@ -828,7 +836,28 @@ class MklConvOp : public OpKernel {
|
|||||||
this->GetBiasHandle(context, conv_fwd_pd, bias_tensor);
|
this->GetBiasHandle(context, conv_fwd_pd, bias_tensor);
|
||||||
conv_fwd->Execute(src_data, filter_data, bias_data, dst_data);
|
conv_fwd->Execute(src_data, filter_data, bias_data, dst_data);
|
||||||
} else {
|
} else {
|
||||||
|
if (!eager_mode) {
|
||||||
conv_fwd->Execute(src_data, filter_data, dst_data);
|
conv_fwd->Execute(src_data, filter_data, dst_data);
|
||||||
|
} else {
|
||||||
|
// In eager mode we first write the output to temporary
|
||||||
|
// buffer in MKL format. Then we convert the data to TF format.
|
||||||
|
Ttemp_output* tmp_data = reinterpret_cast<Ttemp_output*>(
|
||||||
|
tmp_tensor.flat<Toutput>().data());
|
||||||
|
conv_fwd->Execute(src_data, filter_data, tmp_data);
|
||||||
|
|
||||||
|
// Now we need to convert the output to TF format.
|
||||||
|
auto output_tf_md = output_mkl_shape.GetTfLayout();
|
||||||
|
auto output_tf_pd = memory::primitive_desc(output_tf_md, cpu_engine_);
|
||||||
|
auto dst_pd = (*conv_fwd_pd).dst_primitive_desc();
|
||||||
|
mkldnn::reorder::primitive_desc reorder_pd =
|
||||||
|
mkldnn::reorder::primitive_desc(dst_pd, output_tf_pd);
|
||||||
|
std::vector<mkldnn::primitive> net;
|
||||||
|
memory* tmp_data_mem = new memory(dst_pd, tmp_data);
|
||||||
|
memory* dst_data_mem = new memory(output_tf_pd, dst_data);
|
||||||
|
net.push_back(
|
||||||
|
mkldnn::reorder(reorder_pd, *tmp_data_mem, *dst_data_mem));
|
||||||
|
stream(stream::kind::eager).submit(net).wait();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete primitive since it is not cached.
|
// Delete primitive since it is not cached.
|
||||||
@ -942,7 +971,9 @@ class MklConvOp : public OpKernel {
|
|||||||
const ConvFwdPd& conv_prim_desc,
|
const ConvFwdPd& conv_prim_desc,
|
||||||
const memory::dims& output_dims_mkl_order,
|
const memory::dims& output_dims_mkl_order,
|
||||||
MKL_TENSOR_FORMAT output_tf_format,
|
MKL_TENSOR_FORMAT output_tf_format,
|
||||||
Tensor** output_tensor) {
|
MklDnnShape* output_mkl_shape,
|
||||||
|
Tensor** output_tensor,
|
||||||
|
Tensor* tmp_tensor) {
|
||||||
DCHECK(output_tensor);
|
DCHECK(output_tensor);
|
||||||
#ifdef ENABLE_MKLDNN_V1
|
#ifdef ENABLE_MKLDNN_V1
|
||||||
auto dst_md = conv_prim_desc.dst_desc();
|
auto dst_md = conv_prim_desc.dst_desc();
|
||||||
@ -960,25 +991,29 @@ class MklConvOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Allocate shape of MKL tensor
|
// Allocate shape of MKL tensor
|
||||||
MklDnnShape output_mkl_shape;
|
output_mkl_shape->SetMklTensor(true);
|
||||||
output_mkl_shape.SetMklTensor(true);
|
output_mkl_shape->SetMklLayout(&DST_MD);
|
||||||
output_mkl_shape.SetMklLayout(&DST_MD);
|
output_mkl_shape->SetElemType(MklDnnType<Toutput>());
|
||||||
output_mkl_shape.SetElemType(MklDnnType<Toutput>());
|
output_mkl_shape->SetTfLayout(output_dims_mkl_order.size(),
|
||||||
output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(),
|
|
||||||
output_dims_mkl_order, output_tf_format);
|
output_dims_mkl_order, output_tf_format);
|
||||||
|
|
||||||
// Allocate shape of TF tensor
|
// Allocate shape of TF tensor
|
||||||
TensorShape output_tf_shape;
|
TensorShape output_tf_shape;
|
||||||
output_tf_shape.AddDim((DST_MD.get_size() / sizeof(Toutput)));
|
output_tf_shape.AddDim((DST_MD.get_size() / sizeof(Toutput)));
|
||||||
|
if (eager_mode) {
|
||||||
|
AllocTmpBuffer<Toutput>(context, tmp_tensor, output_tf_shape);
|
||||||
|
output_tf_shape = output_mkl_shape->GetTfShape();
|
||||||
|
}
|
||||||
|
|
||||||
AllocateOutputSetMklShape(context, kOutputIndex_Dst, output_tensor,
|
AllocateOutputSetMklShape(context, kOutputIndex_Dst, output_tensor,
|
||||||
output_tf_shape, output_mkl_shape);
|
output_tf_shape, *output_mkl_shape, eager_mode);
|
||||||
if (fuse_add_) {
|
if (fuse_add_) {
|
||||||
const Tensor& add_tensor = MklGetInput(context, kInputIndex_Add);
|
const Tensor& add_tensor = MklGetInput(context, kInputIndex_Add);
|
||||||
MklDnnShape add_mkl_shape;
|
MklDnnShape add_mkl_shape;
|
||||||
GetMklShape(context, kInputIndex_Add, &add_mkl_shape);
|
GetMklShape(context, kInputIndex_Add, &add_mkl_shape);
|
||||||
|
|
||||||
// Check if reorder is needed
|
// Check if reorder is needed
|
||||||
if (add_mkl_shape == output_mkl_shape) {
|
if (add_mkl_shape == *output_mkl_shape) {
|
||||||
DCHECK((*output_tensor)->CopyFrom(add_tensor, output_tf_shape));
|
DCHECK((*output_tensor)->CopyFrom(add_tensor, output_tf_shape));
|
||||||
} else {
|
} else {
|
||||||
if (add_mkl_shape.IsMklTensor()) {
|
if (add_mkl_shape.IsMklTensor()) {
|
||||||
@ -986,14 +1021,14 @@ class MklConvOp : public OpKernel {
|
|||||||
} else {
|
} else {
|
||||||
#ifdef ENABLE_MKLDNN_V1
|
#ifdef ENABLE_MKLDNN_V1
|
||||||
auto output_format_tag = MklTensorFormatToMklDnnDataFormat(
|
auto output_format_tag = MklTensorFormatToMklDnnDataFormat(
|
||||||
output_mkl_shape.GetTfDataFormat());
|
output_mkl_shape->GetTfDataFormat());
|
||||||
DCHECK_NE(output_format_tag, memory::format_tag::undef);
|
DCHECK_NE(output_format_tag, memory::format_tag::undef);
|
||||||
auto add_md = memory::desc(output_dims_mkl_order,
|
auto add_md = memory::desc(output_dims_mkl_order,
|
||||||
MklDnnType<Toutput>(), output_format_tag);
|
MklDnnType<Toutput>(), output_format_tag);
|
||||||
#else
|
#else
|
||||||
auto add_md =
|
auto add_md =
|
||||||
memory::desc(output_dims_mkl_order, MklDnnType<Toutput>(),
|
memory::desc(output_dims_mkl_order, MklDnnType<Toutput>(),
|
||||||
output_mkl_shape.GetTfDataFormat());
|
output_mkl_shape->GetTfDataFormat());
|
||||||
auto add_pd = memory::primitive_desc(add_md, this->cpu_engine_);
|
auto add_pd = memory::primitive_desc(add_md, this->cpu_engine_);
|
||||||
#endif // ENABLE_MKLDNN_V1
|
#endif // ENABLE_MKLDNN_V1
|
||||||
void* add_buf = static_cast<void*>(
|
void* add_buf = static_cast<void*>(
|
||||||
@ -1290,11 +1325,11 @@ template <typename Device, typename Tinput, typename Tfilter, typename Tbias,
|
|||||||
bool pad_enabled>
|
bool pad_enabled>
|
||||||
class MklFusedConvOp
|
class MklFusedConvOp
|
||||||
: public MklConvOp<Device, Tinput, Tfilter, Tbias, Toutput, Ttemp_output,
|
: public MklConvOp<Device, Tinput, Tfilter, Tbias, Toutput, Ttemp_output,
|
||||||
Tpadding, false, false, false> {
|
Tpadding, false, false, false, false> {
|
||||||
public:
|
public:
|
||||||
explicit MklFusedConvOp(OpKernelConstruction* context)
|
explicit MklFusedConvOp(OpKernelConstruction* context)
|
||||||
: MklConvOp<Device, Tinput, Tfilter, Tbias, Toutput, Ttemp_output,
|
: MklConvOp<Device, Tinput, Tfilter, Tbias, Toutput, Ttemp_output,
|
||||||
Tpadding, false, false, false>(context) {
|
Tpadding, false, false, false, false>(context) {
|
||||||
// Since we came here through the registration of _MklFusedConv2D, get
|
// Since we came here through the registration of _MklFusedConv2D, get
|
||||||
// all information from 'fused_ops' and 'num_args'
|
// all information from 'fused_ops' and 'num_args'
|
||||||
std::vector<string> fused_ops;
|
std::vector<string> fused_ops;
|
||||||
@ -1386,7 +1421,7 @@ template <typename Device, typename Tbias, typename Toutput,
|
|||||||
typename Ttemp_output, bool bias_enabled, bool is_depthwise>
|
typename Ttemp_output, bool bias_enabled, bool is_depthwise>
|
||||||
class MklQuantizedConv2DOp
|
class MklQuantizedConv2DOp
|
||||||
: public MklConvOp<Device, quint8, qint8, Tbias, Toutput, Ttemp_output,
|
: public MklConvOp<Device, quint8, qint8, Tbias, Toutput, Ttemp_output,
|
||||||
int32, bias_enabled, false, is_depthwise> {
|
int32, bias_enabled, false, is_depthwise, false> {
|
||||||
public:
|
public:
|
||||||
virtual ~MklQuantizedConv2DOp() {
|
virtual ~MklQuantizedConv2DOp() {
|
||||||
if (this->input_bias_ != nullptr) {
|
if (this->input_bias_ != nullptr) {
|
||||||
@ -1402,7 +1437,7 @@ class MklQuantizedConv2DOp
|
|||||||
|
|
||||||
explicit MklQuantizedConv2DOp(OpKernelConstruction* context)
|
explicit MklQuantizedConv2DOp(OpKernelConstruction* context)
|
||||||
: MklConvOp<Device, quint8, qint8, Tbias, Toutput, Ttemp_output, int32,
|
: MklConvOp<Device, quint8, qint8, Tbias, Toutput, Ttemp_output, int32,
|
||||||
bias_enabled, false, is_depthwise>(context) {
|
bias_enabled, false, is_depthwise, false>(context) {
|
||||||
bool is_filter_const;
|
bool is_filter_const;
|
||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
context->GetAttr("is_filter_const", &is_filter_const));
|
context->GetAttr("is_filter_const", &is_filter_const));
|
||||||
@ -1413,7 +1448,7 @@ class MklQuantizedConv2DOp
|
|||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
// Compute int32 output tensor
|
// Compute int32 output tensor
|
||||||
MklConvOp<Device, quint8, qint8, Tbias, Toutput, Ttemp_output, int32,
|
MklConvOp<Device, quint8, qint8, Tbias, Toutput, Ttemp_output, int32,
|
||||||
bias_enabled, false, is_depthwise>::Compute(context);
|
bias_enabled, false, is_depthwise, false>::Compute(context);
|
||||||
|
|
||||||
// Compute additional outputs: min/max scalars.
|
// Compute additional outputs: min/max scalars.
|
||||||
int bias_index_offset;
|
int bias_index_offset;
|
||||||
@ -1475,8 +1510,8 @@ class MklQuantizedConv2DOp
|
|||||||
void ExtendConvFwdParams(OpKernelContext* context,
|
void ExtendConvFwdParams(OpKernelContext* context,
|
||||||
MklConvFwdParams& params) override {
|
MklConvFwdParams& params) override {
|
||||||
MklConvOp<Device, quint8, qint8, Tbias, Toutput, Ttemp_output, int32,
|
MklConvOp<Device, quint8, qint8, Tbias, Toutput, Ttemp_output, int32,
|
||||||
bias_enabled, false, is_depthwise>::ExtendConvFwdParams(context,
|
bias_enabled, false, is_depthwise,
|
||||||
params);
|
false>::ExtendConvFwdParams(context, params);
|
||||||
|
|
||||||
// When the output type is quint8, the output data id requantized
|
// When the output type is quint8, the output data id requantized
|
||||||
// into quint8. A post_op "output_scale" is added to do the conversion.
|
// into quint8. A post_op "output_scale" is added to do the conversion.
|
||||||
@ -1674,7 +1709,9 @@ class MklQuantizedConv2DSumReluOp
|
|||||||
const ConvFwdPd& conv_prim_desc,
|
const ConvFwdPd& conv_prim_desc,
|
||||||
const memory::dims& output_dims_mkl_order,
|
const memory::dims& output_dims_mkl_order,
|
||||||
MKL_TENSOR_FORMAT output_tf_format,
|
MKL_TENSOR_FORMAT output_tf_format,
|
||||||
Tensor** output_tensor) override {
|
MklDnnShape* output_mkl_shape,
|
||||||
|
Tensor** output_tensor,
|
||||||
|
Tensor* tmp_tensor) override {
|
||||||
int summand_idx = context->num_inputs() / 2 - 1;
|
int summand_idx = context->num_inputs() / 2 - 1;
|
||||||
if (std::is_same<Toutput, quint8>::value) {
|
if (std::is_same<Toutput, quint8>::value) {
|
||||||
summand_idx -= 2;
|
summand_idx -= 2;
|
||||||
@ -1701,12 +1738,12 @@ class MklQuantizedConv2DSumReluOp
|
|||||||
*output_tensor = const_cast<Tensor*>(&summand);
|
*output_tensor = const_cast<Tensor*>(&summand);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
MklConvOp<Device, quint8, qint8, Tbias, Toutput, Ttemp_output, int32,
|
MklConvOp<Device, quint8, qint8, Tbias, Toutput, Ttemp_output, int32,
|
||||||
bias_enabled, false,
|
bias_enabled, false, false,
|
||||||
false>::AllocateOutputTensor(context, conv_prim_desc,
|
false>::AllocateOutputTensor(context, conv_prim_desc,
|
||||||
output_dims_mkl_order,
|
output_dims_mkl_order,
|
||||||
output_tf_format, output_tensor);
|
output_tf_format, output_mkl_shape,
|
||||||
|
output_tensor, tmp_tensor);
|
||||||
const Tensor& summand = MklGetInput(context, summand_idx);
|
const Tensor& summand = MklGetInput(context, summand_idx);
|
||||||
if (summand.dtype() != DT_FLOAT)
|
if (summand.dtype() != DT_FLOAT)
|
||||||
TF_CHECK_OK(Status(error::Code::FAILED_PRECONDITION,
|
TF_CHECK_OK(Status(error::Code::FAILED_PRECONDITION,
|
||||||
@ -2120,13 +2157,13 @@ REGISTER_KERNEL_BUILDER(
|
|||||||
.Device(DEVICE_CPU) \
|
.Device(DEVICE_CPU) \
|
||||||
.TypeConstraint<T>("T") \
|
.TypeConstraint<T>("T") \
|
||||||
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
|
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
|
||||||
MklConvOp<CPUDevice, T, T, T, T, T, int32, false, false, false>); \
|
MklConvOp<CPUDevice, T, T, T, T, T, int32, false, false, false, false>); \
|
||||||
REGISTER_KERNEL_BUILDER( \
|
REGISTER_KERNEL_BUILDER( \
|
||||||
Name("_MklConv2DWithBias") \
|
Name("_MklConv2DWithBias") \
|
||||||
.Device(DEVICE_CPU) \
|
.Device(DEVICE_CPU) \
|
||||||
.TypeConstraint<T>("T") \
|
.TypeConstraint<T>("T") \
|
||||||
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
|
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
|
||||||
MklConvOp<CPUDevice, T, T, T, T, T, int32, true, false, false>); \
|
MklConvOp<CPUDevice, T, T, T, T, T, int32, true, false, false, false>); \
|
||||||
REGISTER_KERNEL_BUILDER( \
|
REGISTER_KERNEL_BUILDER( \
|
||||||
Name("__MklDummyConv2DWithBias") \
|
Name("__MklDummyConv2DWithBias") \
|
||||||
.Device(DEVICE_CPU) \
|
.Device(DEVICE_CPU) \
|
||||||
@ -2139,21 +2176,27 @@ REGISTER_KERNEL_BUILDER(
|
|||||||
.TypeConstraint<T>("T") \
|
.TypeConstraint<T>("T") \
|
||||||
.TypeConstraint<int32>("Tpaddings") \
|
.TypeConstraint<int32>("Tpaddings") \
|
||||||
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
|
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
|
||||||
MklConvOp<CPUDevice, T, T, T, T, T, int32, false, true, false>); \
|
MklConvOp<CPUDevice, T, T, T, T, T, int32, false, true, false, false>); \
|
||||||
REGISTER_KERNEL_BUILDER( \
|
REGISTER_KERNEL_BUILDER( \
|
||||||
Name("_MklPadWithConv2D") \
|
Name("_MklPadWithConv2D") \
|
||||||
.Device(DEVICE_CPU) \
|
.Device(DEVICE_CPU) \
|
||||||
.TypeConstraint<T>("T") \
|
.TypeConstraint<T>("T") \
|
||||||
.TypeConstraint<int64>("Tpaddings") \
|
.TypeConstraint<int64>("Tpaddings") \
|
||||||
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
|
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
|
||||||
MklConvOp<CPUDevice, T, T, T, T, T, int64, false, true, false>); \
|
MklConvOp<CPUDevice, T, T, T, T, T, int64, false, true, false, false>); \
|
||||||
REGISTER_KERNEL_BUILDER( \
|
REGISTER_KERNEL_BUILDER( \
|
||||||
Name("__MklDummyPadWithConv2D") \
|
Name("__MklDummyPadWithConv2D") \
|
||||||
.Device(DEVICE_CPU) \
|
.Device(DEVICE_CPU) \
|
||||||
.TypeConstraint<T>("T") \
|
.TypeConstraint<T>("T") \
|
||||||
.TypeConstraint<int32>("Tpaddings") \
|
.TypeConstraint<int32>("Tpaddings") \
|
||||||
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
|
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
|
||||||
MklDummyOp<CPUDevice, T>);
|
MklDummyOp<CPUDevice, T>); \
|
||||||
|
REGISTER_KERNEL_BUILDER( \
|
||||||
|
Name("_MklEagerConv2D") \
|
||||||
|
.Device(DEVICE_CPU) \
|
||||||
|
.TypeConstraint<T>("T") \
|
||||||
|
.Label(mkl_op_registry::kMklNameChangeOpLabel), \
|
||||||
|
MklConvOp<CPUDevice, T, T, T, T, T, int32, false, false, false, true>);
|
||||||
|
|
||||||
TF_CALL_float(REGISTER_MKL_CPU_2D);
|
TF_CALL_float(REGISTER_MKL_CPU_2D);
|
||||||
TF_CALL_bfloat16(REGISTER_MKL_CPU_2D);
|
TF_CALL_bfloat16(REGISTER_MKL_CPU_2D);
|
||||||
@ -2164,7 +2207,7 @@ TF_CALL_bfloat16(REGISTER_MKL_CPU_2D);
|
|||||||
.Device(DEVICE_CPU) \
|
.Device(DEVICE_CPU) \
|
||||||
.TypeConstraint<T>("T") \
|
.TypeConstraint<T>("T") \
|
||||||
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
|
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
|
||||||
MklConvOp<CPUDevice, T, T, T, T, T, int32, false, false, true>);
|
MklConvOp<CPUDevice, T, T, T, T, T, int32, false, false, true, false>);
|
||||||
|
|
||||||
TF_CALL_float(REGISTER_MKL_CPU_2D_DEPTHWISE);
|
TF_CALL_float(REGISTER_MKL_CPU_2D_DEPTHWISE);
|
||||||
TF_CALL_bfloat16(REGISTER_MKL_CPU_2D_DEPTHWISE);
|
TF_CALL_bfloat16(REGISTER_MKL_CPU_2D_DEPTHWISE);
|
||||||
@ -2210,7 +2253,7 @@ TF_CALL_bfloat16(REGISTER_MKL_CPU_2D_FUSED);
|
|||||||
.Device(DEVICE_CPU) \
|
.Device(DEVICE_CPU) \
|
||||||
.TypeConstraint<T>("T") \
|
.TypeConstraint<T>("T") \
|
||||||
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
|
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
|
||||||
MklConvOp<CPUDevice, T, T, T, T, T, int32, false, false, false>);
|
MklConvOp<CPUDevice, T, T, T, T, T, int32, false, false, false, false>);
|
||||||
TF_CALL_float(REGISTER_MKL_CPU_3D);
|
TF_CALL_float(REGISTER_MKL_CPU_3D);
|
||||||
TF_CALL_bfloat16(REGISTER_MKL_CPU_3D);
|
TF_CALL_bfloat16(REGISTER_MKL_CPU_3D);
|
||||||
|
|
||||||
|
@ -1652,6 +1652,25 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
|
|||||||
expected to invoke these operators.
|
expected to invoke these operators.
|
||||||
)doc");
|
)doc");
|
||||||
|
|
||||||
|
REGISTER_OP("_MklEagerConv2D")
|
||||||
|
.Input("input: T")
|
||||||
|
.Input("filter: T")
|
||||||
|
.Output("output: T")
|
||||||
|
.Attr("T: {bfloat16, float}")
|
||||||
|
.Attr("strides: list(int)")
|
||||||
|
.Attr("use_cudnn_on_gpu: bool = true")
|
||||||
|
.Attr(GetPaddingAttrStringWithExplicit())
|
||||||
|
.Attr(GetExplicitPaddingsAttrString())
|
||||||
|
.Attr(GetConvnetDataFormatAttrString())
|
||||||
|
.Attr("dilations: list(int) = [1, 1, 1, 1]")
|
||||||
|
.SetShapeFn(shape_inference::Conv2DShapeWithExplicitPadding)
|
||||||
|
.Doc(R"doc(
|
||||||
|
MKL version of Conv2D operator for Eager mode. Uses MKL DNN APIs to perform 2D convolution.
|
||||||
|
|
||||||
|
NOTE Do not invoke this operator directly in Python. Eager Op rewrite is
|
||||||
|
expected to invoke these operators.
|
||||||
|
)doc");
|
||||||
|
|
||||||
REGISTER_OP("__MklDummyConv2DWithBias")
|
REGISTER_OP("__MklDummyConv2DWithBias")
|
||||||
.Input("input: T")
|
.Input("input: T")
|
||||||
.Input("filter: T")
|
.Input("filter: T")
|
||||||
|
@ -204,6 +204,7 @@ memory::format_tag MklTensorFormatToMklDnnDataFormat(MklTensorFormat format);
|
|||||||
|
|
||||||
TensorFormat MklDnn3DDataFormatToTFDataFormat(MKL_TENSOR_FORMAT format);
|
TensorFormat MklDnn3DDataFormatToTFDataFormat(MKL_TENSOR_FORMAT format);
|
||||||
TensorFormat MklDnnDataFormatToTFDataFormat(MKL_TENSOR_FORMAT format);
|
TensorFormat MklDnnDataFormatToTFDataFormat(MKL_TENSOR_FORMAT format);
|
||||||
|
|
||||||
memory::dims CalculateTFStrides(const memory::dims& dims_tf_order);
|
memory::dims CalculateTFStrides(const memory::dims& dims_tf_order);
|
||||||
memory::desc CreateBlockedMemDescHelper(const memory::dims& dim,
|
memory::desc CreateBlockedMemDescHelper(const memory::dims& dim,
|
||||||
const memory::dims& strides,
|
const memory::dims& strides,
|
||||||
@ -696,7 +697,9 @@ inline Tensor ConvertMklToTF(OpKernelContext* context, const Tensor& mkl_tensor,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get the MKL shape from the second string tensor
|
// Get the MKL shape from the second string tensor
|
||||||
inline void GetMklShape(OpKernelContext* ctext, int n, MklDnnShape* mklshape) {
|
inline void GetMklShape(OpKernelContext* ctext, int n, MklDnnShape* mklshape,
|
||||||
|
bool eager_mode) {
|
||||||
|
if (!eager_mode) {
|
||||||
mklshape->DeSerializeMklDnnShape(
|
mklshape->DeSerializeMklDnnShape(
|
||||||
ctext->input(GetTensorMetaDataIndex(n, ctext->num_inputs()))
|
ctext->input(GetTensorMetaDataIndex(n, ctext->num_inputs()))
|
||||||
.flat<uint8>()
|
.flat<uint8>()
|
||||||
@ -705,6 +708,13 @@ inline void GetMklShape(OpKernelContext* ctext, int n, MklDnnShape* mklshape) {
|
|||||||
.flat<uint8>()
|
.flat<uint8>()
|
||||||
.size() *
|
.size() *
|
||||||
sizeof(uint8));
|
sizeof(uint8));
|
||||||
|
} else {
|
||||||
|
mklshape->SetMklTensor(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void GetMklShape(OpKernelContext* ctext, int n, MklDnnShape* mklshape) {
|
||||||
|
GetMklShape(ctext, n, mklshape, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Gets the actual input
|
// Gets the actual input
|
||||||
@ -733,14 +743,15 @@ inline void GetMklShapeList(OpKernelContext* ctext, StringPiece name,
|
|||||||
/// Get shape of input tensor pointed by 'input_idx' in TensorShape format.
|
/// Get shape of input tensor pointed by 'input_idx' in TensorShape format.
|
||||||
/// If the input tensor is in MKL layout, then obtains TensorShape from
|
/// If the input tensor is in MKL layout, then obtains TensorShape from
|
||||||
/// MklShape.
|
/// MklShape.
|
||||||
inline TensorShape GetTfShape(OpKernelContext* context, size_t input_idx) {
|
inline TensorShape GetTfShape(OpKernelContext* context, size_t input_idx,
|
||||||
|
bool eager_mode = false) {
|
||||||
// Sanity check.
|
// Sanity check.
|
||||||
CHECK_NOTNULL(context);
|
CHECK_NOTNULL(context);
|
||||||
CHECK_LT(input_idx, context->num_inputs());
|
CHECK_LT(input_idx, context->num_inputs());
|
||||||
|
|
||||||
MklDnnShape input_mkl_shape;
|
MklDnnShape input_mkl_shape;
|
||||||
GetMklShape(context, input_idx, &input_mkl_shape);
|
GetMklShape(context, input_idx, &input_mkl_shape, eager_mode);
|
||||||
if (input_mkl_shape.IsMklTensor()) {
|
if (input_mkl_shape.IsMklTensor() && !eager_mode) {
|
||||||
return input_mkl_shape.GetTfShape();
|
return input_mkl_shape.GetTfShape();
|
||||||
} else {
|
} else {
|
||||||
const Tensor& t = MklGetInput(context, input_idx);
|
const Tensor& t = MklGetInput(context, input_idx);
|
||||||
@ -768,13 +779,15 @@ inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
|
|||||||
inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
|
inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
|
||||||
Tensor** output,
|
Tensor** output,
|
||||||
const TensorShape& tf_shape,
|
const TensorShape& tf_shape,
|
||||||
const MklDnnShape& mkl_shape) {
|
const MklDnnShape& mkl_shape,
|
||||||
Tensor* second_tensor = nullptr;
|
bool eager_mode = false) {
|
||||||
TensorShape second_shape;
|
|
||||||
second_shape.AddDim(mkl_shape.GetSerializeBufferSize());
|
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
ctext, ctext->allocate_output(GetTensorDataIndex(n, ctext->num_outputs()),
|
ctext, ctext->allocate_output(GetTensorDataIndex(n, ctext->num_outputs()),
|
||||||
tf_shape, output));
|
tf_shape, output));
|
||||||
|
if (!eager_mode) {
|
||||||
|
Tensor* second_tensor = nullptr;
|
||||||
|
TensorShape second_shape;
|
||||||
|
second_shape.AddDim(mkl_shape.GetSerializeBufferSize());
|
||||||
OP_REQUIRES_OK(ctext, ctext->allocate_output(
|
OP_REQUIRES_OK(ctext, ctext->allocate_output(
|
||||||
GetTensorMetaDataIndex(n, ctext->num_outputs()),
|
GetTensorMetaDataIndex(n, ctext->num_outputs()),
|
||||||
second_shape, &second_tensor));
|
second_shape, &second_tensor));
|
||||||
@ -782,6 +795,7 @@ inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
|
|||||||
second_tensor->flat<uint8>().data(),
|
second_tensor->flat<uint8>().data(),
|
||||||
second_tensor->flat<uint8>().size() * sizeof(uint8));
|
second_tensor->flat<uint8>().size() * sizeof(uint8));
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Allocates a temp tensor and returns the data buffer for temporary storage.
|
// Allocates a temp tensor and returns the data buffer for temporary storage.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
Loading…
x
Reference in New Issue
Block a user