Added identity op and a fix for LRN (#9641)
* relu grad and maxpooling grad fixes for perf * Graph layout pass and conversion pass changes This commit makes following changes: - Enables support for ReluGrad and BiasAddGrad - Adds support for detecting depthwise/batchwise pooling - Adds more unit tests for Graph rewrite pass - Improvements to handling control-flow edges - Bug fixes * Defaulting to Eigen when LRN depth_radius!=2 * Fixed mkl_conv_grad_filter.cc for conv_ops_tests.py * Style fix to mkl_matmul and remove unnecessary 'MKL' label on matmul kernel * Style fixes based on clang-format to mkl_conv_* and mkl_matmul * Bug fixes * Adding OP_REQUIRES_OK check in Concat * Making some style changes * Enabled the configuration of MKL settings * relu grad and maxpooling grad fixes for perf * Graph layout pass and conversion pass changes This commit makes following changes: - Enables support for ReluGrad and BiasAddGrad - Adds support for detecting depthwise/batchwise pooling - Adds more unit tests for Graph rewrite pass - Improvements to handling control-flow edges - Bug fixes * Defaulting to Eigen when LRN depth_radius!=2 * Fixed mkl_conv_grad_filter.cc for conv_ops_tests.py * Style fix to mkl_matmul and remove unnecessary 'MKL' label on matmul kernel * Style fixes based on clang-format to mkl_conv_* and mkl_matmul * Bug fixes * Adding OP_REQUIRES_OK check in Concat * Making some style changes * Enabled the configuration of MKL settings * Fixing graph unit tests with Mkl op name change to _Mkl; Fixed missing _ in MklToTf op * Fixed missing libdl.so.2 in BUILD file * Fixes for unit test build failures. * Changes in mkl_conv_grad_filter_ops.cc for Google code style * Fixes to remove dead code * removed the dead code and added a TODO for mkl implementation to handle this case in the future * Enabling MklIdentityOp * Calling MKL for all values of depth radius in LRN * Fixed buildifier sanity check error * Adding support for google's CI automation * Updated link to new MKL version * Enabling MklIdentityOp * Calling MKL for all values of depth radius in LRN * Fix for missing locate binary * Fix for missing locate command in CI * Adding updatedb to populate the database after installing mlocate * Fixed buildifier issue * setting tf_need_mkl=0 in libtf files * Added third_party/mkl/* to .gitignore * Added third_party/eigen3/mkl_include to .gitignore * In configured, set MKL-enabling options only for Linux. * Enabling MklIdentityOp * Calling MKL for all values of depth radius in LRN * Making style fix in LRN * Fixed Indentation
This commit is contained in:
parent
27dd167c5f
commit
243ac01320
@ -736,6 +736,7 @@ cc_library(
|
|||||||
"//tensorflow/core/kernels:mkl_concat_op",
|
"//tensorflow/core/kernels:mkl_concat_op",
|
||||||
"//tensorflow/core/kernels:mkl_conv_op",
|
"//tensorflow/core/kernels:mkl_conv_op",
|
||||||
"//tensorflow/core/kernels:mkl_fused_batch_norm_op",
|
"//tensorflow/core/kernels:mkl_fused_batch_norm_op",
|
||||||
|
"//tensorflow/core/kernels:mkl_identity_op",
|
||||||
"//tensorflow/core/kernels:mkl_lrn_op",
|
"//tensorflow/core/kernels:mkl_lrn_op",
|
||||||
"//tensorflow/core/kernels:mkl_pooling_ops",
|
"//tensorflow/core/kernels:mkl_pooling_ops",
|
||||||
"//tensorflow/core/kernels:mkl_relu_op",
|
"//tensorflow/core/kernels:mkl_relu_op",
|
||||||
@ -2123,6 +2124,7 @@ tf_cc_test_mkl(
|
|||||||
"//tensorflow/core/kernels:mkl_concat_op",
|
"//tensorflow/core/kernels:mkl_concat_op",
|
||||||
"//tensorflow/core/kernels:mkl_conv_op",
|
"//tensorflow/core/kernels:mkl_conv_op",
|
||||||
"//tensorflow/core/kernels:mkl_fused_batch_norm_op",
|
"//tensorflow/core/kernels:mkl_fused_batch_norm_op",
|
||||||
|
"//tensorflow/core/kernels:mkl_identity_op",
|
||||||
"//tensorflow/core/kernels:mkl_lrn_op",
|
"//tensorflow/core/kernels:mkl_lrn_op",
|
||||||
"//tensorflow/core/kernels:mkl_pooling_ops",
|
"//tensorflow/core/kernels:mkl_pooling_ops",
|
||||||
"//tensorflow/core/kernels:mkl_relu_op",
|
"//tensorflow/core/kernels:mkl_relu_op",
|
||||||
|
@ -273,6 +273,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
csinfo_.conv2d_grad_filter = "Conv2DBackpropFilter";
|
csinfo_.conv2d_grad_filter = "Conv2DBackpropFilter";
|
||||||
csinfo_.fused_batch_norm = "FusedBatchNorm";
|
csinfo_.fused_batch_norm = "FusedBatchNorm";
|
||||||
csinfo_.fused_batch_norm_grad = "FusedBatchNormGrad";
|
csinfo_.fused_batch_norm_grad = "FusedBatchNormGrad";
|
||||||
|
csinfo_.identity = "Identity";
|
||||||
csinfo_.lrn = "LRN";
|
csinfo_.lrn = "LRN";
|
||||||
csinfo_.lrn_grad = "LRNGrad";
|
csinfo_.lrn_grad = "LRNGrad";
|
||||||
csinfo_.matmul = "MatMul";
|
csinfo_.matmul = "MatMul";
|
||||||
@ -326,6 +327,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
rinfo_.push_back({csinfo_.fused_batch_norm_grad,
|
rinfo_.push_back({csinfo_.fused_batch_norm_grad,
|
||||||
GetMklOpName(csinfo_.fused_batch_norm_grad),
|
GetMklOpName(csinfo_.fused_batch_norm_grad),
|
||||||
CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr});
|
CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr});
|
||||||
|
rinfo_.push_back({csinfo_.identity,
|
||||||
|
GetMklOpName(csinfo_.identity),
|
||||||
|
CopyAttrsIdentity, AlwaysRewrite, nullptr});
|
||||||
rinfo_.push_back({csinfo_.lrn,
|
rinfo_.push_back({csinfo_.lrn,
|
||||||
GetMklOpName(csinfo_.lrn),
|
GetMklOpName(csinfo_.lrn),
|
||||||
CopyAttrsLRN, AlwaysRewrite, nullptr});
|
CopyAttrsLRN, AlwaysRewrite, nullptr});
|
||||||
@ -446,6 +450,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
string conv2d_grad_filter;
|
string conv2d_grad_filter;
|
||||||
string fused_batch_norm;
|
string fused_batch_norm;
|
||||||
string fused_batch_norm_grad;
|
string fused_batch_norm_grad;
|
||||||
|
string identity;
|
||||||
string lrn;
|
string lrn;
|
||||||
string lrn_grad;
|
string lrn_grad;
|
||||||
string matmul;
|
string matmul;
|
||||||
@ -767,6 +772,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
static void CopyAttrsConcatV2(const Node* orig_node, NodeBuilder* nb);
|
static void CopyAttrsConcatV2(const Node* orig_node, NodeBuilder* nb);
|
||||||
static void CopyAttrsConv2D(const Node* orig_node, NodeBuilder* nb);
|
static void CopyAttrsConv2D(const Node* orig_node, NodeBuilder* nb);
|
||||||
static void CopyAttrsFusedBatchNorm(const Node* orig_node, NodeBuilder* nb);
|
static void CopyAttrsFusedBatchNorm(const Node* orig_node, NodeBuilder* nb);
|
||||||
|
static void CopyAttrsIdentity(const Node* orig_node, NodeBuilder* nb);
|
||||||
static void CopyAttrsLRN(const Node* orig_node, NodeBuilder* nb);
|
static void CopyAttrsLRN(const Node* orig_node, NodeBuilder* nb);
|
||||||
static void CopyAttrsPooling(const Node* orig_node, NodeBuilder* nb);
|
static void CopyAttrsPooling(const Node* orig_node, NodeBuilder* nb);
|
||||||
static void CopyAttrsRelu(const Node* orig_node, NodeBuilder* nb);
|
static void CopyAttrsRelu(const Node* orig_node, NodeBuilder* nb);
|
||||||
@ -1275,6 +1281,16 @@ void MklLayoutRewritePass::CopyAttrsBiasAddGrad(const Node* orig_node,
|
|||||||
nb->Attr("data_format", data_format);
|
nb->Attr("data_format", data_format);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void MklLayoutRewritePass::CopyAttrsIdentity(const Node* orig_node,
|
||||||
|
NodeBuilder* nb) {
|
||||||
|
DataType T;
|
||||||
|
|
||||||
|
// Get all attributes from old node.
|
||||||
|
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
|
||||||
|
// Add attributes to new node.
|
||||||
|
nb->Attr("T", T);
|
||||||
|
}
|
||||||
|
|
||||||
void MklLayoutRewritePass::CopyAttrsLRN(const Node* orig_node,
|
void MklLayoutRewritePass::CopyAttrsLRN(const Node* orig_node,
|
||||||
NodeBuilder* nb) {
|
NodeBuilder* nb) {
|
||||||
DataType T;
|
DataType T;
|
||||||
|
@ -4918,6 +4918,14 @@ tf_mkl_kernel_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_mkl_kernel_library(
|
||||||
|
name = "mkl_identity_op",
|
||||||
|
prefix = "mkl_identity_op",
|
||||||
|
deps = ARRAY_DEPS + [
|
||||||
|
"//third_party/mkl:intel_binary_blob",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_mkl_kernel_library(
|
tf_mkl_kernel_library(
|
||||||
name = "mkl_lrn_op",
|
name = "mkl_lrn_op",
|
||||||
prefix = "mkl_lrn_op",
|
prefix = "mkl_lrn_op",
|
||||||
|
63
tensorflow/core/kernels/mkl_identity_op.cc
Normal file
63
tensorflow/core/kernels/mkl_identity_op.cc
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// See docs in ../ops/array_ops.cc.
|
||||||
|
#ifdef INTEL_MKL
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/framework/types.h"
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/util/mkl_util.h"
|
||||||
|
#include "third_party/mkl/include/mkl_dnn.h"
|
||||||
|
#include "third_party/mkl/include/mkl_dnn_types.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||||
|
template <typename Device, typename T>
|
||||||
|
class MklIdentityOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit MklIdentityOp(OpKernelConstruction* context) : OpKernel(context) {}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* context) override {
|
||||||
|
MklShape mkl_shape_input;
|
||||||
|
GetMklShape(context, 0, &mkl_shape_input);
|
||||||
|
bool input_in_mkl_format = mkl_shape_input.IsMklTensor();
|
||||||
|
|
||||||
|
if (input_in_mkl_format) {
|
||||||
|
ForwarMklTensorInToOut(context, 0, 0);
|
||||||
|
} else {
|
||||||
|
FowardTfTensorInToOut(context, 0, 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsExpensive() override { return false; }
|
||||||
|
};
|
||||||
|
|
||||||
|
#define REGISTER_MKL_CPU(T) \
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("_MklIdentity") \
|
||||||
|
.Device(DEVICE_CPU) \
|
||||||
|
.TypeConstraint<T>("T") \
|
||||||
|
.Label(mkl_op_registry::kMklOpLabel), \
|
||||||
|
MklIdentityOp<CPUDevice, T>); \
|
||||||
|
|
||||||
|
TF_CALL_float(REGISTER_MKL_CPU);
|
||||||
|
#undef REGISTER_MKL_CPU
|
||||||
|
} // namespace tensorflow
|
||||||
|
#endif // INTEL_MKL
|
@ -171,9 +171,7 @@ class MklLRNOp : public OpKernel {
|
|||||||
MklShape input_shape;
|
MklShape input_shape;
|
||||||
dnnPrimitive_t lrn_fwd = nullptr;
|
dnnPrimitive_t lrn_fwd = nullptr;
|
||||||
dnnPrimitive_t convert_input = nullptr;
|
dnnPrimitive_t convert_input = nullptr;
|
||||||
/* dnnPrimitive_t convert_output; */
|
|
||||||
dnnLayout_t lt_input = nullptr;
|
dnnLayout_t lt_input = nullptr;
|
||||||
/* dnnLayout_t lt_output; */
|
|
||||||
dnnLayout_t lt_internal_input = nullptr;
|
dnnLayout_t lt_internal_input = nullptr;
|
||||||
dnnLayout_t lt_internal_workspace = nullptr;
|
dnnLayout_t lt_internal_workspace = nullptr;
|
||||||
dnnLayout_t lt_internal_output = nullptr;
|
dnnLayout_t lt_internal_output = nullptr;
|
||||||
@ -390,11 +388,6 @@ class MklLRNGradOp : public OpKernel {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (depth_radius_ != 2) {
|
|
||||||
mkl_context.MklDefaultToEigen(context);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (ingrad_in_mkl_format || inimage_in_mkl_format) {
|
if (ingrad_in_mkl_format || inimage_in_mkl_format) {
|
||||||
const MklShape* tmp_mkl_shape = (ingrad_in_mkl_format)
|
const MklShape* tmp_mkl_shape = (ingrad_in_mkl_format)
|
||||||
? &mkl_context.ingrad_shape
|
? &mkl_context.ingrad_shape
|
||||||
@ -476,11 +469,11 @@ class MklLRNGradOp : public OpKernel {
|
|||||||
const_cast<void*>(static_cast<const void*>(output->flat<T>().data()));
|
const_cast<void*>(static_cast<const void*>(output->flat<T>().data()));
|
||||||
|
|
||||||
Tensor mkl_tmp_input_buf_tensor, mkl_tmp_image_buf_tensor,
|
Tensor mkl_tmp_input_buf_tensor, mkl_tmp_image_buf_tensor,
|
||||||
mkl_tmp_outimage_buf_tensor, mkl_tmp_workspace_buf_tensor;
|
mkl_tmp_outimage_buf_tensor;
|
||||||
// Convert Inputs if needed
|
// Convert Inputs if needed
|
||||||
mkl_context.MklPrepareLRNGradInput(
|
mkl_context.MklPrepareLRNGradInput(context, &mkl_tmp_input_buf_tensor,
|
||||||
context, &mkl_tmp_input_buf_tensor, &mkl_tmp_image_buf_tensor,
|
&mkl_tmp_image_buf_tensor,
|
||||||
&mkl_tmp_outimage_buf_tensor, &mkl_tmp_workspace_buf_tensor);
|
&mkl_tmp_outimage_buf_tensor);
|
||||||
|
|
||||||
// We do not do any conversion for output. But we simply emit it
|
// We do not do any conversion for output. But we simply emit it
|
||||||
// in MKL format.
|
// in MKL format.
|
||||||
@ -537,11 +530,13 @@ class MklLRNGradOp : public OpKernel {
|
|||||||
void MklPrepareLRNGradInput(OpKernelContext* context,
|
void MklPrepareLRNGradInput(OpKernelContext* context,
|
||||||
Tensor* mkl_tmp_input_buf_tensor,
|
Tensor* mkl_tmp_input_buf_tensor,
|
||||||
Tensor* mkl_tmp_image_buf_tensor,
|
Tensor* mkl_tmp_image_buf_tensor,
|
||||||
Tensor* mkl_tmp_outimage_buf_tensor,
|
Tensor* mkl_tmp_outimage_buf_tensor) {
|
||||||
Tensor* mkl_tmp_workspace_buf_tensor) {
|
|
||||||
const Tensor& in_grads = MklGetInput(context, 0);
|
const Tensor& in_grads = MklGetInput(context, 0);
|
||||||
const Tensor& in_image = MklGetInput(context, 1);
|
const Tensor& in_image = MklGetInput(context, 1);
|
||||||
const Tensor& out_image = MklGetInput(context, 2);
|
const Tensor& out_image = MklGetInput(context, 2);
|
||||||
|
const Tensor& workspace = MklGetInput(
|
||||||
|
context,
|
||||||
|
3); /*Worskpsace is enabled, get the buffer to the workspace */
|
||||||
|
|
||||||
void* user_input = const_cast<void*>(
|
void* user_input = const_cast<void*>(
|
||||||
static_cast<const void*>(in_grads.flat<T>().data()));
|
static_cast<const void*>(in_grads.flat<T>().data()));
|
||||||
@ -549,6 +544,9 @@ class MklLRNGradOp : public OpKernel {
|
|||||||
static_cast<const void*>(in_image.flat<T>().data()));
|
static_cast<const void*>(in_image.flat<T>().data()));
|
||||||
void* user_fwd_output = const_cast<void*>(
|
void* user_fwd_output = const_cast<void*>(
|
||||||
static_cast<const void*>(out_image.flat<T>().data()));
|
static_cast<const void*>(out_image.flat<T>().data()));
|
||||||
|
void* workspace_buffer = const_cast<void*>(
|
||||||
|
static_cast<const void*>(workspace.flat<T>().data()));
|
||||||
|
|
||||||
CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(<_workspace, lrn_bwd,
|
CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(<_workspace, lrn_bwd,
|
||||||
dnnResourceWorkspace),
|
dnnResourceWorkspace),
|
||||||
E_SUCCESS);
|
E_SUCCESS);
|
||||||
@ -623,9 +621,7 @@ class MklLRNGradOp : public OpKernel {
|
|||||||
res_lrn_bwd[dnnResourceDst] = user_fwd_output;
|
res_lrn_bwd[dnnResourceDst] = user_fwd_output;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Allocate buffer for workspace.
|
res_lrn_bwd[dnnResourceWorkspace] = workspace_buffer;
|
||||||
AllocTmpBuffer(context, mkl_tmp_workspace_buf_tensor, lt_workspace,
|
|
||||||
&res_lrn_bwd[dnnResourceWorkspace]);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fallback implementation - Taken from lrn_op.cc
|
// Fallback implementation - Taken from lrn_op.cc
|
||||||
@ -713,7 +709,7 @@ class MklLRNGradOp : public OpKernel {
|
|||||||
Shard(worker_threads.num_threads, worker_threads.workers, nodes * batch,
|
Shard(worker_threads.num_threads, worker_threads.workers, nodes * batch,
|
||||||
depth * depth, shard);
|
depth * depth, shard);
|
||||||
}
|
}
|
||||||
|
|
||||||
// release mkl resources
|
// release mkl resources
|
||||||
void Mklcleanup() {
|
void Mklcleanup() {
|
||||||
bool ingrad_in_mkl_format = ingrad_shape.IsMklTensor();
|
bool ingrad_in_mkl_format = ingrad_shape.IsMklTensor();
|
||||||
|
@ -129,7 +129,7 @@ class MklReshapeOp : public OpKernel {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
CopyTFTensorInToOut(context, 0, 0, shape);
|
CopyTfTensorInToOutWithShape(context, 0, 0, shape);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -1506,6 +1506,23 @@ REGISTER_OP("Identity")
|
|||||||
Return a tensor with the same shape and contents as the input tensor or value.
|
Return a tensor with the same shape and contents as the input tensor or value.
|
||||||
)Doc");
|
)Doc");
|
||||||
|
|
||||||
|
#ifdef INTEL_MKL
|
||||||
|
REGISTER_OP("_MklIdentity")
|
||||||
|
.Input("input: T")
|
||||||
|
.Input("mkl_input: uint8")
|
||||||
|
.Output("output: T")
|
||||||
|
.Output("mkl_output: uint8")
|
||||||
|
.Attr("T: type")
|
||||||
|
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||||
|
c->set_output(0, c->input(0));
|
||||||
|
c->set_output_handle_dtype(0, c->input_handle_dtype(0));
|
||||||
|
c->set_output_handle_shape(0, c->input_handle_shape(0));
|
||||||
|
return Status::OK();
|
||||||
|
})
|
||||||
|
.Doc(R"Doc( Mkl implementation of IdentityOp
|
||||||
|
)Doc");
|
||||||
|
#endif
|
||||||
|
|
||||||
// --------------------------------------------------------------------------
|
// --------------------------------------------------------------------------
|
||||||
REGISTER_OP("RefIdentity")
|
REGISTER_OP("RefIdentity")
|
||||||
.Input("input: Ref(T)")
|
.Input("input: Ref(T)")
|
||||||
|
@ -542,8 +542,8 @@ inline int64 GetMklTensorDim(const MklShape& mkl_shape, char dimension) {
|
|||||||
return mkl_shape.dim_size(index);
|
return mkl_shape.dim_size(index);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void CopyMklTensorInToOut(OpKernelContext* context, int idx_in,
|
inline void CopyMklTensorInToOut(OpKernelContext* context,
|
||||||
int idx_out) {
|
int idx_in, int idx_out) {
|
||||||
int num_inputs = context->num_inputs();
|
int num_inputs = context->num_inputs();
|
||||||
int num_outputs = context->num_outputs();
|
int num_outputs = context->num_outputs();
|
||||||
int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
|
int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
|
||||||
@ -563,8 +563,9 @@ inline void CopyMklTensorInToOut(OpKernelContext* context, int idx_in,
|
|||||||
context->set_output(idx_meta_out, meta_output);
|
context->set_output(idx_meta_out, meta_output);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void CopyTFTensorInToOut(OpKernelContext* context, int idx_in,
|
inline void CopyTfTensorInToOutWithShape(OpKernelContext* context,
|
||||||
int idx_out, const TensorShape& shape) {
|
int idx_in, int idx_out,
|
||||||
|
const TensorShape& shape) {
|
||||||
int num_inputs = context->num_inputs();
|
int num_inputs = context->num_inputs();
|
||||||
int num_outputs = context->num_outputs();
|
int num_outputs = context->num_outputs();
|
||||||
int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
|
int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
|
||||||
@ -580,6 +581,41 @@ inline void CopyTFTensorInToOut(OpKernelContext* context, int idx_in,
|
|||||||
context->set_output(idx_data_out, output);
|
context->set_output(idx_data_out, output);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline void FowardTfTensorInToOut(OpKernelContext* context,
|
||||||
|
int idx_in, int idx_out) {
|
||||||
|
int num_inputs = context->num_inputs();
|
||||||
|
int num_outputs = context->num_outputs();
|
||||||
|
int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
|
||||||
|
int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
|
||||||
|
|
||||||
|
MklShape mkl_shape_output;
|
||||||
|
mkl_shape_output.SetMklTensor(false);
|
||||||
|
AllocateOutputSetMklShape(context, idx_out, mkl_shape_output);
|
||||||
|
if (IsRefType(context->input_dtype(idx_data_in))) {
|
||||||
|
context->forward_ref_input_to_ref_output(idx_data_in, idx_data_out);
|
||||||
|
} else {
|
||||||
|
context->set_output(idx_data_out, context->input(idx_data_in));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void ForwarMklTensorInToOut(OpKernelContext* context,
|
||||||
|
int idx_in, int idx_out) {
|
||||||
|
int num_inputs = context->num_inputs();
|
||||||
|
int num_outputs = context->num_outputs();
|
||||||
|
int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
|
||||||
|
int idx_meta_in = GetTensorMetaDataIndex(idx_in, num_inputs);
|
||||||
|
int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
|
||||||
|
int idx_meta_out = GetTensorMetaDataIndex(idx_out, num_outputs);
|
||||||
|
|
||||||
|
if (IsRefType(context->input_dtype(idx_data_in))) {
|
||||||
|
context->forward_ref_input_to_ref_output(idx_data_in, idx_data_out);
|
||||||
|
context->forward_ref_input_to_ref_output(idx_meta_in, idx_meta_out);
|
||||||
|
} else {
|
||||||
|
context->set_output(idx_data_out, context->input(idx_data_in));
|
||||||
|
context->set_output(idx_meta_out, context->input(idx_meta_in));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
namespace mkl_op_registry {
|
namespace mkl_op_registry {
|
||||||
static const char* kMklOpLabel = "MklOp";
|
static const char* kMklOpLabel = "MklOp";
|
||||||
static const char* kMklOpLabelPattern = "label='MklOp'";
|
static const char* kMklOpLabelPattern = "label='MklOp'";
|
||||||
|
Loading…
Reference in New Issue
Block a user