Change conv_input_scale and side_input_scale from attributes to inputs for improved flexibility, in fused_conv2d_bias_activation op.
PiperOrigin-RevId: 168311988
This commit is contained in:
parent
4b4e10f9c8
commit
60f15462be
@ -41,6 +41,7 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
template <typename T>
|
||||
@ -66,6 +67,7 @@ template <>
|
||||
struct Int8x4ToInt32<int8> {
|
||||
using type = int32;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// T is the element type of the conv_input, filter and side_input tensors.
|
||||
// BiasType is the element type of the bias tensor, which can be different.
|
||||
@ -73,9 +75,20 @@ struct Int8x4ToInt32<int8> {
|
||||
template <typename Device, typename T, typename BiasType, typename ScaleType>
|
||||
class FusedConv2DBiasActivationOp : public OpKernel {
|
||||
public:
|
||||
enum InputIndexes {
|
||||
kConvInput = 0,
|
||||
kFilter,
|
||||
kBias,
|
||||
kSideInput,
|
||||
kConvInputScale,
|
||||
kSideInputScale,
|
||||
kNumInputs
|
||||
};
|
||||
|
||||
explicit FusedConv2DBiasActivationOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {
|
||||
string data_format_str, filter_format_str;
|
||||
CHECK_EQ(kNumInputs, context->num_inputs());
|
||||
OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str));
|
||||
OP_REQUIRES(context, FormatFromString(data_format_str, &data_format_),
|
||||
errors::InvalidArgument("Invalid data format"));
|
||||
@ -125,13 +138,6 @@ class FusedConv2DBiasActivationOp : public OpKernel {
|
||||
errors::InvalidArgument("Current implementation only supports "
|
||||
"RELU as the activation function."));
|
||||
cudnn_use_autotune_ = CudnnUseAutotune();
|
||||
float conv_input_scale_flt, side_input_scale_flt;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("conv_input_scale", &conv_input_scale_flt));
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("side_input_scale", &side_input_scale_flt));
|
||||
conv_input_scale_ = conv_input_scale_flt;
|
||||
side_input_scale_ = side_input_scale_flt;
|
||||
}
|
||||
|
||||
Status CheckShape(const Tensor& tensor, const string& tensor_name) {
|
||||
@ -154,22 +160,30 @@ class FusedConv2DBiasActivationOp : public OpKernel {
|
||||
void Compute(OpKernelContext* context) override {
|
||||
// The conv_input tensor is one of the following formats:
|
||||
// NHWC, NCHW, NCHW_VECT_C.
|
||||
const Tensor& conv_input = context->input(0);
|
||||
const Tensor& conv_input = context->input(kConvInput);
|
||||
OP_REQUIRES_OK(context, CheckShape(conv_input, "conv_input"));
|
||||
|
||||
// The filter tensor is one of the following formats:
|
||||
// HWIO, OIHW, OIHW_VECT_I.
|
||||
const Tensor& filter = context->input(1);
|
||||
const Tensor& filter = context->input(kFilter);
|
||||
OP_REQUIRES_OK(context, CheckShape(filter, "filter"));
|
||||
|
||||
// Input bias is a 1-D tensor, with size matching output depth.
|
||||
const Tensor& bias = context->input(2);
|
||||
const Tensor& bias = context->input(kBias);
|
||||
OP_REQUIRES_OK(context, CheckShape(bias, "conv_input"));
|
||||
|
||||
const Tensor& conv_input_scale_tensor = context->input(kConvInputScale);
|
||||
const Tensor& side_input_scale_tensor = context->input(kSideInputScale);
|
||||
|
||||
auto conv_input_scale = *reinterpret_cast<const ScaleType*>(
|
||||
conv_input_scale_tensor.tensor_data().data());
|
||||
auto side_input_scale = *reinterpret_cast<const ScaleType*>(
|
||||
side_input_scale_tensor.tensor_data().data());
|
||||
|
||||
// If side_input_scale != 0, then side_input is not ignored and
|
||||
// has the same type and dimensions as the output.
|
||||
const Tensor& side_input = context->input(3);
|
||||
if (side_input_scale_ != 0) {
|
||||
const Tensor& side_input = context->input(kSideInput);
|
||||
if (side_input_scale != 0) {
|
||||
OP_REQUIRES_OK(context, CheckShape(side_input, "side_input"));
|
||||
}
|
||||
|
||||
@ -212,10 +226,10 @@ class FusedConv2DBiasActivationOp : public OpKernel {
|
||||
return;
|
||||
}
|
||||
|
||||
launcher_.launch(context, cudnn_use_autotune_, conv_input,
|
||||
conv_input_scale_, filter, stride_rows_, stride_cols_,
|
||||
eigen_padding_type_, side_input, side_input_scale_, bias,
|
||||
activation_mode_, data_format_, filter_format_, output);
|
||||
launcher_.launch(context, cudnn_use_autotune_, conv_input, conv_input_scale,
|
||||
filter, stride_rows_, stride_cols_, eigen_padding_type_,
|
||||
side_input, side_input_scale, bias, activation_mode_,
|
||||
data_format_, filter_format_, output);
|
||||
}
|
||||
|
||||
private:
|
||||
@ -225,8 +239,6 @@ class FusedConv2DBiasActivationOp : public OpKernel {
|
||||
ActivationMode activation_mode_;
|
||||
TensorFormat data_format_;
|
||||
FilterTensorFormat filter_format_;
|
||||
ScaleType conv_input_scale_;
|
||||
ScaleType side_input_scale_;
|
||||
LaunchFusedConv2DBiasActivationOp<Device, T, BiasType, ScaleType> launcher_;
|
||||
bool cudnn_use_autotune_;
|
||||
|
||||
@ -579,14 +591,18 @@ REGISTER_KERNEL_BUILDER(
|
||||
Name("FusedConv2DBiasActivation")
|
||||
.Device(DEVICE_GPU)
|
||||
.TypeConstraint<float>("T")
|
||||
.TypeConstraint<float>("Tbias"),
|
||||
.TypeConstraint<float>("Tbias")
|
||||
.HostMemory("conv_input_scale")
|
||||
.HostMemory("side_input_scale"),
|
||||
FusedConv2DBiasActivationOp<GPUDevice, float, float, float>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("FusedConv2DBiasActivation")
|
||||
.Device(DEVICE_GPU)
|
||||
.TypeConstraint<qint8>("T")
|
||||
.TypeConstraint<float>("Tbias"),
|
||||
.TypeConstraint<float>("Tbias")
|
||||
.HostMemory("conv_input_scale")
|
||||
.HostMemory("side_input_scale"),
|
||||
FusedConv2DBiasActivationOp<GPUDevice, qint8, float, float>);
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
|
@ -42,11 +42,11 @@ REGISTER_OP("FusedConv2DBiasActivation")
|
||||
.Input("filter: T")
|
||||
.Input("bias: Tbias")
|
||||
.Input("side_input: T")
|
||||
.Input("conv_input_scale: float")
|
||||
.Input("side_input_scale: float")
|
||||
.Output("output: T")
|
||||
.Attr("T: {float, half, qint8}")
|
||||
.Attr("Tbias: {float, half}")
|
||||
.Attr("conv_input_scale: float = 1.0")
|
||||
.Attr("side_input_scale: float = 0.0")
|
||||
.Attr("strides: list(int)")
|
||||
.Attr(GetPaddingAttrString())
|
||||
.Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'")
|
||||
@ -97,6 +97,11 @@ REGISTER_OP("FusedConv2DBiasActivation")
|
||||
TF_RETURN_IF_ERROR(c->Merge(side_input_shape, c->output(0), &unused));
|
||||
}
|
||||
|
||||
// Check that conv_input_scale and side_input_scale are scalar tensors.
|
||||
ShapeHandle unused;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
|
||||
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
@ -117,15 +122,15 @@ REGISTER_OP("FusedConv2DBiasActivation")
|
||||
side_input: A tensor with format as specified by `data_format` (see below).
|
||||
This tensor will be ignored and can be [] if side_input_scale == 0.
|
||||
Otherwise, the size of each dimension must match the `output` tensor.
|
||||
conv_input_scale: scalar float value to be multiplied by `conv_input`.
|
||||
(conceptually.. in reality it is applied after convolution).
|
||||
side_input_scale: scalar float value to be multiplied by `side_input`.
|
||||
output: A tensor with format as specified by `data_format` (see below).
|
||||
The dimension sizes are determined automatically based on other inputs
|
||||
and attributes.
|
||||
T: The element data type of `conv_input`, `side_input` and `output` tensors.
|
||||
Note: must match with the `data_format`.
|
||||
Tbias: The element data type of `bias`.
|
||||
conv_input_scale: scalar float value to be multiplied by `conv_input`.
|
||||
(conceptually.. in reality it is applied after convolution).
|
||||
side_input_scale: scalar float value to be multiplied by `side_input`.
|
||||
strides: 1-D tensor of length 4. The stride of the sliding window for each
|
||||
dimension of `input`. The dimension order is determined by the value of
|
||||
`data_format`, see below for details.
|
||||
|
@ -97,11 +97,11 @@ def fused_conv2d_bias_activation(conv_input,
|
||||
conv_input,
|
||||
filter,
|
||||
bias,
|
||||
side_input,
|
||||
conv_input_scale,
|
||||
side_input_scale,
|
||||
padding=padding,
|
||||
strides=strides,
|
||||
conv_input_scale=conv_input_scale,
|
||||
side_input_scale=side_input_scale,
|
||||
side_input=side_input,
|
||||
activation_mode=activation_mode,
|
||||
data_format=data_format,
|
||||
filter_format=filter_format,
|
||||
|
Loading…
Reference in New Issue
Block a user