diff --git a/tensorflow/core/kernels/argmax_op.cc b/tensorflow/core/kernels/argmax_op.cc index 63bef41a272..1478797227c 100644 --- a/tensorflow/core/kernels/argmax_op.cc +++ b/tensorflow/core/kernels/argmax_op.cc @@ -154,6 +154,7 @@ class ArgMinOp ArgMinOp); TF_CALL_REAL_NUMBER_TYPES(REGISTER_ARGMAX); +TF_CALL_bool(REGISTER_ARGMAX); #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) @@ -194,7 +195,9 @@ namespace functor { extern template struct ArgMin; TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS); +TF_CALL_bool(DECLARE_GPU_SPECS); TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_CLASS); +TF_CALL_bool(DECLARE_GPU_CLASS); #undef DECLARE_GPU_SPECS #undef DECLARE_GPU_CLASS @@ -233,6 +236,7 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_CLASS); ArgMinOp); TF_CALL_GPU_NUMBER_TYPES(REGISTER_ARGMAX_GPU); +TF_CALL_bool(REGISTER_ARGMAX_GPU); #undef REGISTER_ARGMAX_GPU diff --git a/tensorflow/core/kernels/argmax_op_gpu.cu.cc b/tensorflow/core/kernels/argmax_op_gpu.cu.cc index bd7c4b4027c..659048e6a1f 100644 --- a/tensorflow/core/kernels/argmax_op_gpu.cu.cc +++ b/tensorflow/core/kernels/argmax_op_gpu.cu.cc @@ -32,6 +32,7 @@ typedef Eigen::GpuDevice GPUDevice; template struct functor::ArgMin; TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPEC); +TF_CALL_bool(DEFINE_GPU_SPEC); } // end namespace tensorflow diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index b4af577f45a..f4559041af9 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -1096,7 +1096,7 @@ REGISTER_OP("ArgMax") .Input("input: T") .Input("dimension: Tidx") .Output("output: output_type") - .Attr("T: numbertype") + .Attr("T: {numbertype, bool}") .Attr("Tidx: {int32, int64} = DT_INT32") .Attr("output_type: {int32, int64} = DT_INT64") .SetShapeFn(ArgOpShape); @@ -1105,7 +1105,7 @@ REGISTER_OP("ArgMin") .Input("input: T") .Input("dimension: Tidx") .Output("output: output_type") - .Attr("T: numbertype") + .Attr("T: {numbertype, bool}") .Attr("Tidx: {int32, int64} = DT_INT32") .Attr("output_type: {int32, int64} = DT_INT64") .SetShapeFn(ArgOpShape); diff --git a/tensorflow/python/kernel_tests/argmax_op_test.py b/tensorflow/python/kernel_tests/argmax_op_test.py index 023766c899d..8a6ac74849c 100644 --- a/tensorflow/python/kernel_tests/argmax_op_test.py +++ b/tensorflow/python/kernel_tests/argmax_op_test.py @@ -61,7 +61,7 @@ class ArgMaxTest(test.TestCase): self._testArg(method, x, axis, expected_values, False, expected_err_re) def _testBasic(self, dtype): - x = np.arange(200, dtype=dtype) + x = np.arange(200, dtype=np.float32).astype(np.bool_).astype(dtype) np.random.shuffle(x) # Check that argmin and argmax match numpy along the primary axis @@ -78,7 +78,9 @@ class ArgMaxTest(test.TestCase): def _testDim(self, dtype): shape = (3, 2, 4, 5, 6, 3, 7) - x = np.arange(functools.reduce(lambda x, y: x * y, shape), dtype=dtype) + x = np.arange( + functools.reduce(lambda x, y: x * y, shape), + dtype=np.float32).astype(dtype) np.random.shuffle(x) x = x.reshape(shape) @@ -124,6 +126,11 @@ class ArgMaxTest(test.TestCase): self._testTieBreaking(np.int64) self._testDim(np.int64) + def testBool(self): + self._testBasic(np.bool_) + self._testTieBreaking(np.bool_) + self._testDim(np.bool_) + def testEmpty(self): with self.cached_session(): for op in math_ops.argmin, math_ops.argmax: