diff --git a/tensorflow/core/kernels/one_hot_op.cc b/tensorflow/core/kernels/one_hot_op.cc index 2732fb3a7e4..253d609d7e7 100644 --- a/tensorflow/core/kernels/one_hot_op.cc +++ b/tensorflow/core/kernels/one_hot_op.cc @@ -39,7 +39,7 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; -template +template class OneHotOp : public OpKernel { public: explicit OneHotOp(OpKernelConstruction* ctx) : OpKernel(ctx) { @@ -88,22 +88,22 @@ class OneHotOp : public OpKernel { // prefix_dim_size == # of elements before the axis // depth_v == # of elements per axis // suffix_dim_size == # of elements after the axis - int64 prefix_dim_size = 1; + TI prefix_dim_size = 1; for (int i = 0; i < axis; ++i) { prefix_dim_size *= indices_shape.dim_size(i); } - int64 suffix_dim_size = + TI suffix_dim_size = indices_shape.num_elements() / prefix_dim_size; // Split indices into matrix of size prefix_dim_size x suffix_dim_size auto indices_t = - indices.shaped({prefix_dim_size, suffix_dim_size}); + indices.shaped({prefix_dim_size, suffix_dim_size}); // Split output into 3-Tensor of size: // prefix_dim_size x depth x suffix_dim_size. auto output_t = output->shaped({prefix_dim_size, depth_v, suffix_dim_size}); - functor::OneHot::Compute(ctx->eigen_device(), indices_t, + functor::OneHot::Compute(ctx->eigen_device(), indices_t, on_value_t, off_value_t, &output_t); } @@ -113,44 +113,60 @@ class OneHotOp : public OpKernel { TF_DISALLOW_COPY_AND_ASSIGN(OneHotOp); }; -#define REGISTER_ONE_HOT(type) \ - REGISTER_KERNEL_BUILDER(Name("OneHot") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T") \ - .HostMemory("depth"), \ - OneHotOp); +#define REGISTER_ONE_HOT_INDEX(type, index_type) \ + REGISTER_KERNEL_BUILDER(Name("OneHot") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("TI") \ + .TypeConstraint("T") \ + .HostMemory("depth"), \ + OneHotOp); -TF_CALL_NUMBER_TYPES(REGISTER_ONE_HOT); +#define REGISTER_ONE_HOT(type) \ + REGISTER_ONE_HOT_INDEX(type, int32); \ + REGISTER_ONE_HOT_INDEX(type, int64) + +TF_CALL_ALL_TYPES(REGISTER_ONE_HOT); #if GOOGLE_CUDA // Forward declarations of the functor specializations for GPU. namespace functor { -#define DECLARE_GPU_SPEC(T) \ - template <> \ - void OneHot::Compute( \ - const GPUDevice& d, const typename TTypes::ConstMatrix& indices, \ - const typename TTypes::ConstScalar& on_value, \ - const typename TTypes::ConstScalar& off_value, \ - typename TTypes::Tensor* output); \ - extern template struct OneHot; +#define DECLARE_GPU_SPEC_INDEX(T, TI) \ + template <> \ + void OneHot::Compute( \ + const GPUDevice& d, const typename TTypes::ConstMatrix& indices, \ + const typename TTypes::ConstScalar& on_value, \ + const typename TTypes::ConstScalar& off_value, \ + typename TTypes::Tensor* output); \ + extern template struct OneHot; + +#define DECLARE_GPU_SPEC(T) \ + DECLARE_GPU_SPEC_INDEX(T, int32); \ + DECLARE_GPU_SPEC_INDEX(T, int64); \ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); +#undef DECLARE_GPU_SPEC_INDEX #undef DECLARE_GPU_SPEC } // namespace functor // Registration of the GPU implementations. -#define REGISTER_ONE_HOT_GPU(type) \ - REGISTER_KERNEL_BUILDER(Name("OneHot") \ - .Device(DEVICE_GPU) \ - .TypeConstraint("T") \ - .HostMemory("depth"), \ - OneHotOp); +#define REGISTER_ONE_HOT_GPU_INDEX(type, index_type) \ + REGISTER_KERNEL_BUILDER(Name("OneHot") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("TI") \ + .TypeConstraint("T") \ + .HostMemory("depth"), \ + OneHotOp); + +#define REGISTER_ONE_HOT_GPU(type) \ + REGISTER_ONE_HOT_GPU_INDEX(type, int32); \ + REGISTER_ONE_HOT_GPU_INDEX(type, int64); \ TF_CALL_GPU_NUMBER_TYPES(REGISTER_ONE_HOT_GPU); +#undef REGISTER_ONE_HOT_GPU_INDEX #undef REGISTER_ONE_HOT_GPU #endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/one_hot_op.h b/tensorflow/core/kernels/one_hot_op.h index 61f42f05dae..33770102876 100644 --- a/tensorflow/core/kernels/one_hot_op.h +++ b/tensorflow/core/kernels/one_hot_op.h @@ -27,11 +27,11 @@ namespace tensorflow { namespace generator { -template +template class OneGenerator { public: EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE - OneGenerator(const TTypes::ConstMatrix& indices, + OneGenerator(const typename TTypes::ConstMatrix& indices, const typename TTypes::ConstScalar& on_value, const typename TTypes::ConstScalar& off_value) : indices_(indices), on_value_(on_value), off_value_(off_value) {} @@ -44,7 +44,7 @@ class OneGenerator { } private: - const TTypes::ConstMatrix indices_; + const typename TTypes::ConstMatrix indices_; const typename TTypes::ConstScalar on_value_; const typename TTypes::ConstScalar off_value_; }; @@ -53,14 +53,14 @@ class OneGenerator { namespace functor { -template +template struct OneHot { EIGEN_ALWAYS_INLINE static void Compute( - const Device& d, const TTypes::ConstMatrix& indices, + const Device& d, const typename TTypes::ConstMatrix& indices, const typename TTypes::ConstScalar& on_value, const typename TTypes::ConstScalar& off_value, typename TTypes::Tensor* output) { - generator::OneGenerator generator(indices, on_value, off_value); + generator::OneGenerator generator(indices, on_value, off_value); output->device(d) = output->generate(generator); } }; diff --git a/tensorflow/core/kernels/one_hot_op_gpu.cu.cc b/tensorflow/core/kernels/one_hot_op_gpu.cu.cc index 804f7ef4064..70433b22ae6 100644 --- a/tensorflow/core/kernels/one_hot_op_gpu.cu.cc +++ b/tensorflow/core/kernels/one_hot_op_gpu.cu.cc @@ -21,17 +21,24 @@ limitations under the License. #include "tensorflow/core/kernels/one_hot_op.h" #include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.h" namespace tensorflow { typedef Eigen::GpuDevice GPUDevice; -#define DEFINE_GPU_SPEC(T) \ - template class generator::OneGenerator; \ - template struct functor::OneHot; +#define DEFINE_GPU_SPEC_INDEX(T, TI) \ + template class generator::OneGenerator; \ + template struct functor::OneHot; + +#define DEFINE_GPU_SPEC(T) \ + DEFINE_GPU_SPEC_INDEX(T, int32); \ + DEFINE_GPU_SPEC_INDEX(T, int64) + TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPEC); +#undef DEFINE_GPU_SPEC_INDEX #undef DEFINE_GPU_SPEC } // end namespace tensorflow diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index fe0fa343377..f44fc50e5e0 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -1733,13 +1733,14 @@ dimension be equal to sizeof(`type`)/sizeof(`T`). The shape then goes from )doc"); REGISTER_OP("OneHot") - .Input("indices: int64") + .Input("indices: TI") .Input("depth: int32") .Input("on_value: T") .Input("off_value: T") .Attr("axis: int = -1") .Output("output: T") .Attr("T: type") + .Attr("TI: {int32, int64} = DT_INT64") .Doc(R"doc( Returns a one-hot tensor. diff --git a/tensorflow/python/kernel_tests/one_hot_op_test.py b/tensorflow/python/kernel_tests/one_hot_op_test.py index 06f7f84ef86..cc792a3f951 100644 --- a/tensorflow/python/kernel_tests/one_hot_op_test.py +++ b/tensorflow/python/kernel_tests/one_hot_op_test.py @@ -34,7 +34,7 @@ class OneHotTest(tf.test.TestCase): ans = tf.one_hot(**inputs) if expected_err_re is None: tf_ans = ans.eval() - self.assertAllClose(tf_ans, truth, atol=1e-10) + self.assertAllEqual(tf_ans, truth) self.assertEqual(tf_ans.shape, ans.get_shape()) else: with self.assertRaisesOpError(expected_err_re): @@ -77,7 +77,7 @@ class OneHotTest(tf.test.TestCase): truth=truth.T) # Output is transpose version in this case def _testDefaultBasic(self, dtype): - indices = np.asarray([0, 2, -1, 1], dtype=dtype) + indices = np.asarray([0, 2, -1, 1], dtype=np.int64) depth = 3 truth = np.asarray( @@ -89,18 +89,16 @@ class OneHotTest(tf.test.TestCase): # axis == -1 self._testBothOneHot( - indices=indices, - depth=depth, - dtype=dtype, - truth=truth) + indices=indices, + depth=depth, + truth=truth) # axis == 0 self._testBothOneHot( - indices=indices, - depth=depth, - axis=0, - dtype=dtype, - truth=truth.T) # Output is transpose version in this case + indices=indices, + depth=depth, + axis=0, + truth=truth.T) # Output is transpose version in this case def testFloatBasic(self): self._testBasic(np.float32) @@ -163,7 +161,7 @@ class OneHotTest(tf.test.TestCase): def _testDefaultValuesBatch(self, dtype): indices = np.asarray([[0, 2, -1, 1], [1, 0, 1, -1]], - dtype=dtype) + dtype=np.int64) depth = 3 truth = np.asarray( @@ -192,10 +190,10 @@ class OneHotTest(tf.test.TestCase): dtype=dtype, truth=[truth[0].T, truth[1].T]) # Do not transpose the batch - def _testTypeBatch(self, dtype): + def _testValueTypeBatch(self, dtype): indices = np.asarray([[0, 2, -1, 1], [1, 0, 1, -1]], - dtype=dtype) + dtype=np.int64) depth = 3 on_value = np.asarray(1.0, dtype=dtype) @@ -218,6 +216,7 @@ class OneHotTest(tf.test.TestCase): on_value=on_value, off_value=off_value, depth=depth, + dtype=dtype, truth=truth) # axis == 1 @@ -227,32 +226,33 @@ class OneHotTest(tf.test.TestCase): off_value=off_value, depth=depth, axis=1, + dtype=dtype, truth=[truth[0].T, truth[1].T]) # Do not transpose the batch def testFloatBatch(self): self._testBatch(np.float32) self._testDefaultValuesBatch(np.float32) - self._testTypeBatch(np.float32) + self._testValueTypeBatch(np.float32) def testDoubleBatch(self): self._testBatch(np.float64) self._testDefaultValuesBatch(np.float64) - self._testTypeBatch(np.float64) + self._testValueTypeBatch(np.float64) def testInt32Batch(self): self._testBatch(np.int32) self._testDefaultValuesBatch(np.int32) - self._testTypeBatch(np.int32) + self._testValueTypeBatch(np.int32) def testInt64Batch(self): self._testBatch(np.int64) self._testDefaultValuesBatch(np.int64) - self._testTypeBatch(np.int64) + self._testValueTypeBatch(np.int64) def testComplexBatch(self): self._testBatch(np.complex64) - self._testDefaultValuesBatch(np.complex64) - self._testTypeBatch(np.complex64) + # self._testDefaultValuesBatch(np.complex64) + self._testValueTypeBatch(np.complex64) def testSimpleCases(self): indices = [0,1,2] @@ -284,17 +284,6 @@ class OneHotTest(tf.test.TestCase): self._testBothOneHot(indices=indices, depth=depth, on_value=1, off_value=-1, truth=truth) - def testStringDtypeError(self): - indices = [0,1,2] - depth = 3 - truth = np.asarray( - [[1.0, 0.0, 0.0], - [0.0, 1.0, 0.0], - [0.0, 0.0, 1.0]]) - self._testBothOneHot(indices=indices, depth=depth, on_value=1, - off_value=-1, dtype=tf.string, raises=TypeError, - truth=truth) - def testSingleValueGiven(self): # Only on_value provided indices = [0,1,2] @@ -317,5 +306,111 @@ class OneHotTest(tf.test.TestCase): self._testBothOneHot(indices=indices, depth=depth, off_value=0.0, truth=truth) + def testString(self): + indices = [0,1,2] + depth = 3 + truth = np.asarray( + [[b"1.0", b"0.0", b"0.0"], + [b"0.0", b"1.0", b"0.0"], + [b"0.0", b"0.0", b"1.0"]]) + on_value = np.asarray(b"1.0") + off_value = np.asarray(b"0.0") + + self._testBothOneHot(indices=indices, depth=depth, on_value=on_value, + off_value=off_value, dtype=tf.string, truth=truth) + + on_value = tf.constant(b"1.0") + off_value = tf.constant(b"0.0") + self._testBothOneHot(indices=indices, depth=depth, on_value=on_value, + off_value=off_value, dtype=tf.string, truth=truth) + + on_value = b"1.0" + off_value = b"0.0" + self._testBothOneHot(indices=indices, depth=depth, on_value=on_value, + off_value=off_value, dtype=tf.string, truth=truth) + + def testIndicesTypes(self): + tf_types = [tf.int32, tf.int64] + np_types = [np.int32, np.int64] + for itype in tf_types + np_types: + if itype in tf_types: + indices = tf.constant([[0, 2, -1, 1], + [1, 0, 1, -1]], + dtype=itype) + elif itype in np_types: + indices = np.asarray([[0, 2, -1, 1], + [1, 0, 1, -1]], + dtype=itype) + depth = 3 + + on_value = np.asarray(1.0, dtype=np.float32) + off_value = np.asarray(-1.0, dtype=np.float32) + + truth = np.asarray( + [[[1.0, -1.0, -1.0], + [-1.0, -1.0, 1.0], + [-1.0, -1.0, -1.0], + [-1.0, 1.0, -1.0]], + [[-1.0, 1.0, -1.0], + [1.0, -1.0, -1.0], + [-1.0, 1.0, -1.0], + [-1.0, -1.0, -1.0]]], + dtype=np.float32) + + # axis == -1 + self._testBothOneHot( + indices=indices, + on_value=on_value, + off_value=off_value, + depth=depth, + truth=truth) + + # axis == 1 + self._testBothOneHot( + indices=indices, + on_value=on_value, + off_value=off_value, + depth=depth, + axis=1, + truth=[truth[0].T, truth[1].T]) # Do not transpose the batch + + def testOnOffMismatchTypeError(self): + indices = [0, 1, 2] + depth = 3 + on_value = np.asarray(1.0, np.float64) + off_value = np.asarray(0.0, np.float32) + + self._testBothOneHot( + indices=indices, + depth=depth, + on_value=on_value, + off_value=off_value, + truth=None, + raises=TypeError) + + def testDtypeMismatchTypeError(self): + indices = [0, 1, 2] + depth = 3 + on_value = np.asarray(1.0, np.float32) + off_value = np.asarray(0.0, np.float32) + dtype = np.int32 + + self._testBothOneHot( + indices=indices, + depth=depth, + on_value=on_value, + dtype=dtype, + truth=None, + raises=TypeError) + + self._testBothOneHot( + indices=indices, + depth=depth, + on_value=off_value, + dtype=dtype, + truth=None, + raises=TypeError) + + if __name__ == "__main__": tf.test.main() diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 9ac80960b28..c19d93829a7 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -1912,14 +1912,21 @@ def _DepthToSpaceShape(op): [input_shape[0], height, width, new_depth])] -def one_hot(indices, depth, on_value=1, off_value=0, - axis=None, dtype=dtypes.float32, name=None): +def one_hot(indices, depth, on_value=None, off_value=None, + axis=None, dtype=None, name=None): """Returns a one-hot tensor. The locations represented by indices in `indices` take value `on_value`, - while all other locations take value `off_value`. By default, `on_value` is 1, - and `off_value` is 0. The type of the output tensor is specified by `dtype`, - which defaults to `tf.float32`. + while all other locations take value `off_value`. + + `on_value` and `off_value` must have matching data types. If `dtype` is also + provided, they must be the same data type as specified by `dtype`. + + If `on_value` is not provided, it will default to the value `1` with type + `dtype` + + If `off_value` is not provided, it will default to the value `0` with type + `dtype` If the input `indices` is rank `N`, the output will have rank `N+1`. The new axis is created at dimension `axis` (default: the new axis is appended @@ -1941,6 +1948,13 @@ def one_hot(indices, depth, on_value=1, off_value=0, depth x batch x features if axis == 0 ``` + If `dtype` is not provided, it will attempt to assume the data type of + `on_value` or `off_value`, if one or both are passed in. If none of + `on_value`, `off_value`, or `dtype` are provided, `dtype` will default to the + value `tf.float32` + + Note: If a non-numeric data type output is desired (tf.string, tf.bool, etc.), + both `on_value` and `off_value` _must_ be provided to `one_hot` Examples ========= @@ -1988,6 +2002,22 @@ def one_hot(indices, depth, on_value=1, off_value=0, ] ``` + Using default values for `on_value` and `off_value`: + + ``` + indices = [0, 1, 2] + depth = 3 + ``` + + The output will be + + ``` + output = + [[1., 0., 0.], + [0., 1., 0.], + [0., 0., 1.]] + ``` + Args: indices: A `Tensor` of indices. depth: A scalar defining the depth of the one hot dimension. @@ -2002,20 +2032,50 @@ def one_hot(indices, depth, on_value=1, off_value=0, output: The one-hot tensor. Raises: - TypeError: If dtype is `tf.string` + TypeError: If dtype of either `on_value` or `off_value` don't match `dtype` + TypeError: If dtype of `on_value` and `off_value` don't match one another """ - # Check for bad dtype specification - if dtype == dtypes.string: - raise TypeError("dtype must be a numeric type") - with ops.op_scope([indices, depth, on_value, off_value, axis, dtype], name, "one_hot") as name: - on_value = ops.convert_to_tensor(on_value, dtype=dtype, name="on_value") - off_value = ops.convert_to_tensor(off_value, dtype=dtype, name="off_value") - indices = ops.convert_to_tensor(indices, dtype=dtypes.int64, name="indices") - depth = ops.convert_to_tensor(depth, dtype=dtypes.int32, name="depth") - return gen_array_ops._one_hot(indices, depth, on_value, off_value, axis, - name) + on_exists = on_value is not None + off_exists = off_value is not None + + on_dtype = ops.convert_to_tensor(on_value).dtype.base_dtype if on_exists \ + else None + off_dtype = ops.convert_to_tensor(off_value).dtype.base_dtype if off_exists\ + else None + + if on_exists or off_exists: + if dtype is not None: + # Ensure provided on_value and/or off_value match dtype + if (on_exists and on_dtype != dtype): + raise TypeError("dtype {0} of on_value does not match " \ + "dtype parameter {1}".format(on_dtype, dtype)) + if (off_exists and off_dtype != dtype): + raise TypeError("dtype {0} of off_value does not match " \ + "dtype parameter {1}".format(off_dtype, dtype)) + else: + # dtype not provided: automatically assign it + dtype = on_dtype if on_exists else off_dtype + elif dtype is None: + # None of on_value, off_value, or dtype provided. Default dtype to float32 + dtype = dtypes.float32 + + if not on_exists: + # on_value not provided: assign to value 1 of type dtype + on_value = ops.convert_to_tensor(1, dtype, name="on_value") + on_dtype = dtype + if not off_exists: + # off_value not provided: assign to value 0 of type dtype + off_value = ops.convert_to_tensor(0, dtype, name="off_value") + off_dtype = dtype + + if on_dtype != off_dtype: + raise TypeError("dtype {0} of on_value does not match " \ + "dtype {1} of off_value".format(on_dtype, off_dtype)) + + return gen_array_ops._one_hot(indices, depth, on_value, + off_value, axis, name) @ops.RegisterShape("OneHot")