Support argmin/argmax of boolean tensors on CPUs and GPUs.

PiperOrigin-RevId: 306503450
Change-Id: Ia54c1c2bd5a46eec1795a13914129fc601a4d09e
This commit is contained in:
Akshay Modi 2020-04-14 13:29:32 -07:00 committed by TensorFlower Gardener
parent 3e414c752b
commit b0cd75dc99
4 changed files with 16 additions and 4 deletions

View File

@ -154,6 +154,7 @@ class ArgMinOp
ArgMinOp<CPUDevice, type, int32>); ArgMinOp<CPUDevice, type, int32>);
TF_CALL_REAL_NUMBER_TYPES(REGISTER_ARGMAX); TF_CALL_REAL_NUMBER_TYPES(REGISTER_ARGMAX);
TF_CALL_bool(REGISTER_ARGMAX);
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
(defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
@ -194,7 +195,9 @@ namespace functor {
extern template struct ArgMin<GPUDevice, T, int32>; extern template struct ArgMin<GPUDevice, T, int32>;
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS); TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
TF_CALL_bool(DECLARE_GPU_SPECS);
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_CLASS); TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_CLASS);
TF_CALL_bool(DECLARE_GPU_CLASS);
#undef DECLARE_GPU_SPECS #undef DECLARE_GPU_SPECS
#undef DECLARE_GPU_CLASS #undef DECLARE_GPU_CLASS
@ -233,6 +236,7 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_CLASS);
ArgMinOp<GPUDevice, type, int32>); ArgMinOp<GPUDevice, type, int32>);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_ARGMAX_GPU); TF_CALL_GPU_NUMBER_TYPES(REGISTER_ARGMAX_GPU);
TF_CALL_bool(REGISTER_ARGMAX_GPU);
#undef REGISTER_ARGMAX_GPU #undef REGISTER_ARGMAX_GPU

View File

@ -32,6 +32,7 @@ typedef Eigen::GpuDevice GPUDevice;
template struct functor::ArgMin<GPUDevice, T, int32>; template struct functor::ArgMin<GPUDevice, T, int32>;
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPEC); TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPEC);
TF_CALL_bool(DEFINE_GPU_SPEC);
} // end namespace tensorflow } // end namespace tensorflow

View File

@ -1096,7 +1096,7 @@ REGISTER_OP("ArgMax")
.Input("input: T") .Input("input: T")
.Input("dimension: Tidx") .Input("dimension: Tidx")
.Output("output: output_type") .Output("output: output_type")
.Attr("T: numbertype") .Attr("T: {numbertype, bool}")
.Attr("Tidx: {int32, int64} = DT_INT32") .Attr("Tidx: {int32, int64} = DT_INT32")
.Attr("output_type: {int32, int64} = DT_INT64") .Attr("output_type: {int32, int64} = DT_INT64")
.SetShapeFn(ArgOpShape); .SetShapeFn(ArgOpShape);
@ -1105,7 +1105,7 @@ REGISTER_OP("ArgMin")
.Input("input: T") .Input("input: T")
.Input("dimension: Tidx") .Input("dimension: Tidx")
.Output("output: output_type") .Output("output: output_type")
.Attr("T: numbertype") .Attr("T: {numbertype, bool}")
.Attr("Tidx: {int32, int64} = DT_INT32") .Attr("Tidx: {int32, int64} = DT_INT32")
.Attr("output_type: {int32, int64} = DT_INT64") .Attr("output_type: {int32, int64} = DT_INT64")
.SetShapeFn(ArgOpShape); .SetShapeFn(ArgOpShape);

View File

@ -61,7 +61,7 @@ class ArgMaxTest(test.TestCase):
self._testArg(method, x, axis, expected_values, False, expected_err_re) self._testArg(method, x, axis, expected_values, False, expected_err_re)
def _testBasic(self, dtype): def _testBasic(self, dtype):
x = np.arange(200, dtype=dtype) x = np.arange(200, dtype=np.float32).astype(np.bool_).astype(dtype)
np.random.shuffle(x) np.random.shuffle(x)
# Check that argmin and argmax match numpy along the primary axis # Check that argmin and argmax match numpy along the primary axis
@ -78,7 +78,9 @@ class ArgMaxTest(test.TestCase):
def _testDim(self, dtype): def _testDim(self, dtype):
shape = (3, 2, 4, 5, 6, 3, 7) shape = (3, 2, 4, 5, 6, 3, 7)
x = np.arange(functools.reduce(lambda x, y: x * y, shape), dtype=dtype) x = np.arange(
functools.reduce(lambda x, y: x * y, shape),
dtype=np.float32).astype(dtype)
np.random.shuffle(x) np.random.shuffle(x)
x = x.reshape(shape) x = x.reshape(shape)
@ -124,6 +126,11 @@ class ArgMaxTest(test.TestCase):
self._testTieBreaking(np.int64) self._testTieBreaking(np.int64)
self._testDim(np.int64) self._testDim(np.int64)
def testBool(self):
self._testBasic(np.bool_)
self._testTieBreaking(np.bool_)
self._testDim(np.bool_)
def testEmpty(self): def testEmpty(self):
with self.cached_session(): with self.cached_session():
for op in math_ops.argmin, math_ops.argmax: for op in math_ops.argmin, math_ops.argmax: