diff --git a/tensorflow/core/kernels/extract_image_patches_op.cc b/tensorflow/core/kernels/extract_image_patches_op.cc index 0fc1f567a92..2cc9933965e 100644 --- a/tensorflow/core/kernels/extract_image_patches_op.cc +++ b/tensorflow/core/kernels/extract_image_patches_op.cc @@ -126,7 +126,7 @@ class ExtractImagePatchesOp : public UnaryOp { Name("ExtractImagePatches").Device(DEVICE_CPU).TypeConstraint("T"), \ ExtractImagePatchesOp); -TF_CALL_REAL_NUMBER_TYPES(REGISTER); +TF_CALL_NUMBER_TYPES(REGISTER); #undef REGISTER @@ -145,7 +145,7 @@ namespace functor { typename TTypes::Tensor output); \ extern template struct ExtractImagePatchesForward; -TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); +TF_CALL_GPU_ALL_TYPES(DECLARE_GPU_SPEC); #undef DECLARE_GPU_SPEC @@ -157,7 +157,7 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); Name("ExtractImagePatches").Device(DEVICE_GPU).TypeConstraint("T"), \ ExtractImagePatchesOp); -TF_CALL_GPU_NUMBER_TYPES(REGISTER); +TF_CALL_GPU_ALL_TYPES(REGISTER); #undef REGISTER diff --git a/tensorflow/core/kernels/extract_image_patches_op_gpu.cu.cc b/tensorflow/core/kernels/extract_image_patches_op_gpu.cu.cc index 650c51fc765..e6a49da7fd2 100644 --- a/tensorflow/core/kernels/extract_image_patches_op_gpu.cu.cc +++ b/tensorflow/core/kernels/extract_image_patches_op_gpu.cu.cc @@ -29,7 +29,7 @@ namespace functor { #define REGISTER(T) template struct ExtractImagePatchesForward; -TF_CALL_GPU_NUMBER_TYPES(REGISTER); +TF_CALL_GPU_ALL_TYPES(REGISTER); #undef REGISTER diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index 60efdcb7a73..65c9510a1f2 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -2525,7 +2525,9 @@ REGISTER_OP("ExtractImagePatches") .Attr("ksizes: list(int) >= 4") .Attr("strides: list(int) >= 4") .Attr("rates: list(int) >= 4") - .Attr("T: realnumbertype") + .Attr( + "T: {bfloat16, half, float, double, int8, int16, int32, int64, " + "uint8, uint16, uint32, uint64, complex64, complex128, bool}") .Attr(GetPaddingAttrString()) .SetShapeFn([](InferenceContext* c) { ShapeHandle input_shape;