Merge pull request #2187 from samjabrahams/onehottypes

Fix convert_to_tensor error in tf.one_hot
This commit is contained in:
ebrevdo 2016-05-22 17:25:15 -07:00
commit 4be0965c01
6 changed files with 262 additions and 83 deletions

View File

@ -39,7 +39,7 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
template <typename Device, typename T>
template <typename Device, typename T, typename TI>
class OneHotOp : public OpKernel {
public:
explicit OneHotOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
@ -88,22 +88,22 @@ class OneHotOp : public OpKernel {
// prefix_dim_size == # of elements before the axis
// depth_v == # of elements per axis
// suffix_dim_size == # of elements after the axis
int64 prefix_dim_size = 1;
TI prefix_dim_size = 1;
for (int i = 0; i < axis; ++i) {
prefix_dim_size *= indices_shape.dim_size(i);
}
int64 suffix_dim_size =
TI suffix_dim_size =
indices_shape.num_elements() / prefix_dim_size;
// Split indices into matrix of size prefix_dim_size x suffix_dim_size
auto indices_t =
indices.shaped<int64, 2>({prefix_dim_size, suffix_dim_size});
indices.shaped<TI, 2>({prefix_dim_size, suffix_dim_size});
// Split output into 3-Tensor of size:
// prefix_dim_size x depth x suffix_dim_size.
auto output_t =
output->shaped<T, 3>({prefix_dim_size, depth_v, suffix_dim_size});
functor::OneHot<Device, T>::Compute(ctx->eigen_device<Device>(), indices_t,
functor::OneHot<Device, T, TI>::Compute(ctx->eigen_device<Device>(), indices_t,
on_value_t, off_value_t, &output_t);
}
@ -113,44 +113,60 @@ class OneHotOp : public OpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(OneHotOp);
};
#define REGISTER_ONE_HOT(type) \
REGISTER_KERNEL_BUILDER(Name("OneHot") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.HostMemory("depth"), \
OneHotOp<CPUDevice, type>);
#define REGISTER_ONE_HOT_INDEX(type, index_type) \
REGISTER_KERNEL_BUILDER(Name("OneHot") \
.Device(DEVICE_CPU) \
.TypeConstraint<index_type>("TI") \
.TypeConstraint<type>("T") \
.HostMemory("depth"), \
OneHotOp<CPUDevice, type, index_type>);
TF_CALL_NUMBER_TYPES(REGISTER_ONE_HOT);
#define REGISTER_ONE_HOT(type) \
REGISTER_ONE_HOT_INDEX(type, int32); \
REGISTER_ONE_HOT_INDEX(type, int64)
TF_CALL_ALL_TYPES(REGISTER_ONE_HOT);
#if GOOGLE_CUDA
// Forward declarations of the functor specializations for GPU.
namespace functor {
#define DECLARE_GPU_SPEC(T) \
template <> \
void OneHot<GPUDevice, T>::Compute( \
const GPUDevice& d, const typename TTypes<int64>::ConstMatrix& indices, \
const typename TTypes<T>::ConstScalar& on_value, \
const typename TTypes<T>::ConstScalar& off_value, \
typename TTypes<T, 3>::Tensor* output); \
extern template struct OneHot<GPUDevice, T>;
#define DECLARE_GPU_SPEC_INDEX(T, TI) \
template <> \
void OneHot<GPUDevice, T, TI>::Compute( \
const GPUDevice& d, const typename TTypes<TI>::ConstMatrix& indices, \
const typename TTypes<T>::ConstScalar& on_value, \
const typename TTypes<T>::ConstScalar& off_value, \
typename TTypes<T, 3>::Tensor* output); \
extern template struct OneHot<GPUDevice, T, TI>;
#define DECLARE_GPU_SPEC(T) \
DECLARE_GPU_SPEC_INDEX(T, int32); \
DECLARE_GPU_SPEC_INDEX(T, int64); \
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
#undef DECLARE_GPU_SPEC_INDEX
#undef DECLARE_GPU_SPEC
} // namespace functor
// Registration of the GPU implementations.
#define REGISTER_ONE_HOT_GPU(type) \
REGISTER_KERNEL_BUILDER(Name("OneHot") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T") \
.HostMemory("depth"), \
OneHotOp<GPUDevice, type>);
#define REGISTER_ONE_HOT_GPU_INDEX(type, index_type) \
REGISTER_KERNEL_BUILDER(Name("OneHot") \
.Device(DEVICE_GPU) \
.TypeConstraint<index_type>("TI") \
.TypeConstraint<type>("T") \
.HostMemory("depth"), \
OneHotOp<GPUDevice, type, index_type>);
#define REGISTER_ONE_HOT_GPU(type) \
REGISTER_ONE_HOT_GPU_INDEX(type, int32); \
REGISTER_ONE_HOT_GPU_INDEX(type, int64); \
TF_CALL_GPU_NUMBER_TYPES(REGISTER_ONE_HOT_GPU);
#undef REGISTER_ONE_HOT_GPU_INDEX
#undef REGISTER_ONE_HOT_GPU
#endif // GOOGLE_CUDA

View File

@ -27,11 +27,11 @@ namespace tensorflow {
namespace generator {
template <typename T>
template <typename T, typename TI>
class OneGenerator {
public:
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
OneGenerator(const TTypes<int64>::ConstMatrix& indices,
OneGenerator(const typename TTypes<TI>::ConstMatrix& indices,
const typename TTypes<T>::ConstScalar& on_value,
const typename TTypes<T>::ConstScalar& off_value)
: indices_(indices), on_value_(on_value), off_value_(off_value) {}
@ -44,7 +44,7 @@ class OneGenerator {
}
private:
const TTypes<int64>::ConstMatrix indices_;
const typename TTypes<TI>::ConstMatrix indices_;
const typename TTypes<T>::ConstScalar on_value_;
const typename TTypes<T>::ConstScalar off_value_;
};
@ -53,14 +53,14 @@ class OneGenerator {
namespace functor {
template <typename Device, typename T>
template <typename Device, typename T, typename TI>
struct OneHot {
EIGEN_ALWAYS_INLINE static void Compute(
const Device& d, const TTypes<int64>::ConstMatrix& indices,
const Device& d, const typename TTypes<TI>::ConstMatrix& indices,
const typename TTypes<T>::ConstScalar& on_value,
const typename TTypes<T>::ConstScalar& off_value,
typename TTypes<T, 3>::Tensor* output) {
generator::OneGenerator<T> generator(indices, on_value, off_value);
generator::OneGenerator<T, TI> generator(indices, on_value, off_value);
output->device(d) = output->generate(generator);
}
};

View File

@ -21,17 +21,24 @@ limitations under the License.
#include "tensorflow/core/kernels/one_hot_op.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/types.h"
namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
#define DEFINE_GPU_SPEC(T) \
template class generator::OneGenerator<T>; \
template struct functor::OneHot<GPUDevice, T>;
#define DEFINE_GPU_SPEC_INDEX(T, TI) \
template class generator::OneGenerator<T, TI>; \
template struct functor::OneHot<GPUDevice, T, TI>;
#define DEFINE_GPU_SPEC(T) \
DEFINE_GPU_SPEC_INDEX(T, int32); \
DEFINE_GPU_SPEC_INDEX(T, int64)
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPEC);
#undef DEFINE_GPU_SPEC_INDEX
#undef DEFINE_GPU_SPEC
} // end namespace tensorflow

View File

@ -1733,13 +1733,14 @@ dimension be equal to sizeof(`type`)/sizeof(`T`). The shape then goes from
)doc");
REGISTER_OP("OneHot")
.Input("indices: int64")
.Input("indices: TI")
.Input("depth: int32")
.Input("on_value: T")
.Input("off_value: T")
.Attr("axis: int = -1")
.Output("output: T")
.Attr("T: type")
.Attr("TI: {int32, int64} = DT_INT64")
.Doc(R"doc(
Returns a one-hot tensor.

View File

@ -34,7 +34,7 @@ class OneHotTest(tf.test.TestCase):
ans = tf.one_hot(**inputs)
if expected_err_re is None:
tf_ans = ans.eval()
self.assertAllClose(tf_ans, truth, atol=1e-10)
self.assertAllEqual(tf_ans, truth)
self.assertEqual(tf_ans.shape, ans.get_shape())
else:
with self.assertRaisesOpError(expected_err_re):
@ -77,7 +77,7 @@ class OneHotTest(tf.test.TestCase):
truth=truth.T) # Output is transpose version in this case
def _testDefaultBasic(self, dtype):
indices = np.asarray([0, 2, -1, 1], dtype=dtype)
indices = np.asarray([0, 2, -1, 1], dtype=np.int64)
depth = 3
truth = np.asarray(
@ -89,18 +89,16 @@ class OneHotTest(tf.test.TestCase):
# axis == -1
self._testBothOneHot(
indices=indices,
depth=depth,
dtype=dtype,
truth=truth)
indices=indices,
depth=depth,
truth=truth)
# axis == 0
self._testBothOneHot(
indices=indices,
depth=depth,
axis=0,
dtype=dtype,
truth=truth.T) # Output is transpose version in this case
indices=indices,
depth=depth,
axis=0,
truth=truth.T) # Output is transpose version in this case
def testFloatBasic(self):
self._testBasic(np.float32)
@ -163,7 +161,7 @@ class OneHotTest(tf.test.TestCase):
def _testDefaultValuesBatch(self, dtype):
indices = np.asarray([[0, 2, -1, 1],
[1, 0, 1, -1]],
dtype=dtype)
dtype=np.int64)
depth = 3
truth = np.asarray(
@ -192,10 +190,10 @@ class OneHotTest(tf.test.TestCase):
dtype=dtype,
truth=[truth[0].T, truth[1].T]) # Do not transpose the batch
def _testTypeBatch(self, dtype):
def _testValueTypeBatch(self, dtype):
indices = np.asarray([[0, 2, -1, 1],
[1, 0, 1, -1]],
dtype=dtype)
dtype=np.int64)
depth = 3
on_value = np.asarray(1.0, dtype=dtype)
@ -218,6 +216,7 @@ class OneHotTest(tf.test.TestCase):
on_value=on_value,
off_value=off_value,
depth=depth,
dtype=dtype,
truth=truth)
# axis == 1
@ -227,32 +226,33 @@ class OneHotTest(tf.test.TestCase):
off_value=off_value,
depth=depth,
axis=1,
dtype=dtype,
truth=[truth[0].T, truth[1].T]) # Do not transpose the batch
def testFloatBatch(self):
self._testBatch(np.float32)
self._testDefaultValuesBatch(np.float32)
self._testTypeBatch(np.float32)
self._testValueTypeBatch(np.float32)
def testDoubleBatch(self):
self._testBatch(np.float64)
self._testDefaultValuesBatch(np.float64)
self._testTypeBatch(np.float64)
self._testValueTypeBatch(np.float64)
def testInt32Batch(self):
self._testBatch(np.int32)
self._testDefaultValuesBatch(np.int32)
self._testTypeBatch(np.int32)
self._testValueTypeBatch(np.int32)
def testInt64Batch(self):
self._testBatch(np.int64)
self._testDefaultValuesBatch(np.int64)
self._testTypeBatch(np.int64)
self._testValueTypeBatch(np.int64)
def testComplexBatch(self):
self._testBatch(np.complex64)
self._testDefaultValuesBatch(np.complex64)
self._testTypeBatch(np.complex64)
# self._testDefaultValuesBatch(np.complex64)
self._testValueTypeBatch(np.complex64)
def testSimpleCases(self):
indices = [0,1,2]
@ -284,17 +284,6 @@ class OneHotTest(tf.test.TestCase):
self._testBothOneHot(indices=indices, depth=depth, on_value=1,
off_value=-1, truth=truth)
def testStringDtypeError(self):
indices = [0,1,2]
depth = 3
truth = np.asarray(
[[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0]])
self._testBothOneHot(indices=indices, depth=depth, on_value=1,
off_value=-1, dtype=tf.string, raises=TypeError,
truth=truth)
def testSingleValueGiven(self):
# Only on_value provided
indices = [0,1,2]
@ -317,5 +306,111 @@ class OneHotTest(tf.test.TestCase):
self._testBothOneHot(indices=indices, depth=depth,
off_value=0.0, truth=truth)
def testString(self):
indices = [0,1,2]
depth = 3
truth = np.asarray(
[[b"1.0", b"0.0", b"0.0"],
[b"0.0", b"1.0", b"0.0"],
[b"0.0", b"0.0", b"1.0"]])
on_value = np.asarray(b"1.0")
off_value = np.asarray(b"0.0")
self._testBothOneHot(indices=indices, depth=depth, on_value=on_value,
off_value=off_value, dtype=tf.string, truth=truth)
on_value = tf.constant(b"1.0")
off_value = tf.constant(b"0.0")
self._testBothOneHot(indices=indices, depth=depth, on_value=on_value,
off_value=off_value, dtype=tf.string, truth=truth)
on_value = b"1.0"
off_value = b"0.0"
self._testBothOneHot(indices=indices, depth=depth, on_value=on_value,
off_value=off_value, dtype=tf.string, truth=truth)
def testIndicesTypes(self):
tf_types = [tf.int32, tf.int64]
np_types = [np.int32, np.int64]
for itype in tf_types + np_types:
if itype in tf_types:
indices = tf.constant([[0, 2, -1, 1],
[1, 0, 1, -1]],
dtype=itype)
elif itype in np_types:
indices = np.asarray([[0, 2, -1, 1],
[1, 0, 1, -1]],
dtype=itype)
depth = 3
on_value = np.asarray(1.0, dtype=np.float32)
off_value = np.asarray(-1.0, dtype=np.float32)
truth = np.asarray(
[[[1.0, -1.0, -1.0],
[-1.0, -1.0, 1.0],
[-1.0, -1.0, -1.0],
[-1.0, 1.0, -1.0]],
[[-1.0, 1.0, -1.0],
[1.0, -1.0, -1.0],
[-1.0, 1.0, -1.0],
[-1.0, -1.0, -1.0]]],
dtype=np.float32)
# axis == -1
self._testBothOneHot(
indices=indices,
on_value=on_value,
off_value=off_value,
depth=depth,
truth=truth)
# axis == 1
self._testBothOneHot(
indices=indices,
on_value=on_value,
off_value=off_value,
depth=depth,
axis=1,
truth=[truth[0].T, truth[1].T]) # Do not transpose the batch
def testOnOffMismatchTypeError(self):
indices = [0, 1, 2]
depth = 3
on_value = np.asarray(1.0, np.float64)
off_value = np.asarray(0.0, np.float32)
self._testBothOneHot(
indices=indices,
depth=depth,
on_value=on_value,
off_value=off_value,
truth=None,
raises=TypeError)
def testDtypeMismatchTypeError(self):
indices = [0, 1, 2]
depth = 3
on_value = np.asarray(1.0, np.float32)
off_value = np.asarray(0.0, np.float32)
dtype = np.int32
self._testBothOneHot(
indices=indices,
depth=depth,
on_value=on_value,
dtype=dtype,
truth=None,
raises=TypeError)
self._testBothOneHot(
indices=indices,
depth=depth,
on_value=off_value,
dtype=dtype,
truth=None,
raises=TypeError)
if __name__ == "__main__":
tf.test.main()

View File

@ -1912,14 +1912,21 @@ def _DepthToSpaceShape(op):
[input_shape[0], height, width, new_depth])]
def one_hot(indices, depth, on_value=1, off_value=0,
axis=None, dtype=dtypes.float32, name=None):
def one_hot(indices, depth, on_value=None, off_value=None,
axis=None, dtype=None, name=None):
"""Returns a one-hot tensor.
The locations represented by indices in `indices` take value `on_value`,
while all other locations take value `off_value`. By default, `on_value` is 1,
and `off_value` is 0. The type of the output tensor is specified by `dtype`,
which defaults to `tf.float32`.
while all other locations take value `off_value`.
`on_value` and `off_value` must have matching data types. If `dtype` is also
provided, they must be the same data type as specified by `dtype`.
If `on_value` is not provided, it will default to the value `1` with type
`dtype`
If `off_value` is not provided, it will default to the value `0` with type
`dtype`
If the input `indices` is rank `N`, the output will have rank `N+1`. The
new axis is created at dimension `axis` (default: the new axis is appended
@ -1941,6 +1948,13 @@ def one_hot(indices, depth, on_value=1, off_value=0,
depth x batch x features if axis == 0
```
If `dtype` is not provided, it will attempt to assume the data type of
`on_value` or `off_value`, if one or both are passed in. If none of
`on_value`, `off_value`, or `dtype` are provided, `dtype` will default to the
value `tf.float32`
Note: If a non-numeric data type output is desired (tf.string, tf.bool, etc.),
both `on_value` and `off_value` _must_ be provided to `one_hot`
Examples
=========
@ -1988,6 +2002,22 @@ def one_hot(indices, depth, on_value=1, off_value=0,
]
```
Using default values for `on_value` and `off_value`:
```
indices = [0, 1, 2]
depth = 3
```
The output will be
```
output =
[[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]]
```
Args:
indices: A `Tensor` of indices.
depth: A scalar defining the depth of the one hot dimension.
@ -2002,20 +2032,50 @@ def one_hot(indices, depth, on_value=1, off_value=0,
output: The one-hot tensor.
Raises:
TypeError: If dtype is `tf.string`
TypeError: If dtype of either `on_value` or `off_value` don't match `dtype`
TypeError: If dtype of `on_value` and `off_value` don't match one another
"""
# Check for bad dtype specification
if dtype == dtypes.string:
raise TypeError("dtype must be a numeric type")
with ops.op_scope([indices, depth, on_value, off_value,
axis, dtype], name, "one_hot") as name:
on_value = ops.convert_to_tensor(on_value, dtype=dtype, name="on_value")
off_value = ops.convert_to_tensor(off_value, dtype=dtype, name="off_value")
indices = ops.convert_to_tensor(indices, dtype=dtypes.int64, name="indices")
depth = ops.convert_to_tensor(depth, dtype=dtypes.int32, name="depth")
return gen_array_ops._one_hot(indices, depth, on_value, off_value, axis,
name)
on_exists = on_value is not None
off_exists = off_value is not None
on_dtype = ops.convert_to_tensor(on_value).dtype.base_dtype if on_exists \
else None
off_dtype = ops.convert_to_tensor(off_value).dtype.base_dtype if off_exists\
else None
if on_exists or off_exists:
if dtype is not None:
# Ensure provided on_value and/or off_value match dtype
if (on_exists and on_dtype != dtype):
raise TypeError("dtype {0} of on_value does not match " \
"dtype parameter {1}".format(on_dtype, dtype))
if (off_exists and off_dtype != dtype):
raise TypeError("dtype {0} of off_value does not match " \
"dtype parameter {1}".format(off_dtype, dtype))
else:
# dtype not provided: automatically assign it
dtype = on_dtype if on_exists else off_dtype
elif dtype is None:
# None of on_value, off_value, or dtype provided. Default dtype to float32
dtype = dtypes.float32
if not on_exists:
# on_value not provided: assign to value 1 of type dtype
on_value = ops.convert_to_tensor(1, dtype, name="on_value")
on_dtype = dtype
if not off_exists:
# off_value not provided: assign to value 0 of type dtype
off_value = ops.convert_to_tensor(0, dtype, name="off_value")
off_dtype = dtype
if on_dtype != off_dtype:
raise TypeError("dtype {0} of on_value does not match " \
"dtype {1} of off_value".format(on_dtype, off_dtype))
return gen_array_ops._one_hot(indices, depth, on_value,
off_value, axis, name)
@ops.RegisterShape("OneHot")