Change conv_input_scale and side_input_scale from attributes to inputs for improved flexibility, in fused_conv2d_bias_activation op.
PiperOrigin-RevId: 168311988
This commit is contained in:
parent
4b4e10f9c8
commit
60f15462be
@ -41,6 +41,7 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
namespace {
|
||||||
typedef Eigen::GpuDevice GPUDevice;
|
typedef Eigen::GpuDevice GPUDevice;
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@ -66,6 +67,7 @@ template <>
|
|||||||
struct Int8x4ToInt32<int8> {
|
struct Int8x4ToInt32<int8> {
|
||||||
using type = int32;
|
using type = int32;
|
||||||
};
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
// T is the element type of the conv_input, filter and side_input tensors.
|
// T is the element type of the conv_input, filter and side_input tensors.
|
||||||
// BiasType is the element type of the bias tensor, which can be different.
|
// BiasType is the element type of the bias tensor, which can be different.
|
||||||
@ -73,9 +75,20 @@ struct Int8x4ToInt32<int8> {
|
|||||||
template <typename Device, typename T, typename BiasType, typename ScaleType>
|
template <typename Device, typename T, typename BiasType, typename ScaleType>
|
||||||
class FusedConv2DBiasActivationOp : public OpKernel {
|
class FusedConv2DBiasActivationOp : public OpKernel {
|
||||||
public:
|
public:
|
||||||
|
enum InputIndexes {
|
||||||
|
kConvInput = 0,
|
||||||
|
kFilter,
|
||||||
|
kBias,
|
||||||
|
kSideInput,
|
||||||
|
kConvInputScale,
|
||||||
|
kSideInputScale,
|
||||||
|
kNumInputs
|
||||||
|
};
|
||||||
|
|
||||||
explicit FusedConv2DBiasActivationOp(OpKernelConstruction* context)
|
explicit FusedConv2DBiasActivationOp(OpKernelConstruction* context)
|
||||||
: OpKernel(context) {
|
: OpKernel(context) {
|
||||||
string data_format_str, filter_format_str;
|
string data_format_str, filter_format_str;
|
||||||
|
CHECK_EQ(kNumInputs, context->num_inputs());
|
||||||
OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str));
|
OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str));
|
||||||
OP_REQUIRES(context, FormatFromString(data_format_str, &data_format_),
|
OP_REQUIRES(context, FormatFromString(data_format_str, &data_format_),
|
||||||
errors::InvalidArgument("Invalid data format"));
|
errors::InvalidArgument("Invalid data format"));
|
||||||
@ -125,13 +138,6 @@ class FusedConv2DBiasActivationOp : public OpKernel {
|
|||||||
errors::InvalidArgument("Current implementation only supports "
|
errors::InvalidArgument("Current implementation only supports "
|
||||||
"RELU as the activation function."));
|
"RELU as the activation function."));
|
||||||
cudnn_use_autotune_ = CudnnUseAutotune();
|
cudnn_use_autotune_ = CudnnUseAutotune();
|
||||||
float conv_input_scale_flt, side_input_scale_flt;
|
|
||||||
OP_REQUIRES_OK(context,
|
|
||||||
context->GetAttr("conv_input_scale", &conv_input_scale_flt));
|
|
||||||
OP_REQUIRES_OK(context,
|
|
||||||
context->GetAttr("side_input_scale", &side_input_scale_flt));
|
|
||||||
conv_input_scale_ = conv_input_scale_flt;
|
|
||||||
side_input_scale_ = side_input_scale_flt;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CheckShape(const Tensor& tensor, const string& tensor_name) {
|
Status CheckShape(const Tensor& tensor, const string& tensor_name) {
|
||||||
@ -154,22 +160,30 @@ class FusedConv2DBiasActivationOp : public OpKernel {
|
|||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
// The conv_input tensor is one of the following formats:
|
// The conv_input tensor is one of the following formats:
|
||||||
// NHWC, NCHW, NCHW_VECT_C.
|
// NHWC, NCHW, NCHW_VECT_C.
|
||||||
const Tensor& conv_input = context->input(0);
|
const Tensor& conv_input = context->input(kConvInput);
|
||||||
OP_REQUIRES_OK(context, CheckShape(conv_input, "conv_input"));
|
OP_REQUIRES_OK(context, CheckShape(conv_input, "conv_input"));
|
||||||
|
|
||||||
// The filter tensor is one of the following formats:
|
// The filter tensor is one of the following formats:
|
||||||
// HWIO, OIHW, OIHW_VECT_I.
|
// HWIO, OIHW, OIHW_VECT_I.
|
||||||
const Tensor& filter = context->input(1);
|
const Tensor& filter = context->input(kFilter);
|
||||||
OP_REQUIRES_OK(context, CheckShape(filter, "filter"));
|
OP_REQUIRES_OK(context, CheckShape(filter, "filter"));
|
||||||
|
|
||||||
// Input bias is a 1-D tensor, with size matching output depth.
|
// Input bias is a 1-D tensor, with size matching output depth.
|
||||||
const Tensor& bias = context->input(2);
|
const Tensor& bias = context->input(kBias);
|
||||||
OP_REQUIRES_OK(context, CheckShape(bias, "conv_input"));
|
OP_REQUIRES_OK(context, CheckShape(bias, "conv_input"));
|
||||||
|
|
||||||
|
const Tensor& conv_input_scale_tensor = context->input(kConvInputScale);
|
||||||
|
const Tensor& side_input_scale_tensor = context->input(kSideInputScale);
|
||||||
|
|
||||||
|
auto conv_input_scale = *reinterpret_cast<const ScaleType*>(
|
||||||
|
conv_input_scale_tensor.tensor_data().data());
|
||||||
|
auto side_input_scale = *reinterpret_cast<const ScaleType*>(
|
||||||
|
side_input_scale_tensor.tensor_data().data());
|
||||||
|
|
||||||
// If side_input_scale != 0, then side_input is not ignored and
|
// If side_input_scale != 0, then side_input is not ignored and
|
||||||
// has the same type and dimensions as the output.
|
// has the same type and dimensions as the output.
|
||||||
const Tensor& side_input = context->input(3);
|
const Tensor& side_input = context->input(kSideInput);
|
||||||
if (side_input_scale_ != 0) {
|
if (side_input_scale != 0) {
|
||||||
OP_REQUIRES_OK(context, CheckShape(side_input, "side_input"));
|
OP_REQUIRES_OK(context, CheckShape(side_input, "side_input"));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -212,10 +226,10 @@ class FusedConv2DBiasActivationOp : public OpKernel {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
launcher_.launch(context, cudnn_use_autotune_, conv_input,
|
launcher_.launch(context, cudnn_use_autotune_, conv_input, conv_input_scale,
|
||||||
conv_input_scale_, filter, stride_rows_, stride_cols_,
|
filter, stride_rows_, stride_cols_, eigen_padding_type_,
|
||||||
eigen_padding_type_, side_input, side_input_scale_, bias,
|
side_input, side_input_scale, bias, activation_mode_,
|
||||||
activation_mode_, data_format_, filter_format_, output);
|
data_format_, filter_format_, output);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -225,8 +239,6 @@ class FusedConv2DBiasActivationOp : public OpKernel {
|
|||||||
ActivationMode activation_mode_;
|
ActivationMode activation_mode_;
|
||||||
TensorFormat data_format_;
|
TensorFormat data_format_;
|
||||||
FilterTensorFormat filter_format_;
|
FilterTensorFormat filter_format_;
|
||||||
ScaleType conv_input_scale_;
|
|
||||||
ScaleType side_input_scale_;
|
|
||||||
LaunchFusedConv2DBiasActivationOp<Device, T, BiasType, ScaleType> launcher_;
|
LaunchFusedConv2DBiasActivationOp<Device, T, BiasType, ScaleType> launcher_;
|
||||||
bool cudnn_use_autotune_;
|
bool cudnn_use_autotune_;
|
||||||
|
|
||||||
@ -579,14 +591,18 @@ REGISTER_KERNEL_BUILDER(
|
|||||||
Name("FusedConv2DBiasActivation")
|
Name("FusedConv2DBiasActivation")
|
||||||
.Device(DEVICE_GPU)
|
.Device(DEVICE_GPU)
|
||||||
.TypeConstraint<float>("T")
|
.TypeConstraint<float>("T")
|
||||||
.TypeConstraint<float>("Tbias"),
|
.TypeConstraint<float>("Tbias")
|
||||||
|
.HostMemory("conv_input_scale")
|
||||||
|
.HostMemory("side_input_scale"),
|
||||||
FusedConv2DBiasActivationOp<GPUDevice, float, float, float>);
|
FusedConv2DBiasActivationOp<GPUDevice, float, float, float>);
|
||||||
|
|
||||||
REGISTER_KERNEL_BUILDER(
|
REGISTER_KERNEL_BUILDER(
|
||||||
Name("FusedConv2DBiasActivation")
|
Name("FusedConv2DBiasActivation")
|
||||||
.Device(DEVICE_GPU)
|
.Device(DEVICE_GPU)
|
||||||
.TypeConstraint<qint8>("T")
|
.TypeConstraint<qint8>("T")
|
||||||
.TypeConstraint<float>("Tbias"),
|
.TypeConstraint<float>("Tbias")
|
||||||
|
.HostMemory("conv_input_scale")
|
||||||
|
.HostMemory("side_input_scale"),
|
||||||
FusedConv2DBiasActivationOp<GPUDevice, qint8, float, float>);
|
FusedConv2DBiasActivationOp<GPUDevice, qint8, float, float>);
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA
|
||||||
|
@ -42,11 +42,11 @@ REGISTER_OP("FusedConv2DBiasActivation")
|
|||||||
.Input("filter: T")
|
.Input("filter: T")
|
||||||
.Input("bias: Tbias")
|
.Input("bias: Tbias")
|
||||||
.Input("side_input: T")
|
.Input("side_input: T")
|
||||||
|
.Input("conv_input_scale: float")
|
||||||
|
.Input("side_input_scale: float")
|
||||||
.Output("output: T")
|
.Output("output: T")
|
||||||
.Attr("T: {float, half, qint8}")
|
.Attr("T: {float, half, qint8}")
|
||||||
.Attr("Tbias: {float, half}")
|
.Attr("Tbias: {float, half}")
|
||||||
.Attr("conv_input_scale: float = 1.0")
|
|
||||||
.Attr("side_input_scale: float = 0.0")
|
|
||||||
.Attr("strides: list(int)")
|
.Attr("strides: list(int)")
|
||||||
.Attr(GetPaddingAttrString())
|
.Attr(GetPaddingAttrString())
|
||||||
.Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'")
|
.Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'")
|
||||||
@ -97,6 +97,11 @@ REGISTER_OP("FusedConv2DBiasActivation")
|
|||||||
TF_RETURN_IF_ERROR(c->Merge(side_input_shape, c->output(0), &unused));
|
TF_RETURN_IF_ERROR(c->Merge(side_input_shape, c->output(0), &unused));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check that conv_input_scale and side_input_scale are scalar tensors.
|
||||||
|
ShapeHandle unused;
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
})
|
})
|
||||||
.Doc(R"doc(
|
.Doc(R"doc(
|
||||||
@ -117,15 +122,15 @@ REGISTER_OP("FusedConv2DBiasActivation")
|
|||||||
side_input: A tensor with format as specified by `data_format` (see below).
|
side_input: A tensor with format as specified by `data_format` (see below).
|
||||||
This tensor will be ignored and can be [] if side_input_scale == 0.
|
This tensor will be ignored and can be [] if side_input_scale == 0.
|
||||||
Otherwise, the size of each dimension must match the `output` tensor.
|
Otherwise, the size of each dimension must match the `output` tensor.
|
||||||
|
conv_input_scale: scalar float value to be multiplied by `conv_input`.
|
||||||
|
(conceptually.. in reality it is applied after convolution).
|
||||||
|
side_input_scale: scalar float value to be multiplied by `side_input`.
|
||||||
output: A tensor with format as specified by `data_format` (see below).
|
output: A tensor with format as specified by `data_format` (see below).
|
||||||
The dimension sizes are determined automatically based on other inputs
|
The dimension sizes are determined automatically based on other inputs
|
||||||
and attributes.
|
and attributes.
|
||||||
T: The element data type of `conv_input`, `side_input` and `output` tensors.
|
T: The element data type of `conv_input`, `side_input` and `output` tensors.
|
||||||
Note: must match with the `data_format`.
|
Note: must match with the `data_format`.
|
||||||
Tbias: The element data type of `bias`.
|
Tbias: The element data type of `bias`.
|
||||||
conv_input_scale: scalar float value to be multiplied by `conv_input`.
|
|
||||||
(conceptually.. in reality it is applied after convolution).
|
|
||||||
side_input_scale: scalar float value to be multiplied by `side_input`.
|
|
||||||
strides: 1-D tensor of length 4. The stride of the sliding window for each
|
strides: 1-D tensor of length 4. The stride of the sliding window for each
|
||||||
dimension of `input`. The dimension order is determined by the value of
|
dimension of `input`. The dimension order is determined by the value of
|
||||||
`data_format`, see below for details.
|
`data_format`, see below for details.
|
||||||
|
@ -97,11 +97,11 @@ def fused_conv2d_bias_activation(conv_input,
|
|||||||
conv_input,
|
conv_input,
|
||||||
filter,
|
filter,
|
||||||
bias,
|
bias,
|
||||||
|
side_input,
|
||||||
|
conv_input_scale,
|
||||||
|
side_input_scale,
|
||||||
padding=padding,
|
padding=padding,
|
||||||
strides=strides,
|
strides=strides,
|
||||||
conv_input_scale=conv_input_scale,
|
|
||||||
side_input_scale=side_input_scale,
|
|
||||||
side_input=side_input,
|
|
||||||
activation_mode=activation_mode,
|
activation_mode=activation_mode,
|
||||||
data_format=data_format,
|
data_format=data_format,
|
||||||
filter_format=filter_format,
|
filter_format=filter_format,
|
||||||
|
Loading…
Reference in New Issue
Block a user