From 6ab77f3be330714746054eb678c9c4116f300692 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Fri, 17 Jan 2020 00:13:26 +0000 Subject: [PATCH] Add complex number support for tf.extract_image_patches This PR tries to address the issue raised in 35955 where there was no complex number support for tf.extract_image_patches. The op `tf.extract_image_patches` itself could be used in many ways than just image so it makes sense to add complex support. This fix fixes 35955. Signed-off-by: Yong Tang Register GPU types for tf.extract_image_patches Signed-off-by: Yong Tang Fix build failure with GPU Signed-off-by: Yong Tang Update array_ops.cc to only include TF_CALL_NUMBER_TYPES This fix updates array_ops.cc to only include TF_CALL_NUMBER_TYPES for extract_image_patches (realnumbertypes + complex64 + complex128). Signed-off-by: Yong Tang --- tensorflow/core/kernels/extract_image_patches_op.cc | 6 +++--- tensorflow/core/kernels/extract_image_patches_op_gpu.cu.cc | 2 +- tensorflow/core/ops/array_ops.cc | 4 +++- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/tensorflow/core/kernels/extract_image_patches_op.cc b/tensorflow/core/kernels/extract_image_patches_op.cc index 0fc1f567a92..2cc9933965e 100644 --- a/tensorflow/core/kernels/extract_image_patches_op.cc +++ b/tensorflow/core/kernels/extract_image_patches_op.cc @@ -126,7 +126,7 @@ class ExtractImagePatchesOp : public UnaryOp { Name("ExtractImagePatches").Device(DEVICE_CPU).TypeConstraint("T"), \ ExtractImagePatchesOp); -TF_CALL_REAL_NUMBER_TYPES(REGISTER); +TF_CALL_NUMBER_TYPES(REGISTER); #undef REGISTER @@ -145,7 +145,7 @@ namespace functor { typename TTypes::Tensor output); \ extern template struct ExtractImagePatchesForward; -TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); +TF_CALL_GPU_ALL_TYPES(DECLARE_GPU_SPEC); #undef DECLARE_GPU_SPEC @@ -157,7 +157,7 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); Name("ExtractImagePatches").Device(DEVICE_GPU).TypeConstraint("T"), \ ExtractImagePatchesOp); -TF_CALL_GPU_NUMBER_TYPES(REGISTER); +TF_CALL_GPU_ALL_TYPES(REGISTER); #undef REGISTER diff --git a/tensorflow/core/kernels/extract_image_patches_op_gpu.cu.cc b/tensorflow/core/kernels/extract_image_patches_op_gpu.cu.cc index 650c51fc765..e6a49da7fd2 100644 --- a/tensorflow/core/kernels/extract_image_patches_op_gpu.cu.cc +++ b/tensorflow/core/kernels/extract_image_patches_op_gpu.cu.cc @@ -29,7 +29,7 @@ namespace functor { #define REGISTER(T) template struct ExtractImagePatchesForward; -TF_CALL_GPU_NUMBER_TYPES(REGISTER); +TF_CALL_GPU_ALL_TYPES(REGISTER); #undef REGISTER diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index 60efdcb7a73..65c9510a1f2 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -2525,7 +2525,9 @@ REGISTER_OP("ExtractImagePatches") .Attr("ksizes: list(int) >= 4") .Attr("strides: list(int) >= 4") .Attr("rates: list(int) >= 4") - .Attr("T: realnumbertype") + .Attr( + "T: {bfloat16, half, float, double, int8, int16, int32, int64, " + "uint8, uint16, uint32, uint64, complex64, complex128, bool}") .Attr(GetPaddingAttrString()) .SetShapeFn([](InferenceContext* c) { ShapeHandle input_shape;