Initial FusedMatMul kernel implementation: MatMul+BiasAdd.
PiperOrigin-RevId: 242742812
This commit is contained in:
parent
244cb0b925
commit
2c12f4c8c5
@ -126,6 +126,7 @@ tensorflow/core/kernels/fill_functor.cc
|
||||
tensorflow/core/kernels/fft_ops.cc
|
||||
tensorflow/core/kernels/function_ops.cc
|
||||
tensorflow/core/kernels/fused_batch_norm_op.cc
|
||||
tensorflow/core/kernels/fused_eigen_output_kernels.cc
|
||||
tensorflow/core/kernels/gather_functor.cc
|
||||
tensorflow/core/kernels/gather_nd_op.cc
|
||||
tensorflow/core/kernels/gather_nd_op_cpu_impl_0.cc
|
||||
|
@ -773,6 +773,18 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "fused_eigen_output_kernels",
|
||||
srcs = ["fused_eigen_output_kernels.cc"],
|
||||
hdrs = ["fused_eigen_output_kernels.h"],
|
||||
deps = [
|
||||
":eigen_contraction_kernel",
|
||||
"//tensorflow/core:framework",
|
||||
"//third_party/eigen3",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "eigen_helpers",
|
||||
hdrs = [
|
||||
@ -3431,6 +3443,7 @@ tf_kernel_library(
|
||||
name = "matmul_op",
|
||||
srcs = [
|
||||
"matmul_op.cc",
|
||||
"matmul_op_fused.cc",
|
||||
] + if_mkl([
|
||||
"mkl_matmul_op.cc",
|
||||
]),
|
||||
@ -3444,6 +3457,7 @@ tf_kernel_library(
|
||||
}),
|
||||
deps = MATH_DEPS + [
|
||||
":eigen_contraction_kernel",
|
||||
":fused_eigen_output_kernels",
|
||||
":gpu_utils",
|
||||
] + select({
|
||||
":xsmm": ["@libxsmm_archive//:xsmm_avx"],
|
||||
@ -3628,11 +3642,16 @@ tf_cuda_cc_test(
|
||||
":quantized_ops",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:client_session",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:tensorflow",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
],
|
||||
)
|
||||
|
||||
@ -3873,6 +3892,7 @@ tf_kernel_library(
|
||||
":eigen_contraction_kernel",
|
||||
":image_resizer_state",
|
||||
":fill_functor",
|
||||
":fused_eigen_output_kernels",
|
||||
":ops_util",
|
||||
"@com_google_absl//absl/base:dynamic_annotations",
|
||||
"@com_google_absl//absl/strings",
|
||||
@ -5850,10 +5870,14 @@ filegroup(
|
||||
"depthwise_conv_op.cc",
|
||||
"dynamic_partition_op.cc",
|
||||
"encode_wav_op.cc",
|
||||
"eigen_contraction_kernel.cc",
|
||||
"eigen_contraction_kernel.h",
|
||||
"fake_quant_ops.cc",
|
||||
"fifo_queue.cc",
|
||||
"fifo_queue_op.cc",
|
||||
"fused_batch_norm_op.cc",
|
||||
"fused_eigen_output_kernels.cc",
|
||||
"fused_eigen_output_kernels.h",
|
||||
"listdiff_op.cc",
|
||||
"population_count_op.cc",
|
||||
"population_count_op.h",
|
||||
|
@ -51,6 +51,8 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/kernels/conv_2d.h"
|
||||
#include "tensorflow/core/kernels/conv_ops.h"
|
||||
#include "tensorflow/core/kernels/eigen_contraction_kernel.h"
|
||||
#include "tensorflow/core/kernels/fused_eigen_output_kernels.h"
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
#include "tensorflow/core/util/use_cudnn.h"
|
||||
@ -69,28 +71,6 @@ class AutotuneResult;
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
// Supported Conv2D fusions. Not all of them supported on all type of devices.
|
||||
enum class FusedComputationType {
|
||||
kUndefined,
|
||||
// NOTE(ezhulenev): CuDNN `cudnnConvolutionBiasActivationForward` supports
|
||||
// identity activation function, it in theory should allow to fuse convolution
|
||||
// with BiasAdd, but in practice it doesn't work, cuDNN ignores this parameter
|
||||
// and always does Relu activation.
|
||||
kBiasAdd, // CPU
|
||||
kBiasAddWithRelu, // CPU and GPU
|
||||
kBiasAddWithRelu6, // CPU
|
||||
kBiasAddWithElu, // CPU
|
||||
kFusedBatchNorm, // CPU
|
||||
kFusedBatchNormWithRelu, // CPU
|
||||
kFusedBatchNormWithRelu6, // CPU
|
||||
kFusedBatchNormWithElu // CPU
|
||||
};
|
||||
|
||||
// We have to pass around additional arguments for all possible fusion types.
|
||||
struct FusedComputationArgs {
|
||||
float epsilon = 0.0; // Used by `FusedBatchNorm` fusion only
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
struct LaunchFusedConv2DOp {
|
||||
void operator()(OpKernelContext* context, bool use_cudnn,
|
||||
@ -101,165 +81,6 @@ struct LaunchFusedConv2DOp {
|
||||
const Conv2DDimensions& dimensions, Tensor* output);
|
||||
};
|
||||
|
||||
// Type alias for the tensor contraction output mapper.
|
||||
template <typename Scalar, typename Index>
|
||||
using ContractionOutputMapper =
|
||||
Eigen::internal::blas_data_mapper<Scalar, Index, Eigen::ColMajor>;
|
||||
|
||||
// Returns input expression without any transformations.
|
||||
struct Identity {
|
||||
template <typename XprType>
|
||||
static auto apply(XprType expr) -> XprType {
|
||||
return expr;
|
||||
};
|
||||
};
|
||||
|
||||
// Applies `Relu` to the passed input expression.
|
||||
struct Relu {
|
||||
template <typename XprType>
|
||||
static auto apply(XprType expr)
|
||||
-> decltype(expr.cwiseMax(std::declval<typename XprType::Scalar>())) {
|
||||
return expr.cwiseMax(static_cast<typename XprType::Scalar>(0));
|
||||
};
|
||||
};
|
||||
|
||||
// Applies `Relu6` to the passed input expression.
|
||||
struct Relu6 {
|
||||
template <typename XprType>
|
||||
static auto apply(XprType expr)
|
||||
-> decltype(expr.cwiseMax(std::declval<typename XprType::Scalar>())
|
||||
.cwiseMin(std::declval<typename XprType::Scalar>())) {
|
||||
return expr.cwiseMax(static_cast<typename XprType::Scalar>(0))
|
||||
.cwiseMin(static_cast<typename XprType::Scalar>(6));
|
||||
};
|
||||
};
|
||||
|
||||
// Applies `Elu` to the passed input expression.
|
||||
struct Elu {
|
||||
template <typename XprType>
|
||||
static auto apply(XprType expr) -> decltype(
|
||||
(expr < std::declval<typename XprType::Scalar>())
|
||||
.select(expr.exp() -
|
||||
expr.constant(std::declval<typename XprType::Scalar>()),
|
||||
expr)) {
|
||||
return (expr < static_cast<typename XprType::Scalar>(0))
|
||||
.select(expr.exp() -
|
||||
expr.constant(static_cast<typename XprType::Scalar>(1)),
|
||||
expr);
|
||||
};
|
||||
};
|
||||
|
||||
// TensorContraction swaps lhs with rhs, and changes layout from RowMajor
|
||||
// (default in Tensorflow) to ColMajor (preferred in Eigen), and computes matmul
|
||||
// using these tensors.
|
||||
//
|
||||
// TensorContraction output matrix (before reshape) has a ColMajor layout, and
|
||||
// has dimensions:
|
||||
// - rows: output_channels
|
||||
// - cols: all other dimensions
|
||||
//
|
||||
// First element in every column is:
|
||||
// [batch ??, height ??, width ??, out_channel = i]
|
||||
//
|
||||
// We do not know what are the values of the 'batch', 'height', and 'width' here
|
||||
// (if we know original dimensions, they can be computed from 'j').
|
||||
//
|
||||
// Each column of an output block is a continuous slice along the output channel
|
||||
// dimension, so we can use it to efficiently compute any transformation that
|
||||
// depends only on a channel value (e.g. add channel bias).
|
||||
|
||||
// Output kernel that fuses BiasAdd operation into the output of tensor
|
||||
// contraction + activation function defined by Activation.
|
||||
template <typename T, typename Activation = Identity>
|
||||
struct BiasAddOutputKernel {
|
||||
explicit BiasAddOutputKernel(const T* bias_data) : bias_data(bias_data) {}
|
||||
|
||||
template <typename Index, typename Scalar>
|
||||
EIGEN_ALWAYS_INLINE void operator()(
|
||||
const ContractionOutputMapper<Scalar, Index>& output_mapper,
|
||||
const Eigen::TensorContractionParams& params, Index i, Index j,
|
||||
Index num_rows, Index num_cols) const {
|
||||
DCHECK(params.swapped_arguments);
|
||||
|
||||
const T* bias_base = bias_data + i;
|
||||
typename TTypes<T>::UnalignedConstTensor bias(bias_base, num_rows);
|
||||
|
||||
for (int col = 0; col < num_cols; ++col) {
|
||||
T* output_base = &output_mapper(0, col);
|
||||
typename TTypes<T>::UnalignedTensor output(output_base, num_rows);
|
||||
const auto expr = output + bias;
|
||||
output = Activation::template apply<decltype(expr)>(expr);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
const T* bias_data;
|
||||
};
|
||||
|
||||
// Output kernel that fuses FusedBatchNorm operation into the output of tensor
|
||||
// contraction + activation function defined by Activation.
|
||||
template <typename T, typename Activation = Identity>
|
||||
struct FusedBatchNormOutputKernel {
|
||||
FusedBatchNormOutputKernel(T epsilon, const T* scaling_factor_data,
|
||||
const T* offset_data, const T* estimated_mean_data)
|
||||
: epsilon(epsilon),
|
||||
scaling_factor_data(scaling_factor_data),
|
||||
offset_data(offset_data),
|
||||
estimated_mean_data(estimated_mean_data) {}
|
||||
|
||||
template <typename Index, typename Scalar>
|
||||
EIGEN_ALWAYS_INLINE void operator()(
|
||||
const ContractionOutputMapper<Scalar, Index>& output_mapper,
|
||||
const Eigen::TensorContractionParams& params, Index i, Index j,
|
||||
Index num_rows, Index num_cols) const {
|
||||
DCHECK(params.swapped_arguments);
|
||||
|
||||
const T* scaling_factor_base = scaling_factor_data + i;
|
||||
const T* offset_base = offset_data + i;
|
||||
const T* mean_base = estimated_mean_data + i;
|
||||
|
||||
typename TTypes<T>::UnalignedConstTensor scaling_factor(scaling_factor_base,
|
||||
num_rows);
|
||||
typename TTypes<T>::UnalignedConstTensor offset(offset_base, num_rows);
|
||||
typename TTypes<T>::UnalignedConstTensor mean(mean_base, num_rows);
|
||||
|
||||
for (int col = 0; col < num_cols; ++col) {
|
||||
T* output_base = &output_mapper(0, col);
|
||||
typename TTypes<T>::UnalignedTensor output(output_base, num_rows);
|
||||
|
||||
auto scaled = (output - mean) * scaling_factor;
|
||||
auto shifted = scaled + offset;
|
||||
|
||||
output = Activation::template apply<decltype(shifted)>(shifted);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
T epsilon;
|
||||
const T* scaling_factor_data;
|
||||
const T* offset_data;
|
||||
const T* estimated_mean_data;
|
||||
};
|
||||
|
||||
// Type aliases for the output kernels, purely for the sake of better launch
|
||||
// dispatching code readability.
|
||||
template <typename T>
|
||||
using WithBiasAdd = BiasAddOutputKernel<T>;
|
||||
template <typename T>
|
||||
using WithBiasAddAndRelu = BiasAddOutputKernel<T, Relu>;
|
||||
template <typename T>
|
||||
using WithBiasAddAndRelu6 = BiasAddOutputKernel<T, Relu6>;
|
||||
template <typename T>
|
||||
using WithBiasAddAndElu = BiasAddOutputKernel<T, Elu>;
|
||||
template <typename T>
|
||||
using WithFusedBatchNorm = FusedBatchNormOutputKernel<T>;
|
||||
template <typename T>
|
||||
using WithFusedBatchNormAndRelu = FusedBatchNormOutputKernel<T, Relu>;
|
||||
template <typename T>
|
||||
using WithFusedBatchNormAndRelu6 = FusedBatchNormOutputKernel<T, Relu6>;
|
||||
template <typename T>
|
||||
using WithFusedBatchNormAndElu = FusedBatchNormOutputKernel<T, Elu>;
|
||||
|
||||
// This is CPU-only implementation that uses Eigen contraction output kernels.
|
||||
//
|
||||
// Dispatch 2D convolution to the appropriate primitive operation:
|
||||
@ -346,8 +167,17 @@ struct LaunchFusedConv2DOp<CPUDevice, T> {
|
||||
errors::Unimplemented("Fused conv implementation only supports "
|
||||
"NHWC tensor format for now."));
|
||||
|
||||
BiasAddArgs bias_add;
|
||||
FusedBatchNormArgs fused_batch_norm;
|
||||
BiasAddArgs<T> bias_add_args;
|
||||
if (BiasAddArgs<T>::IsSupported(fusion)) {
|
||||
OP_REQUIRES_OK(context, InitBiasAddArgs(context, &bias_add_args));
|
||||
}
|
||||
|
||||
FusedBatchNormArgs<T> fused_batch_norm_args;
|
||||
if (FusedBatchNormArgs<T>::IsSupported(fusion)) {
|
||||
OP_REQUIRES_OK(context,
|
||||
InitFusedBatchNormArgs(context, fusion_args.epsilon,
|
||||
&fused_batch_norm_args));
|
||||
}
|
||||
|
||||
LaunchFusedConv2DWithOutputKernel<T> conv2d(
|
||||
dimensions.stride_rows, dimensions.stride_cols,
|
||||
@ -357,148 +187,43 @@ struct LaunchFusedConv2DOp<CPUDevice, T> {
|
||||
case FusedComputationType::kUndefined:
|
||||
OP_REQUIRES_OK(context, errors::Internal("Fusion type is undefined"));
|
||||
break;
|
||||
|
||||
case FusedComputationType::kBiasAdd:
|
||||
OP_REQUIRES_OK(context, InitBiasAddArgs(context, &bias_add));
|
||||
conv2d(WithBiasAdd<T>(bias_add.bias_add_data), context, input, filter,
|
||||
conv2d(WithBiasAdd<T>(bias_add_args), context, input, filter, output);
|
||||
break;
|
||||
case FusedComputationType::kBiasAddWithRelu:
|
||||
conv2d(WithBiasAddAndRelu<T>(bias_add_args), context, input, filter,
|
||||
output);
|
||||
break;
|
||||
|
||||
case FusedComputationType::kBiasAddWithRelu:
|
||||
OP_REQUIRES_OK(context, InitBiasAddArgs(context, &bias_add));
|
||||
conv2d(WithBiasAddAndRelu<T>(bias_add.bias_add_data), context, input,
|
||||
filter, output);
|
||||
break;
|
||||
|
||||
case FusedComputationType::kBiasAddWithRelu6:
|
||||
OP_REQUIRES_OK(context, InitBiasAddArgs(context, &bias_add));
|
||||
conv2d(WithBiasAddAndRelu6<T>(bias_add.bias_add_data), context, input,
|
||||
filter, output);
|
||||
conv2d(WithBiasAddAndRelu6<T>(bias_add_args), context, input, filter,
|
||||
output);
|
||||
break;
|
||||
|
||||
case FusedComputationType::kBiasAddWithElu:
|
||||
OP_REQUIRES_OK(context, InitBiasAddArgs(context, &bias_add));
|
||||
conv2d(WithBiasAddAndElu<T>(bias_add.bias_add_data), context, input,
|
||||
filter, output);
|
||||
conv2d(WithBiasAddAndElu<T>(bias_add_args), context, input, filter,
|
||||
output);
|
||||
break;
|
||||
|
||||
case FusedComputationType::kFusedBatchNorm:
|
||||
OP_REQUIRES_OK(context,
|
||||
InitFusedBatchNormArgs(context, fusion_args.epsilon,
|
||||
&fused_batch_norm));
|
||||
conv2d(WithFusedBatchNorm<T>(fusion_args.epsilon,
|
||||
fused_batch_norm.scaling_factor.data(),
|
||||
fused_batch_norm.offset_data,
|
||||
fused_batch_norm.estimated_mean_data),
|
||||
context, input, filter, output);
|
||||
conv2d(
|
||||
WithFusedBatchNorm<T>(fusion_args.epsilon, fused_batch_norm_args),
|
||||
context, input, filter, output);
|
||||
break;
|
||||
|
||||
case FusedComputationType::kFusedBatchNormWithRelu:
|
||||
OP_REQUIRES_OK(context,
|
||||
InitFusedBatchNormArgs(context, fusion_args.epsilon,
|
||||
&fused_batch_norm));
|
||||
conv2d(WithFusedBatchNormAndRelu<T>(
|
||||
fusion_args.epsilon, fused_batch_norm.scaling_factor.data(),
|
||||
fused_batch_norm.offset_data,
|
||||
fused_batch_norm.estimated_mean_data),
|
||||
conv2d(WithFusedBatchNormAndRelu<T>(fusion_args.epsilon,
|
||||
fused_batch_norm_args),
|
||||
context, input, filter, output);
|
||||
break;
|
||||
|
||||
case FusedComputationType::kFusedBatchNormWithRelu6:
|
||||
OP_REQUIRES_OK(context,
|
||||
InitFusedBatchNormArgs(context, fusion_args.epsilon,
|
||||
&fused_batch_norm));
|
||||
conv2d(WithFusedBatchNormAndRelu6<T>(
|
||||
fusion_args.epsilon, fused_batch_norm.scaling_factor.data(),
|
||||
fused_batch_norm.offset_data,
|
||||
fused_batch_norm.estimated_mean_data),
|
||||
conv2d(WithFusedBatchNormAndRelu6<T>(fusion_args.epsilon,
|
||||
fused_batch_norm_args),
|
||||
context, input, filter, output);
|
||||
break;
|
||||
|
||||
case FusedComputationType::kFusedBatchNormWithElu:
|
||||
OP_REQUIRES_OK(context,
|
||||
InitFusedBatchNormArgs(context, fusion_args.epsilon,
|
||||
&fused_batch_norm));
|
||||
conv2d(WithFusedBatchNormAndElu<T>(
|
||||
fusion_args.epsilon, fused_batch_norm.scaling_factor.data(),
|
||||
fused_batch_norm.offset_data,
|
||||
fused_batch_norm.estimated_mean_data),
|
||||
conv2d(WithFusedBatchNormAndElu<T>(fusion_args.epsilon,
|
||||
fused_batch_norm_args),
|
||||
context, input, filter, output);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
struct BiasAddArgs {
|
||||
const T* bias_add_data = nullptr;
|
||||
};
|
||||
|
||||
struct FusedBatchNormArgs {
|
||||
const T* scale_data = nullptr;
|
||||
const T* offset_data = nullptr;
|
||||
const T* estimated_mean_data = nullptr;
|
||||
const T* estimated_variance_data = nullptr;
|
||||
|
||||
// Precomputed expression:
|
||||
// scaling_factor = (estimated_variance + epsilon).rsqrt() * scale
|
||||
Eigen::Tensor<T, 1, Eigen::RowMajor> scaling_factor;
|
||||
};
|
||||
|
||||
#define TF_REQUIRES(EXP, STATUS) \
|
||||
if (!TF_PREDICT_TRUE(EXP)) return (STATUS)
|
||||
|
||||
void InitDataPtr(const Tensor& tensor, const T** ptr) const {
|
||||
*ptr = reinterpret_cast<const T*>(tensor.tensor_data().data());
|
||||
}
|
||||
|
||||
Status InitBiasAddArgs(OpKernelContext* context, BiasAddArgs* args) const {
|
||||
// Bias of the following dimensions: [ output_depth ]
|
||||
const Tensor& bias = context->input(2);
|
||||
|
||||
TF_REQUIRES(bias.dims() == 1,
|
||||
errors::InvalidArgument("bias must be 1-dimensional",
|
||||
bias.shape().DebugString()));
|
||||
|
||||
InitDataPtr(bias, &args->bias_add_data);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status InitFusedBatchNormArgs(OpKernelContext* context, float epsilon,
|
||||
FusedBatchNormArgs* args) const {
|
||||
const Tensor& scale = context->input(2);
|
||||
const Tensor& offset = context->input(3);
|
||||
const Tensor& estimated_mean = context->input(4);
|
||||
const Tensor& estimated_variance = context->input(5);
|
||||
|
||||
TF_REQUIRES(scale.dims() == 1,
|
||||
errors::InvalidArgument("scale must be 1-dimensional",
|
||||
scale.shape().DebugString()));
|
||||
TF_REQUIRES(offset.dims() == 1,
|
||||
errors::InvalidArgument("offset must be 1-dimensional",
|
||||
offset.shape().DebugString()));
|
||||
TF_REQUIRES(estimated_mean.dims() == 1,
|
||||
errors::InvalidArgument("estimated_mean must be 1-dimensional",
|
||||
estimated_mean.shape().DebugString()));
|
||||
TF_REQUIRES(
|
||||
estimated_variance.dims() == 1,
|
||||
errors::InvalidArgument("estimated_variance must be 1-dimensional",
|
||||
estimated_variance.shape().DebugString()));
|
||||
|
||||
InitDataPtr(scale, &args->scale_data);
|
||||
InitDataPtr(offset, &args->offset_data);
|
||||
InitDataPtr(estimated_mean, &args->estimated_mean_data);
|
||||
InitDataPtr(estimated_variance, &args->estimated_variance_data);
|
||||
|
||||
// Precompute scaling factor once for all output blocks (kernels).
|
||||
args->scaling_factor =
|
||||
(estimated_variance.flat<T>() + static_cast<T>(epsilon)).rsqrt() *
|
||||
scale.flat<T>();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#undef TF_REQUIRES
|
||||
};
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
@ -894,66 +619,33 @@ class FusedConv2DOp : public OpKernel {
|
||||
use_cudnn_ &= CanUseCudnn();
|
||||
cudnn_use_autotune_ = CudnnUseAutotune();
|
||||
|
||||
// 'fused_ops' and 'num_args' attributes are specified by the Grappler
|
||||
// Remapper optimizer (see grappler/optimizers/remapper.cc).
|
||||
|
||||
std::vector<string> fused_ops;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("fused_ops", &fused_ops));
|
||||
OP_REQUIRES(context, !fused_ops.empty(),
|
||||
errors::InvalidArgument(
|
||||
"Fused Conv2D must have at least one fused op."));
|
||||
|
||||
int num_args;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("num_args", &num_args));
|
||||
|
||||
// TODO(ezhulenev): Add support for fusion element-wise op chains defined
|
||||
// at runtime, e.g. Relu+Sqrt+Tanh+etc.
|
||||
|
||||
using FCT = FusedComputationType;
|
||||
std::vector<std::pair<FusedConv2DPattern, FusedComputationType>> mappings =
|
||||
{{{{"BiasAdd"}, true}, FCT::kBiasAdd},
|
||||
{{{"BiasAdd", "Relu"}, false}, FCT::kBiasAddWithRelu},
|
||||
{{{"BiasAdd", "Relu6"}, true}, FCT::kBiasAddWithRelu6},
|
||||
{{{"BiasAdd", "Elu"}, true}, FCT::kBiasAddWithElu},
|
||||
{{{"FusedBatchNorm"}, true}, FCT::kFusedBatchNorm},
|
||||
{{{"FusedBatchNorm", "Relu"}, true}, FCT::kFusedBatchNormWithRelu},
|
||||
{{{"FusedBatchNorm", "Relu6"}, true}, FCT::kFusedBatchNormWithRelu6},
|
||||
{{{"FusedBatchNorm", "Elu"}, true}, FCT::kFusedBatchNormWithElu}};
|
||||
|
||||
// Match op fusion to one of hte supported patterns.
|
||||
for (const auto& mapping : mappings) {
|
||||
const FusedConv2DPattern& pattern = mapping.first;
|
||||
if (FusedOpsMatchAndSupportedOnDevice(fused_ops, pattern.fused_ops,
|
||||
pattern.cpu_only)) {
|
||||
fused_computation_ = mapping.second;
|
||||
}
|
||||
}
|
||||
if (fused_computation_ == FusedComputationType::kUndefined) {
|
||||
OP_REQUIRES(context, false,
|
||||
errors::Unimplemented("Fusion is not implemented: [",
|
||||
absl::StrJoin(fused_ops, ","), "]"));
|
||||
std::vector<FusedComputationPattern> patterns;
|
||||
if (std::is_same<Device, CPUDevice>::value) {
|
||||
patterns = {
|
||||
{FCT::kBiasAdd, {"BiasAdd"}},
|
||||
{FCT::kBiasAddWithRelu, {"BiasAdd", "Relu"}},
|
||||
{FCT::kBiasAddWithRelu6, {"BiasAdd", "Relu6"}},
|
||||
{FCT::kBiasAddWithElu, {"BiasAdd", "Elu"}},
|
||||
{FCT::kFusedBatchNorm, {"FusedBatchNorm"}},
|
||||
{FCT::kFusedBatchNormWithRelu, {"FusedBatchNorm", "Relu"}},
|
||||
{FCT::kFusedBatchNormWithRelu6, {"FusedBatchNorm", "Relu6"}},
|
||||
{FCT::kFusedBatchNormWithElu, {"FusedBatchNorm", "Elu"}},
|
||||
};
|
||||
}
|
||||
|
||||
// Depending on a picked fusion type validate fusion-specific arguments.
|
||||
if (fused_computation_ == FusedComputationType::kBiasAdd ||
|
||||
fused_computation_ == FusedComputationType::kBiasAddWithRelu ||
|
||||
fused_computation_ == FusedComputationType::kBiasAddWithRelu6 ||
|
||||
fused_computation_ == FusedComputationType::kBiasAddWithElu) {
|
||||
OP_REQUIRES(context, num_args == 1,
|
||||
errors::InvalidArgument(
|
||||
"Fused Conv2D must have one extra argument: bias."));
|
||||
// NOTE(ezhulenev): CuDNN `cudnnConvolutionBiasActivationForward` supports
|
||||
// identity activation function, it in theory should allow to fuse
|
||||
// convolution with BiasAdd, but in practice it doesn't work, cuDNN ignores
|
||||
// this parameter and always does Relu activation.
|
||||
if (std::is_same<Device, GPUDevice>::value) {
|
||||
patterns = {{FCT::kBiasAddWithRelu, {"BiasAdd", "Relu"}}};
|
||||
}
|
||||
|
||||
if (fused_computation_ == FusedComputationType::kFusedBatchNorm ||
|
||||
fused_computation_ == FusedComputationType::kFusedBatchNormWithRelu ||
|
||||
fused_computation_ == FusedComputationType::kFusedBatchNormWithRelu6 ||
|
||||
fused_computation_ == FusedComputationType::kFusedBatchNormWithElu) {
|
||||
OP_REQUIRES(
|
||||
context, num_args == 4,
|
||||
errors::InvalidArgument("Fused FusedBatchNorm must have four extra "
|
||||
"arguments: scale, offset, mean, variance."));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon_));
|
||||
}
|
||||
OP_REQUIRES_OK(context, InitializeFusedComputation(
|
||||
context, "Conv2D", patterns,
|
||||
&fused_computation_, &fused_computation_args_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
@ -995,36 +687,19 @@ class FusedConv2DOp : public OpKernel {
|
||||
return;
|
||||
}
|
||||
|
||||
FusedComputationArgs args;
|
||||
args.epsilon = epsilon_;
|
||||
|
||||
LaunchFusedConv2DOp<Device, T>()(context, use_cudnn_, cudnn_use_autotune_,
|
||||
input, filter, fused_computation_, args,
|
||||
params_, dimensions, output);
|
||||
input, filter, fused_computation_,
|
||||
fused_computation_args_, params_,
|
||||
dimensions, output);
|
||||
}
|
||||
|
||||
private:
|
||||
struct FusedConv2DPattern {
|
||||
std::vector<string> fused_ops;
|
||||
bool cpu_only;
|
||||
};
|
||||
|
||||
bool FusedOpsMatchAndSupportedOnDevice(const std::vector<string>& fused_ops,
|
||||
const std::vector<string>& expected,
|
||||
bool cpu_only) const {
|
||||
if (std::is_same<Device, GPUDevice>::value && cpu_only) {
|
||||
return false;
|
||||
}
|
||||
return fused_ops == expected;
|
||||
}
|
||||
|
||||
Conv2DParameters params_;
|
||||
bool use_cudnn_;
|
||||
bool cudnn_use_autotune_;
|
||||
|
||||
FusedComputationType fused_computation_ = FusedComputationType::kUndefined;
|
||||
|
||||
float epsilon_; // Used only in FusedBatchNorm fusion
|
||||
FusedComputationArgs fused_computation_args_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(FusedConv2DOp);
|
||||
};
|
||||
|
@ -997,8 +997,6 @@ TYPED_TEST_P(FusedConv2DWithBiasOpTest, SpatialConvolution) {
|
||||
this->VerifyConv2DWithBias(filter_size, filter_count);
|
||||
}
|
||||
|
||||
// Relu --------------------------------------------------------------------- //
|
||||
|
||||
TYPED_TEST_P(FusedConv2DWithBiasOpTest, OneByOneConvolutionAndActivation) {
|
||||
const int filter_size = 1;
|
||||
const int filter_count = 12;
|
||||
|
88
tensorflow/core/kernels/fused_eigen_output_kernels.cc
Normal file
88
tensorflow/core/kernels/fused_eigen_output_kernels.cc
Normal file
@ -0,0 +1,88 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/fused_eigen_output_kernels.h"
|
||||
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/strings/substitute.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
Status InitializeFusedComputation(
|
||||
OpKernelConstruction* context, const string& kernel_name,
|
||||
const std::vector<FusedComputationPattern>& patterns,
|
||||
FusedComputationType* fused_computation,
|
||||
FusedComputationArgs* fused_computation_args) {
|
||||
// 'fused_ops' and 'num_args' attributes are specified by the Grappler
|
||||
// Remapper optimizer (see grappler/optimizers/remapper.cc).
|
||||
|
||||
std::vector<string> fused_ops;
|
||||
TF_RETURN_IF_ERROR(context->GetAttr("fused_ops", &fused_ops));
|
||||
if (fused_ops.empty()) {
|
||||
return errors::InvalidArgument("Fused ", kernel_name,
|
||||
" must have at least one fused op.");
|
||||
}
|
||||
|
||||
int num_args;
|
||||
TF_RETURN_IF_ERROR(context->GetAttr("num_args", &num_args));
|
||||
|
||||
// TODO(ezhulenev): Add support for fusion element-wise op chains defined
|
||||
// at runtime, e.g. Relu+Sqrt+Tanh+etc.
|
||||
|
||||
// Reset fused computation type.
|
||||
*fused_computation = FusedComputationType::kUndefined;
|
||||
|
||||
// Match op fusion to one of the supported patterns.
|
||||
for (const auto& pattern : patterns) {
|
||||
if (fused_ops == pattern.fused_ops) {
|
||||
*fused_computation = pattern.fused_computation;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (*fused_computation == FusedComputationType::kUndefined) {
|
||||
return errors::Unimplemented("Fusion is not implemented: [",
|
||||
absl::StrJoin(fused_ops, ","), "]");
|
||||
}
|
||||
|
||||
// Depending on a picked fusion type validate fusion-specific arguments.
|
||||
if (*fused_computation == FusedComputationType::kBiasAdd ||
|
||||
*fused_computation == FusedComputationType::kBiasAddWithRelu ||
|
||||
*fused_computation == FusedComputationType::kBiasAddWithRelu6 ||
|
||||
*fused_computation == FusedComputationType::kBiasAddWithElu) {
|
||||
if (num_args != 1) {
|
||||
return errors::InvalidArgument(
|
||||
"Fused ", kernel_name,
|
||||
" with BiasAdd must have one extra argument: bias.");
|
||||
}
|
||||
}
|
||||
|
||||
if (*fused_computation == FusedComputationType::kFusedBatchNorm ||
|
||||
*fused_computation == FusedComputationType::kFusedBatchNormWithRelu ||
|
||||
*fused_computation == FusedComputationType::kFusedBatchNormWithRelu6 ||
|
||||
*fused_computation == FusedComputationType::kFusedBatchNormWithElu) {
|
||||
if (num_args != 4) {
|
||||
return errors::InvalidArgument(
|
||||
"Fused ", kernel_name,
|
||||
" with FusedBatchNorm must have four extra arguments: scale, offset, "
|
||||
"mean, variance.");
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
context->GetAttr("epsilon", &fused_computation_args->epsilon));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
331
tensorflow/core/kernels/fused_eigen_output_kernels.h
Normal file
331
tensorflow/core/kernels/fused_eigen_output_kernels.h
Normal file
@ -0,0 +1,331 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Output kernels for fusing computation into Eigen Tensor contractions:
|
||||
// (1) FusedConv2DOp
|
||||
// (2) FusedMatMulOp
|
||||
//
|
||||
// Supported fused computations:
|
||||
// (1) {Conv2D/MatMul} + BiasAdd + <Activation>
|
||||
// (2) {Conv2D/MatMul} + FusedBatchNorm + <Activation>
|
||||
//
|
||||
// Activation: Relu, Relu6, Elu, etc...
|
||||
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_FUSED_EIGEN_OUTPUT_KERNELS_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_FUSED_EIGEN_OUTPUT_KERNELS_H_
|
||||
|
||||
#include "third_party/eigen3/Eigen/Core"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/kernels/eigen_contraction_kernel.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
enum class FusedComputationType {
|
||||
kUndefined,
|
||||
kBiasAdd,
|
||||
kBiasAddWithRelu,
|
||||
kBiasAddWithRelu6,
|
||||
kBiasAddWithElu,
|
||||
kFusedBatchNorm,
|
||||
kFusedBatchNormWithRelu,
|
||||
kFusedBatchNormWithRelu6,
|
||||
kFusedBatchNormWithElu
|
||||
};
|
||||
|
||||
// We have to pass around additional arguments for all possible fusion types.
|
||||
struct FusedComputationArgs {
|
||||
float epsilon = 0.0; // Used by `FusedBatchNorm` fusion only
|
||||
};
|
||||
|
||||
struct FusedComputationPattern {
|
||||
FusedComputationType fused_computation;
|
||||
std::vector<string> fused_ops;
|
||||
};
|
||||
|
||||
// Parse attributes from the kernel construction context, and verifies that they
|
||||
// specify valid fused computation pattern.
|
||||
Status InitializeFusedComputation(
|
||||
OpKernelConstruction* context, const string& kernel_name,
|
||||
const std::vector<FusedComputationPattern>& patterns,
|
||||
FusedComputationType* fused_computation,
|
||||
FusedComputationArgs* fused_computation_args);
|
||||
|
||||
// Type alias for the tensor contraction output mapper.
|
||||
template <typename Scalar, typename StorageIndex>
|
||||
using ContractionOutputMapper =
|
||||
Eigen::internal::blas_data_mapper<Scalar, StorageIndex, Eigen::ColMajor>;
|
||||
|
||||
// Returns input expression without any transformations.
|
||||
struct Identity {
|
||||
template <typename XprType>
|
||||
static auto apply(XprType expr) -> XprType {
|
||||
return expr;
|
||||
};
|
||||
};
|
||||
|
||||
// Applies `Relu` to the passed input expression.
|
||||
struct Relu {
|
||||
template <typename XprType>
|
||||
static auto apply(XprType expr)
|
||||
-> decltype(expr.cwiseMax(std::declval<typename XprType::Scalar>())) {
|
||||
return expr.cwiseMax(static_cast<typename XprType::Scalar>(0));
|
||||
};
|
||||
};
|
||||
|
||||
// Applies `Relu6` to the passed input expression.
|
||||
struct Relu6 {
|
||||
template <typename XprType>
|
||||
static auto apply(XprType expr)
|
||||
-> decltype(expr.cwiseMax(std::declval<typename XprType::Scalar>())
|
||||
.cwiseMin(std::declval<typename XprType::Scalar>())) {
|
||||
return expr.cwiseMax(static_cast<typename XprType::Scalar>(0))
|
||||
.cwiseMin(static_cast<typename XprType::Scalar>(6));
|
||||
};
|
||||
};
|
||||
|
||||
// Applies `Elu` to the passed input expression.
|
||||
struct Elu {
|
||||
template <typename XprType>
|
||||
static auto apply(XprType expr) -> decltype(
|
||||
(expr < std::declval<typename XprType::Scalar>())
|
||||
.select(expr.exp() -
|
||||
expr.constant(std::declval<typename XprType::Scalar>()),
|
||||
expr)) {
|
||||
return (expr < static_cast<typename XprType::Scalar>(0))
|
||||
.select(expr.exp() -
|
||||
expr.constant(static_cast<typename XprType::Scalar>(1)),
|
||||
expr);
|
||||
};
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct BiasAddArgs {
|
||||
const T* bias_add_data = nullptr;
|
||||
|
||||
static bool IsSupported(FusedComputationType fusion) {
|
||||
return fusion == FusedComputationType::kBiasAdd ||
|
||||
fusion == FusedComputationType::kBiasAddWithRelu ||
|
||||
fusion == FusedComputationType::kBiasAddWithRelu6 ||
|
||||
fusion == FusedComputationType::kBiasAddWithElu;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct FusedBatchNormArgs {
|
||||
const T* scale_data = nullptr;
|
||||
const T* offset_data = nullptr;
|
||||
const T* estimated_mean_data = nullptr;
|
||||
const T* estimated_variance_data = nullptr;
|
||||
|
||||
// Precomputed expression:
|
||||
// scaling_factor = (estimated_variance + epsilon).rsqrt() * scale
|
||||
Eigen::Tensor<T, 1, Eigen::RowMajor> scaling_factor;
|
||||
|
||||
static bool IsSupported(FusedComputationType fusion) {
|
||||
return fusion == FusedComputationType::kFusedBatchNorm ||
|
||||
fusion == FusedComputationType::kFusedBatchNormWithRelu ||
|
||||
fusion == FusedComputationType::kFusedBatchNormWithRelu6 ||
|
||||
fusion == FusedComputationType::kFusedBatchNormWithElu;
|
||||
}
|
||||
};
|
||||
|
||||
// TensorContraction swaps lhs with rhs, and changes layout from RowMajor
|
||||
// (default in Tensorflow) to ColMajor (preferred in Eigen), and computes matmul
|
||||
// using these tensors.
|
||||
//
|
||||
// (1) Spatial Convolution (see eigen_spatial_convolutions.h):
|
||||
//
|
||||
// TensorContraction output matrix (before reshape) has a ColMajor layout, and
|
||||
// has dimensions:
|
||||
// - rows: output_channels
|
||||
// - cols: all other dimensions
|
||||
//
|
||||
// First element in every column is:
|
||||
// [batch ??, height ??, width ??, out_channel = i]
|
||||
//
|
||||
// We do not know what are the values of the 'batch', 'height', and 'width'
|
||||
// here (if we know original dimensions, they can be computed from 'j').
|
||||
//
|
||||
// Each column of an output block is a continuous slice along the output
|
||||
// channel dimension, so we can use it to efficiently compute any
|
||||
// transformation that depends only on a channel value (e.g. add channel
|
||||
// bias).
|
||||
//
|
||||
// (2) Matrix Multiplication (see matmul_op.cc):
|
||||
//
|
||||
// For the `MxK * KxN` matrix multiplication, output matrix has a `MxN`
|
||||
// dimensions. Each column in output block is a slice of the innermost
|
||||
// dimension of the output matrix starting at offset 'i'.
|
||||
//
|
||||
// Example: In Tensorflow MatMul [8x32] * [32x64], each output block column
|
||||
// will correspond to MatMul output row of size 64 (because Tensorflow uses
|
||||
// row major storage order).
|
||||
|
||||
// Output kernel that fuses BiasAdd operation into the output of tensor
|
||||
// contraction + activation function defined by Activation.
|
||||
template <typename T, typename Activation = Identity>
|
||||
struct BiasAddOutputKernel {
|
||||
explicit BiasAddOutputKernel(const BiasAddArgs<T>& args)
|
||||
: bias_data(args.bias_add_data) {}
|
||||
|
||||
template <typename StorageIndex, typename Scalar>
|
||||
EIGEN_ALWAYS_INLINE void operator()(
|
||||
const ContractionOutputMapper<Scalar, StorageIndex>& output_mapper,
|
||||
const Eigen::TensorContractionParams& params, StorageIndex i,
|
||||
StorageIndex j, StorageIndex num_rows, StorageIndex num_cols) const {
|
||||
DCHECK(params.swapped_arguments);
|
||||
|
||||
const T* bias_base = bias_data + i;
|
||||
typename TTypes<T>::UnalignedConstTensor bias(bias_base, num_rows);
|
||||
|
||||
for (int col = 0; col < num_cols; ++col) {
|
||||
T* output_base = &output_mapper(0, col);
|
||||
typename TTypes<T>::UnalignedTensor output(output_base, num_rows);
|
||||
const auto expr = output + bias;
|
||||
output = Activation::template apply<decltype(expr)>(expr);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
const T* bias_data;
|
||||
};
|
||||
|
||||
// Output kernel that fuses FusedBatchNorm operation into the output of tensor
|
||||
// contraction + activation function defined by Activation.
|
||||
template <typename T, typename Activation = Identity>
|
||||
struct FusedBatchNormOutputKernel {
|
||||
FusedBatchNormOutputKernel(T epsilon, const FusedBatchNormArgs<T>& args)
|
||||
: epsilon(epsilon),
|
||||
scaling_factor_data(args.scaling_factor.data()),
|
||||
offset_data(args.offset_data),
|
||||
estimated_mean_data(args.estimated_mean_data) {}
|
||||
|
||||
template <typename StorageIndex, typename Scalar>
|
||||
EIGEN_ALWAYS_INLINE void operator()(
|
||||
const ContractionOutputMapper<Scalar, StorageIndex>& output_mapper,
|
||||
const Eigen::TensorContractionParams& params, StorageIndex i,
|
||||
StorageIndex j, StorageIndex num_rows, StorageIndex num_cols) const {
|
||||
DCHECK(params.swapped_arguments);
|
||||
|
||||
const T* scaling_factor_base = scaling_factor_data + i;
|
||||
const T* offset_base = offset_data + i;
|
||||
const T* mean_base = estimated_mean_data + i;
|
||||
|
||||
typename TTypes<T>::UnalignedConstTensor scaling_factor(scaling_factor_base,
|
||||
num_rows);
|
||||
typename TTypes<T>::UnalignedConstTensor offset(offset_base, num_rows);
|
||||
typename TTypes<T>::UnalignedConstTensor mean(mean_base, num_rows);
|
||||
|
||||
for (int col = 0; col < num_cols; ++col) {
|
||||
T* output_base = &output_mapper(0, col);
|
||||
typename TTypes<T>::UnalignedTensor output(output_base, num_rows);
|
||||
|
||||
auto scaled = (output - mean) * scaling_factor;
|
||||
auto shifted = scaled + offset;
|
||||
|
||||
output = Activation::template apply<decltype(shifted)>(shifted);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
T epsilon;
|
||||
const T* scaling_factor_data;
|
||||
const T* offset_data;
|
||||
const T* estimated_mean_data;
|
||||
};
|
||||
|
||||
// Type aliases for the output kernels, purely for the sake of better launch
|
||||
// dispatching code readability.
|
||||
template <typename T>
|
||||
using WithBiasAdd = BiasAddOutputKernel<T>;
|
||||
template <typename T>
|
||||
using WithBiasAddAndRelu = BiasAddOutputKernel<T, Relu>;
|
||||
template <typename T>
|
||||
using WithBiasAddAndRelu6 = BiasAddOutputKernel<T, Relu6>;
|
||||
template <typename T>
|
||||
using WithBiasAddAndElu = BiasAddOutputKernel<T, Elu>;
|
||||
template <typename T>
|
||||
using WithFusedBatchNorm = FusedBatchNormOutputKernel<T>;
|
||||
template <typename T>
|
||||
using WithFusedBatchNormAndRelu = FusedBatchNormOutputKernel<T, Relu>;
|
||||
template <typename T>
|
||||
using WithFusedBatchNormAndRelu6 = FusedBatchNormOutputKernel<T, Relu6>;
|
||||
template <typename T>
|
||||
using WithFusedBatchNormAndElu = FusedBatchNormOutputKernel<T, Elu>;
|
||||
|
||||
template <typename T>
|
||||
Status InitBiasAddArgs(OpKernelContext* context, BiasAddArgs<T>* args) {
|
||||
// Bias of the following dimensions: [ output_depth ]
|
||||
const Tensor& bias = context->input(2);
|
||||
|
||||
if (bias.dims() != 1)
|
||||
return errors::InvalidArgument("bias must be 1-dimensional",
|
||||
bias.shape().DebugString());
|
||||
|
||||
const auto data_ptr = [](const Tensor& tensor) -> const T* {
|
||||
return reinterpret_cast<const T*>(tensor.tensor_data().data());
|
||||
};
|
||||
|
||||
args->bias_add_data = data_ptr(bias);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status InitFusedBatchNormArgs(OpKernelContext* context, float epsilon,
|
||||
FusedBatchNormArgs<T>* args) {
|
||||
const Tensor& scale = context->input(2);
|
||||
const Tensor& offset = context->input(3);
|
||||
const Tensor& estimated_mean = context->input(4);
|
||||
const Tensor& estimated_variance = context->input(5);
|
||||
|
||||
if (scale.dims() != 1)
|
||||
return errors::InvalidArgument("scale must be 1-dimensional",
|
||||
scale.shape().DebugString());
|
||||
if (offset.dims() != 1)
|
||||
return errors::InvalidArgument("offset must be 1-dimensional",
|
||||
offset.shape().DebugString());
|
||||
if (estimated_mean.dims() != 1)
|
||||
return errors::InvalidArgument("estimated_mean must be 1-dimensional",
|
||||
estimated_mean.shape().DebugString());
|
||||
if (estimated_variance.dims() != 1)
|
||||
return errors::InvalidArgument("estimated_variance must be 1-dimensional",
|
||||
estimated_variance.shape().DebugString());
|
||||
|
||||
const auto data_ptr = [](const Tensor& tensor) -> const T* {
|
||||
return reinterpret_cast<const T*>(tensor.tensor_data().data());
|
||||
};
|
||||
|
||||
args->scale_data = data_ptr(scale);
|
||||
args->offset_data = data_ptr(offset);
|
||||
args->estimated_mean_data = data_ptr(estimated_mean);
|
||||
args->estimated_variance_data = data_ptr(estimated_variance);
|
||||
|
||||
// Precompute scaling factor once for all output blocks (kernels).
|
||||
args->scaling_factor =
|
||||
(estimated_variance.flat<T>() + static_cast<T>(epsilon)).rsqrt() *
|
||||
scale.flat<T>();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_FUSED_EIGEN_OUTPUT_KERNELS_H_
|
180
tensorflow/core/kernels/matmul_op_fused.cc
Normal file
180
tensorflow/core/kernels/matmul_op_fused.cc
Normal file
@ -0,0 +1,180 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Implements matmul operations with other kernels baked into the
|
||||
// processing, to optimize latency and memory usage:
|
||||
// - MatMul + BiasAdd + <Activation>
|
||||
// - MatMul + FusedBatchNorm + <Activation>
|
||||
//
|
||||
// Activation: Relu, Relu6, Elu, etc...
|
||||
//
|
||||
// Currently supported only on CPU device.
|
||||
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_MATMUL_OP_FUSED_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_MATMUL_OP_FUSED_H_
|
||||
|
||||
#define USE_EIGEN_TENSOR
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/bounds_check.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/kernels/eigen_contraction_kernel.h"
|
||||
#include "tensorflow/core/kernels/fill_functor.h"
|
||||
#include "tensorflow/core/kernels/fused_eigen_output_kernels.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
template <typename Device, typename T>
|
||||
struct LaunchFusedMatMulOp {
|
||||
void operator()(
|
||||
OpKernelContext* context, const Tensor& a, const Tensor& b,
|
||||
const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
|
||||
FusedComputationType fusion, const FusedComputationArgs& fusion_args,
|
||||
Tensor* output);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct LaunchFusedMatMulOp<CPUDevice, T> {
|
||||
void operator()(
|
||||
OpKernelContext* context, const Tensor& a, const Tensor& b,
|
||||
const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
|
||||
FusedComputationType fusion, const FusedComputationArgs& fusion_args,
|
||||
Tensor* output) {
|
||||
auto lhs = a.matrix<T>();
|
||||
auto rhs = b.matrix<T>();
|
||||
auto out = output->matrix<T>();
|
||||
|
||||
auto& d = context->eigen_device<CPUDevice>();
|
||||
|
||||
BiasAddArgs<T> bias_add_args;
|
||||
if (BiasAddArgs<T>::IsSupported(fusion)) {
|
||||
OP_REQUIRES_OK(context, InitBiasAddArgs(context, &bias_add_args));
|
||||
}
|
||||
|
||||
switch (fusion) {
|
||||
case FusedComputationType::kBiasAdd:
|
||||
out.device(d) =
|
||||
lhs.contract(rhs, dim_pair, WithBiasAdd<T>(bias_add_args));
|
||||
break;
|
||||
case FusedComputationType::kUndefined:
|
||||
OP_REQUIRES_OK(context, errors::Internal("Fusion type is undefined"));
|
||||
break;
|
||||
default:
|
||||
OP_REQUIRES_OK(context,
|
||||
errors::Internal("Fusion type is not supported"));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
class FusedMatMulOp : public OpKernel {
|
||||
public:
|
||||
explicit FusedMatMulOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("transpose_a", &transpose_a_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("transpose_b", &transpose_b_));
|
||||
|
||||
std::vector<FusedComputationPattern> patterns;
|
||||
|
||||
using FCT = FusedComputationType;
|
||||
if (std::is_same<Device, CPUDevice>::value) {
|
||||
patterns = {{FCT::kBiasAdd, {"BiasAdd"}}};
|
||||
}
|
||||
|
||||
OP_REQUIRES_OK(context, InitializeFusedComputation(
|
||||
context, "MatMul", patterns,
|
||||
&fused_computation_, &fused_computation_args_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
const Tensor& a = ctx->input(0);
|
||||
const Tensor& b = ctx->input(1);
|
||||
|
||||
// Check that the dimensions of the two matrices are valid.
|
||||
OP_REQUIRES(
|
||||
ctx, TensorShapeUtils::IsMatrix(a.shape()),
|
||||
errors::InvalidArgument("In[0] is not a matrix. Instead it has shape ",
|
||||
a.shape().DebugString()));
|
||||
OP_REQUIRES(
|
||||
ctx, TensorShapeUtils::IsMatrix(b.shape()),
|
||||
errors::InvalidArgument("In[1] is not a matrix. Instead it has shape ",
|
||||
b.shape().DebugString()));
|
||||
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
|
||||
dim_pair[0].first = transpose_a_ ? 0 : 1;
|
||||
dim_pair[0].second = transpose_b_ ? 1 : 0;
|
||||
|
||||
OP_REQUIRES(
|
||||
ctx, a.dim_size(dim_pair[0].first) == b.dim_size(dim_pair[0].second),
|
||||
errors::InvalidArgument(
|
||||
"Matrix size-incompatible: In[0]: ", a.shape().DebugString(),
|
||||
", In[1]: ", b.shape().DebugString()));
|
||||
int a_dim_remaining = 1 - dim_pair[0].first;
|
||||
int b_dim_remaining = 1 - dim_pair[0].second;
|
||||
TensorShape out_shape(
|
||||
{a.dim_size(a_dim_remaining), b.dim_size(b_dim_remaining)});
|
||||
Tensor* out = nullptr;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
|
||||
|
||||
if (out->NumElements() == 0) {
|
||||
// If a has shape [0, x] or b has shape [x, 0], the output shape
|
||||
// is a 0-element matrix, so there is nothing to do.
|
||||
return;
|
||||
}
|
||||
|
||||
if (a.NumElements() == 0 || b.NumElements() == 0) {
|
||||
// If a has shape [x, 0] and b has shape [0, y], the
|
||||
// output shape is [x, y] where x and y are non-zero, so we fill
|
||||
// the output with zeros.
|
||||
functor::SetZeroFunctor<Device, T> f;
|
||||
f(ctx->eigen_device<Device>(), out->flat<T>());
|
||||
return;
|
||||
}
|
||||
|
||||
auto launch = LaunchFusedMatMulOp<Device, T>();
|
||||
launch(ctx, a, b, dim_pair, fused_computation_, fused_computation_args_,
|
||||
out);
|
||||
}
|
||||
|
||||
private:
|
||||
bool transpose_a_;
|
||||
bool transpose_b_;
|
||||
|
||||
FusedComputationType fused_computation_ = FusedComputationType::kUndefined;
|
||||
FusedComputationArgs fused_computation_args_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(FusedMatMulOp);
|
||||
};
|
||||
|
||||
// Registration of the CPU implementations.
|
||||
#define REGISTER_FUSED_CPU_MATMUL(T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("_FusedMatMul").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
|
||||
FusedMatMulOp<CPUDevice, T>);
|
||||
|
||||
TF_CALL_float(REGISTER_FUSED_CPU_MATMUL);
|
||||
|
||||
#undef REGISTER_FUSED_CPU_MATMUL
|
||||
|
||||
} // namespace tensorflow
|
||||
#endif // TENSORFLOW_CORE_KERNELS_MATMUL_OP_FUSED_H_
|
@ -13,13 +13,242 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
template <typename T>
|
||||
class FusedMatMulOpTest : public OpsTestBase {
|
||||
protected:
|
||||
using BiasAddGraphRunner =
|
||||
std::function<void(const Tensor& lhs_data, const Tensor& rhs_data,
|
||||
const Tensor& bias_data, Tensor* out)>;
|
||||
|
||||
// Runs a Tensorflow graph defined by the root scope, and fetches the result
|
||||
// of 'fetch' node into the output Tensor. Optional `fetch_node` parameter
|
||||
// allows to define a fetch node directly using a NodeDef for the ops that are
|
||||
// not supported by the C++ Api.
|
||||
void RunAndFetch(const tensorflow::Scope& root, const string& fetch,
|
||||
Tensor* output, bool allow_gpu_device,
|
||||
const NodeDef* fetch_node = nullptr) {
|
||||
tensorflow::GraphDef graph;
|
||||
TF_ASSERT_OK(root.ToGraphDef(&graph));
|
||||
|
||||
if (fetch_node) {
|
||||
*graph.add_node() = *fetch_node;
|
||||
}
|
||||
|
||||
// We really want to make sure that graph executed exactly as we passed it
|
||||
// to the session, so we disable various optimizations.
|
||||
tensorflow::SessionOptions session_options;
|
||||
|
||||
// Disable common runtime constant folding.
|
||||
session_options.config.mutable_graph_options()
|
||||
->mutable_optimizer_options()
|
||||
->set_opt_level(OptimizerOptions::L0);
|
||||
|
||||
// Disable Grappler optimizations for tests.
|
||||
tensorflow::RewriterConfig* cfg =
|
||||
session_options.config.mutable_graph_options()
|
||||
->mutable_rewrite_options();
|
||||
cfg->set_constant_folding(tensorflow::RewriterConfig::OFF);
|
||||
cfg->set_layout_optimizer(tensorflow::RewriterConfig::OFF);
|
||||
cfg->set_remapping(tensorflow::RewriterConfig::OFF);
|
||||
|
||||
std::unique_ptr<tensorflow::Session> session(
|
||||
tensorflow::NewSession(session_options));
|
||||
|
||||
std::vector<DeviceAttributes> available_devices;
|
||||
TF_ASSERT_OK(session->ListDevices(&available_devices))
|
||||
<< "Failed to get available session devices";
|
||||
|
||||
// Check if session has an available GPU device.
|
||||
const bool has_gpu_device =
|
||||
absl::c_any_of(available_devices, [](const DeviceAttributes& device) {
|
||||
return device.device_type() == DEVICE_GPU;
|
||||
});
|
||||
|
||||
// If fused computation implemented only for CPU, in this test we don't want
|
||||
// to compare GPU vs CPU numbers, so place all nodes on CPU in this case.
|
||||
const bool place_all_on_gpu = allow_gpu_device && has_gpu_device;
|
||||
|
||||
const string device = place_all_on_gpu ? "/device:GPU:0" : "/device:CPU:0";
|
||||
for (NodeDef& mutable_node : *graph.mutable_node()) {
|
||||
mutable_node.set_device(device);
|
||||
}
|
||||
|
||||
TF_ASSERT_OK(session->Create(graph));
|
||||
|
||||
std::vector<Tensor> unfused_tensors;
|
||||
TF_ASSERT_OK(session->Run({}, {fetch}, {}, &unfused_tensors));
|
||||
|
||||
*output = unfused_tensors[0];
|
||||
}
|
||||
|
||||
void RunMatMulWithBias(const Tensor& lhs_data, const Tensor& rhs_data,
|
||||
const Tensor& bias_data, Tensor* output,
|
||||
bool allow_gpu_device = false) {
|
||||
Scope root = tensorflow::Scope::NewRootScope();
|
||||
|
||||
ops::MatMul matmul = ops::MatMul(
|
||||
root.WithOpName("matmul"),
|
||||
ops::Const(root.WithOpName("lhs"), Input::Initializer(lhs_data)),
|
||||
ops::Const(root.WithOpName("rhs"), Input::Initializer(rhs_data)));
|
||||
|
||||
ops::BiasAdd with_bias = ops::BiasAdd(
|
||||
root.WithOpName("with_bias"), matmul,
|
||||
ops::Const(root.WithOpName("bias"), Input::Initializer(bias_data)));
|
||||
|
||||
RunAndFetch(root, "with_bias", output, allow_gpu_device);
|
||||
}
|
||||
|
||||
void RunFusedMatMulOp(const Tensor& lhs_data, const Tensor& rhs_data,
|
||||
const std::vector<Tensor>& args_data,
|
||||
const std::vector<string>& fused_ops, Tensor* output,
|
||||
bool allow_gpu_device = false) {
|
||||
Scope root = tensorflow::Scope::NewRootScope();
|
||||
|
||||
DataType dtype = DataTypeToEnum<T>::v();
|
||||
int num_args = static_cast<int>(args_data.size());
|
||||
|
||||
Output lhs =
|
||||
ops::Const(root.WithOpName("lhs"), Input::Initializer(lhs_data));
|
||||
Output rhs =
|
||||
ops::Const(root.WithOpName("rhs"), Input::Initializer(rhs_data));
|
||||
|
||||
std::vector<NodeDefBuilder::NodeOut> args;
|
||||
for (int i = 0; i < num_args; ++i) {
|
||||
Output arg = ops::Const(root.WithOpName(absl::StrCat("arg", i)),
|
||||
Input::Initializer(args_data[i]));
|
||||
args.emplace_back(arg.name(), 0, dtype);
|
||||
}
|
||||
|
||||
NodeDef fused_matmul;
|
||||
TF_EXPECT_OK(NodeDefBuilder("fused_matmul", "_FusedMatMul")
|
||||
.Input({lhs.name(), 0, dtype})
|
||||
.Input({rhs.name(), 0, dtype})
|
||||
.Input(args)
|
||||
.Attr("num_args", num_args)
|
||||
.Attr("T", dtype)
|
||||
.Attr("fused_ops", fused_ops)
|
||||
.Finalize(&fused_matmul));
|
||||
|
||||
RunAndFetch(root, fused_matmul.name(), output, allow_gpu_device,
|
||||
&fused_matmul);
|
||||
}
|
||||
|
||||
void VerifyBiasAddTensorsNear(int m, int k, int n,
|
||||
const BiasAddGraphRunner& run_default,
|
||||
const BiasAddGraphRunner& run_fused) {
|
||||
DataType dtype = DataTypeToEnum<T>::v();
|
||||
|
||||
Tensor lhs(dtype, {m, k});
|
||||
lhs.flat<T>() = lhs.flat<T>().setRandom();
|
||||
|
||||
// Add some negative values to filter to properly test Relu.
|
||||
Tensor rhs(dtype, {k, n});
|
||||
rhs.flat<T>() = rhs.flat<T>().setRandom();
|
||||
rhs.flat<T>() -= rhs.flat<T>().constant(static_cast<T>(0.5f));
|
||||
|
||||
// Bias added to the inner dimension.
|
||||
const int bias_size = n;
|
||||
Tensor bias(dtype, {bias_size});
|
||||
bias.flat<T>() = bias.flat<T>().setRandom();
|
||||
bias.flat<T>() += bias.flat<T>().constant(static_cast<T>(0.5f));
|
||||
|
||||
Tensor matmul;
|
||||
Tensor fused_matmul;
|
||||
|
||||
run_default(lhs, rhs, bias, &matmul);
|
||||
run_fused(lhs, rhs, bias, &fused_matmul);
|
||||
|
||||
ASSERT_EQ(matmul.dtype(), fused_matmul.dtype());
|
||||
ASSERT_EQ(matmul.shape(), fused_matmul.shape());
|
||||
|
||||
test::ExpectClose(matmul, fused_matmul, /*atol=*/1e-5);
|
||||
}
|
||||
|
||||
// Verifies that computing MatMul+BiasAdd in a graph is identical to
|
||||
// FusedMatMul.
|
||||
void VerifyMatMulWithBias(int m, int k, int n) {
|
||||
const BiasAddGraphRunner run_default =
|
||||
[this](const Tensor& input_data, const Tensor& filter_data,
|
||||
const Tensor& bias_data, Tensor* out) {
|
||||
RunMatMulWithBias(input_data, filter_data, bias_data, out);
|
||||
};
|
||||
|
||||
const BiasAddGraphRunner run_fused = [this](const Tensor& input_data,
|
||||
const Tensor& filter_data,
|
||||
const Tensor& bias_data,
|
||||
Tensor* out) {
|
||||
RunFusedMatMulOp(input_data, filter_data, {bias_data}, {"BiasAdd"}, out);
|
||||
};
|
||||
|
||||
VerifyBiasAddTensorsNear(m, k, n, run_default, run_fused);
|
||||
}
|
||||
};
|
||||
|
||||
// MatMul with BatchNorm can be tested only with `T=float`, because default
|
||||
// `FusedBatchNorm` kernel supports only floats for scale, mean and variance.
|
||||
|
||||
template <typename T>
|
||||
class FusedMatMulWithBiasOpTest : public FusedMatMulOpTest<T> {};
|
||||
template <typename T>
|
||||
class FusedMatMulWithBatchNormOpTest : public FusedMatMulOpTest<T> {};
|
||||
|
||||
TYPED_TEST_SUITE_P(FusedMatMulWithBiasOpTest);
|
||||
TYPED_TEST_SUITE_P(FusedMatMulWithBatchNormOpTest);
|
||||
|
||||
// -------------------------------------------------------------------------- //
|
||||
// MatMul + BiasAdd + {Activation} //
|
||||
// -------------------------------------------------------------------------- //
|
||||
|
||||
TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul256x256x256) {
|
||||
this->VerifyMatMulWithBias(256, 256, 256);
|
||||
}
|
||||
|
||||
TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul1x256x256) {
|
||||
this->VerifyMatMulWithBias(1, 256, 256);
|
||||
}
|
||||
|
||||
TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul256x256x1) {
|
||||
this->VerifyMatMulWithBias(256, 256, 1);
|
||||
}
|
||||
|
||||
TYPED_TEST_P(FusedMatMulWithBiasOpTest, MatMul1x256x1) {
|
||||
this->VerifyMatMulWithBias(1, 256, 1);
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------- //
|
||||
// MatMul + BiasAdd + {Activation} //
|
||||
// -------------------------------------------------------------------------- //
|
||||
|
||||
// TODO(ezhulenev): Add tests for FusedBatchNorm.
|
||||
|
||||
REGISTER_TYPED_TEST_SUITE_P(FusedMatMulWithBiasOpTest, //
|
||||
MatMul256x256x256, //
|
||||
MatMul1x256x256, //
|
||||
MatMul256x256x1, //
|
||||
MatMul1x256x1);
|
||||
|
||||
// TODO(ezhulenev): Add support for more data types.
|
||||
using FusedBiasAddDataTypes = ::testing::Types<float>;
|
||||
INSTANTIATE_TYPED_TEST_SUITE_P(Test, FusedMatMulWithBiasOpTest,
|
||||
FusedBiasAddDataTypes);
|
||||
|
||||
//----------------------------------------------------------------------------//
|
||||
// Performance benchmarks are below. //
|
||||
//----------------------------------------------------------------------------//
|
||||
|
||||
template <typename T>
|
||||
static Graph* Matmul(int m, int k, int n, bool transpose_a, bool transpose_b,
|
||||
DataType type) {
|
||||
|
@ -856,6 +856,25 @@ REGISTER_OP("SparseMatMul")
|
||||
.Attr("Tb: {float, bfloat16} = DT_FLOAT")
|
||||
.SetShapeFn(shape_inference::MatMulShape);
|
||||
|
||||
REGISTER_OP("_FusedMatMul")
|
||||
.Input("a: T")
|
||||
.Input("b: T")
|
||||
.Input("args: num_args * T")
|
||||
.Output("product: T")
|
||||
.Attr("transpose_a: bool = false")
|
||||
.Attr("transpose_b: bool = false")
|
||||
.Attr("T: {float}")
|
||||
.Attr("num_args: int >= 0")
|
||||
.Attr("fused_ops: list(string) = []")
|
||||
// Attributes for the FusedBatchNorm ----------- //
|
||||
.Attr("epsilon: float = 0.0001")
|
||||
// --------------------------------------------- //
|
||||
.SetShapeFn(shape_inference::MatMulShape)
|
||||
.Doc(R"doc(
|
||||
*NOTE*: Do not invoke this operator directly in Python. Grappler is
|
||||
expected to create these operators.
|
||||
)doc");
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// For operations where the output is a reduction function along some
|
||||
|
Loading…
Reference in New Issue
Block a user