Support argmin/argmax of boolean tensors on CPUs and GPUs.
PiperOrigin-RevId: 306503450 Change-Id: Ia54c1c2bd5a46eec1795a13914129fc601a4d09e
This commit is contained in:
parent
3e414c752b
commit
b0cd75dc99
@ -154,6 +154,7 @@ class ArgMinOp
|
||||
ArgMinOp<CPUDevice, type, int32>);
|
||||
|
||||
TF_CALL_REAL_NUMBER_TYPES(REGISTER_ARGMAX);
|
||||
TF_CALL_bool(REGISTER_ARGMAX);
|
||||
|
||||
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
|
||||
(defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
|
||||
@ -194,7 +195,9 @@ namespace functor {
|
||||
extern template struct ArgMin<GPUDevice, T, int32>;
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
|
||||
TF_CALL_bool(DECLARE_GPU_SPECS);
|
||||
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_CLASS);
|
||||
TF_CALL_bool(DECLARE_GPU_CLASS);
|
||||
|
||||
#undef DECLARE_GPU_SPECS
|
||||
#undef DECLARE_GPU_CLASS
|
||||
@ -233,6 +236,7 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_CLASS);
|
||||
ArgMinOp<GPUDevice, type, int32>);
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_ARGMAX_GPU);
|
||||
TF_CALL_bool(REGISTER_ARGMAX_GPU);
|
||||
|
||||
#undef REGISTER_ARGMAX_GPU
|
||||
|
||||
|
@ -32,6 +32,7 @@ typedef Eigen::GpuDevice GPUDevice;
|
||||
template struct functor::ArgMin<GPUDevice, T, int32>;
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPEC);
|
||||
TF_CALL_bool(DEFINE_GPU_SPEC);
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
||||
|
@ -1096,7 +1096,7 @@ REGISTER_OP("ArgMax")
|
||||
.Input("input: T")
|
||||
.Input("dimension: Tidx")
|
||||
.Output("output: output_type")
|
||||
.Attr("T: numbertype")
|
||||
.Attr("T: {numbertype, bool}")
|
||||
.Attr("Tidx: {int32, int64} = DT_INT32")
|
||||
.Attr("output_type: {int32, int64} = DT_INT64")
|
||||
.SetShapeFn(ArgOpShape);
|
||||
@ -1105,7 +1105,7 @@ REGISTER_OP("ArgMin")
|
||||
.Input("input: T")
|
||||
.Input("dimension: Tidx")
|
||||
.Output("output: output_type")
|
||||
.Attr("T: numbertype")
|
||||
.Attr("T: {numbertype, bool}")
|
||||
.Attr("Tidx: {int32, int64} = DT_INT32")
|
||||
.Attr("output_type: {int32, int64} = DT_INT64")
|
||||
.SetShapeFn(ArgOpShape);
|
||||
|
@ -61,7 +61,7 @@ class ArgMaxTest(test.TestCase):
|
||||
self._testArg(method, x, axis, expected_values, False, expected_err_re)
|
||||
|
||||
def _testBasic(self, dtype):
|
||||
x = np.arange(200, dtype=dtype)
|
||||
x = np.arange(200, dtype=np.float32).astype(np.bool_).astype(dtype)
|
||||
np.random.shuffle(x)
|
||||
|
||||
# Check that argmin and argmax match numpy along the primary axis
|
||||
@ -78,7 +78,9 @@ class ArgMaxTest(test.TestCase):
|
||||
|
||||
def _testDim(self, dtype):
|
||||
shape = (3, 2, 4, 5, 6, 3, 7)
|
||||
x = np.arange(functools.reduce(lambda x, y: x * y, shape), dtype=dtype)
|
||||
x = np.arange(
|
||||
functools.reduce(lambda x, y: x * y, shape),
|
||||
dtype=np.float32).astype(dtype)
|
||||
np.random.shuffle(x)
|
||||
x = x.reshape(shape)
|
||||
|
||||
@ -124,6 +126,11 @@ class ArgMaxTest(test.TestCase):
|
||||
self._testTieBreaking(np.int64)
|
||||
self._testDim(np.int64)
|
||||
|
||||
def testBool(self):
|
||||
self._testBasic(np.bool_)
|
||||
self._testTieBreaking(np.bool_)
|
||||
self._testDim(np.bool_)
|
||||
|
||||
def testEmpty(self):
|
||||
with self.cached_session():
|
||||
for op in math_ops.argmin, math_ops.argmax:
|
||||
|
Loading…
Reference in New Issue
Block a user