Fix tf_cnn_benchmarks runtime failures
PiperOrigin-RevId: 247771407
This commit is contained in:
parent
4330405781
commit
2db22ecd6c
@ -697,7 +697,8 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
|
|||||||
// adds necessary annotations to the graph.
|
// adds necessary annotations to the graph.
|
||||||
// TODO(ezhulenev): Convert in other direction for fp16?
|
// TODO(ezhulenev): Convert in other direction for fp16?
|
||||||
const TensorFormat compute_data_format =
|
const TensorFormat compute_data_format =
|
||||||
compute_in_nhwc && data_format == FORMAT_NHWC ? FORMAT_NHWC : FORMAT_NCHW;
|
(compute_in_nhwc && data_format == FORMAT_NHWC) ? FORMAT_NHWC
|
||||||
|
: FORMAT_NCHW;
|
||||||
|
|
||||||
VLOG(3) << "Compute Conv2D with cuDNN:"
|
VLOG(3) << "Compute Conv2D with cuDNN:"
|
||||||
<< " data_format=" << ToString(data_format)
|
<< " data_format=" << ToString(data_format)
|
||||||
@ -853,7 +854,7 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
|
|||||||
|
|
||||||
Tensor transformed_filter;
|
Tensor transformed_filter;
|
||||||
|
|
||||||
const auto transform_filter = [&](FilterTensorFormat dst_format) -> void {
|
const auto transform_filter = [&](FilterTensorFormat dst_format) -> Status {
|
||||||
VLOG(4) << "Transform filter tensor from " << ToString(FORMAT_HWIO)
|
VLOG(4) << "Transform filter tensor from " << ToString(FORMAT_HWIO)
|
||||||
<< " to " << ToString(dst_format);
|
<< " to " << ToString(dst_format);
|
||||||
|
|
||||||
@ -864,18 +865,20 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
|
|||||||
: TensorShape({filter.dim_size(3), filter.dim_size(0),
|
: TensorShape({filter.dim_size(3), filter.dim_size(0),
|
||||||
filter.dim_size(1), filter.dim_size(2)});
|
filter.dim_size(1), filter.dim_size(2)});
|
||||||
|
|
||||||
OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, dst_shape,
|
TF_RETURN_IF_ERROR(ctx->allocate_temp(DataTypeToEnum<T>::value, dst_shape,
|
||||||
&transformed_filter));
|
&transformed_filter));
|
||||||
functor::TransformFilter<GPUDevice, T, int, 4>()(
|
functor::TransformFilter<GPUDevice, T, int, 4>()(
|
||||||
ctx->eigen_device<GPUDevice>(), dst_format,
|
ctx->eigen_device<GPUDevice>(), dst_format,
|
||||||
To32Bit(filter.tensor<T, 4>()),
|
To32Bit(filter.tensor<T, 4>()),
|
||||||
To32Bit(transformed_filter.tensor<T, 4>()));
|
To32Bit(transformed_filter.tensor<T, 4>()));
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
};
|
};
|
||||||
|
|
||||||
if (compute_data_format == FORMAT_NCHW) {
|
if (compute_data_format == FORMAT_NCHW) {
|
||||||
transform_filter(FORMAT_OIHW);
|
OP_REQUIRES_OK(ctx, transform_filter(FORMAT_OIHW));
|
||||||
} else if (compute_data_format == FORMAT_NHWC) {
|
} else if (compute_data_format == FORMAT_NHWC) {
|
||||||
transform_filter(FORMAT_OHWI);
|
OP_REQUIRES_OK(ctx, transform_filter(FORMAT_OHWI));
|
||||||
} else {
|
} else {
|
||||||
ctx->SetStatus(errors::InvalidArgument("Invalid compute data format: ",
|
ctx->SetStatus(errors::InvalidArgument("Invalid compute data format: ",
|
||||||
ToString(compute_data_format)));
|
ToString(compute_data_format)));
|
||||||
|
Loading…
Reference in New Issue
Block a user