Merge pull request #2187 from samjabrahams/onehottypes
Fix convert_to_tensor error in tf.one_hot
This commit is contained in:
commit
4be0965c01
@ -39,7 +39,7 @@ namespace tensorflow {
|
|||||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||||
typedef Eigen::GpuDevice GPUDevice;
|
typedef Eigen::GpuDevice GPUDevice;
|
||||||
|
|
||||||
template <typename Device, typename T>
|
template <typename Device, typename T, typename TI>
|
||||||
class OneHotOp : public OpKernel {
|
class OneHotOp : public OpKernel {
|
||||||
public:
|
public:
|
||||||
explicit OneHotOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
explicit OneHotOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||||
@ -88,22 +88,22 @@ class OneHotOp : public OpKernel {
|
|||||||
// prefix_dim_size == # of elements before the axis
|
// prefix_dim_size == # of elements before the axis
|
||||||
// depth_v == # of elements per axis
|
// depth_v == # of elements per axis
|
||||||
// suffix_dim_size == # of elements after the axis
|
// suffix_dim_size == # of elements after the axis
|
||||||
int64 prefix_dim_size = 1;
|
TI prefix_dim_size = 1;
|
||||||
for (int i = 0; i < axis; ++i) {
|
for (int i = 0; i < axis; ++i) {
|
||||||
prefix_dim_size *= indices_shape.dim_size(i);
|
prefix_dim_size *= indices_shape.dim_size(i);
|
||||||
}
|
}
|
||||||
int64 suffix_dim_size =
|
TI suffix_dim_size =
|
||||||
indices_shape.num_elements() / prefix_dim_size;
|
indices_shape.num_elements() / prefix_dim_size;
|
||||||
|
|
||||||
// Split indices into matrix of size prefix_dim_size x suffix_dim_size
|
// Split indices into matrix of size prefix_dim_size x suffix_dim_size
|
||||||
auto indices_t =
|
auto indices_t =
|
||||||
indices.shaped<int64, 2>({prefix_dim_size, suffix_dim_size});
|
indices.shaped<TI, 2>({prefix_dim_size, suffix_dim_size});
|
||||||
// Split output into 3-Tensor of size:
|
// Split output into 3-Tensor of size:
|
||||||
// prefix_dim_size x depth x suffix_dim_size.
|
// prefix_dim_size x depth x suffix_dim_size.
|
||||||
auto output_t =
|
auto output_t =
|
||||||
output->shaped<T, 3>({prefix_dim_size, depth_v, suffix_dim_size});
|
output->shaped<T, 3>({prefix_dim_size, depth_v, suffix_dim_size});
|
||||||
|
|
||||||
functor::OneHot<Device, T>::Compute(ctx->eigen_device<Device>(), indices_t,
|
functor::OneHot<Device, T, TI>::Compute(ctx->eigen_device<Device>(), indices_t,
|
||||||
on_value_t, off_value_t, &output_t);
|
on_value_t, off_value_t, &output_t);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -113,44 +113,60 @@ class OneHotOp : public OpKernel {
|
|||||||
TF_DISALLOW_COPY_AND_ASSIGN(OneHotOp);
|
TF_DISALLOW_COPY_AND_ASSIGN(OneHotOp);
|
||||||
};
|
};
|
||||||
|
|
||||||
#define REGISTER_ONE_HOT(type) \
|
#define REGISTER_ONE_HOT_INDEX(type, index_type) \
|
||||||
REGISTER_KERNEL_BUILDER(Name("OneHot") \
|
REGISTER_KERNEL_BUILDER(Name("OneHot") \
|
||||||
.Device(DEVICE_CPU) \
|
.Device(DEVICE_CPU) \
|
||||||
.TypeConstraint<type>("T") \
|
.TypeConstraint<index_type>("TI") \
|
||||||
.HostMemory("depth"), \
|
.TypeConstraint<type>("T") \
|
||||||
OneHotOp<CPUDevice, type>);
|
.HostMemory("depth"), \
|
||||||
|
OneHotOp<CPUDevice, type, index_type>);
|
||||||
|
|
||||||
TF_CALL_NUMBER_TYPES(REGISTER_ONE_HOT);
|
#define REGISTER_ONE_HOT(type) \
|
||||||
|
REGISTER_ONE_HOT_INDEX(type, int32); \
|
||||||
|
REGISTER_ONE_HOT_INDEX(type, int64)
|
||||||
|
|
||||||
|
TF_CALL_ALL_TYPES(REGISTER_ONE_HOT);
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
|
|
||||||
// Forward declarations of the functor specializations for GPU.
|
// Forward declarations of the functor specializations for GPU.
|
||||||
namespace functor {
|
namespace functor {
|
||||||
#define DECLARE_GPU_SPEC(T) \
|
#define DECLARE_GPU_SPEC_INDEX(T, TI) \
|
||||||
template <> \
|
template <> \
|
||||||
void OneHot<GPUDevice, T>::Compute( \
|
void OneHot<GPUDevice, T, TI>::Compute( \
|
||||||
const GPUDevice& d, const typename TTypes<int64>::ConstMatrix& indices, \
|
const GPUDevice& d, const typename TTypes<TI>::ConstMatrix& indices, \
|
||||||
const typename TTypes<T>::ConstScalar& on_value, \
|
const typename TTypes<T>::ConstScalar& on_value, \
|
||||||
const typename TTypes<T>::ConstScalar& off_value, \
|
const typename TTypes<T>::ConstScalar& off_value, \
|
||||||
typename TTypes<T, 3>::Tensor* output); \
|
typename TTypes<T, 3>::Tensor* output); \
|
||||||
extern template struct OneHot<GPUDevice, T>;
|
extern template struct OneHot<GPUDevice, T, TI>;
|
||||||
|
|
||||||
|
#define DECLARE_GPU_SPEC(T) \
|
||||||
|
DECLARE_GPU_SPEC_INDEX(T, int32); \
|
||||||
|
DECLARE_GPU_SPEC_INDEX(T, int64); \
|
||||||
|
|
||||||
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
|
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
|
||||||
|
|
||||||
|
#undef DECLARE_GPU_SPEC_INDEX
|
||||||
#undef DECLARE_GPU_SPEC
|
#undef DECLARE_GPU_SPEC
|
||||||
|
|
||||||
} // namespace functor
|
} // namespace functor
|
||||||
|
|
||||||
// Registration of the GPU implementations.
|
// Registration of the GPU implementations.
|
||||||
#define REGISTER_ONE_HOT_GPU(type) \
|
#define REGISTER_ONE_HOT_GPU_INDEX(type, index_type) \
|
||||||
REGISTER_KERNEL_BUILDER(Name("OneHot") \
|
REGISTER_KERNEL_BUILDER(Name("OneHot") \
|
||||||
.Device(DEVICE_GPU) \
|
.Device(DEVICE_GPU) \
|
||||||
.TypeConstraint<type>("T") \
|
.TypeConstraint<index_type>("TI") \
|
||||||
.HostMemory("depth"), \
|
.TypeConstraint<type>("T") \
|
||||||
OneHotOp<GPUDevice, type>);
|
.HostMemory("depth"), \
|
||||||
|
OneHotOp<GPUDevice, type, index_type>);
|
||||||
|
|
||||||
|
#define REGISTER_ONE_HOT_GPU(type) \
|
||||||
|
REGISTER_ONE_HOT_GPU_INDEX(type, int32); \
|
||||||
|
REGISTER_ONE_HOT_GPU_INDEX(type, int64); \
|
||||||
|
|
||||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_ONE_HOT_GPU);
|
TF_CALL_GPU_NUMBER_TYPES(REGISTER_ONE_HOT_GPU);
|
||||||
|
|
||||||
|
#undef REGISTER_ONE_HOT_GPU_INDEX
|
||||||
#undef REGISTER_ONE_HOT_GPU
|
#undef REGISTER_ONE_HOT_GPU
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA
|
||||||
|
@ -27,11 +27,11 @@ namespace tensorflow {
|
|||||||
|
|
||||||
namespace generator {
|
namespace generator {
|
||||||
|
|
||||||
template <typename T>
|
template <typename T, typename TI>
|
||||||
class OneGenerator {
|
class OneGenerator {
|
||||||
public:
|
public:
|
||||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
|
||||||
OneGenerator(const TTypes<int64>::ConstMatrix& indices,
|
OneGenerator(const typename TTypes<TI>::ConstMatrix& indices,
|
||||||
const typename TTypes<T>::ConstScalar& on_value,
|
const typename TTypes<T>::ConstScalar& on_value,
|
||||||
const typename TTypes<T>::ConstScalar& off_value)
|
const typename TTypes<T>::ConstScalar& off_value)
|
||||||
: indices_(indices), on_value_(on_value), off_value_(off_value) {}
|
: indices_(indices), on_value_(on_value), off_value_(off_value) {}
|
||||||
@ -44,7 +44,7 @@ class OneGenerator {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const TTypes<int64>::ConstMatrix indices_;
|
const typename TTypes<TI>::ConstMatrix indices_;
|
||||||
const typename TTypes<T>::ConstScalar on_value_;
|
const typename TTypes<T>::ConstScalar on_value_;
|
||||||
const typename TTypes<T>::ConstScalar off_value_;
|
const typename TTypes<T>::ConstScalar off_value_;
|
||||||
};
|
};
|
||||||
@ -53,14 +53,14 @@ class OneGenerator {
|
|||||||
|
|
||||||
namespace functor {
|
namespace functor {
|
||||||
|
|
||||||
template <typename Device, typename T>
|
template <typename Device, typename T, typename TI>
|
||||||
struct OneHot {
|
struct OneHot {
|
||||||
EIGEN_ALWAYS_INLINE static void Compute(
|
EIGEN_ALWAYS_INLINE static void Compute(
|
||||||
const Device& d, const TTypes<int64>::ConstMatrix& indices,
|
const Device& d, const typename TTypes<TI>::ConstMatrix& indices,
|
||||||
const typename TTypes<T>::ConstScalar& on_value,
|
const typename TTypes<T>::ConstScalar& on_value,
|
||||||
const typename TTypes<T>::ConstScalar& off_value,
|
const typename TTypes<T>::ConstScalar& off_value,
|
||||||
typename TTypes<T, 3>::Tensor* output) {
|
typename TTypes<T, 3>::Tensor* output) {
|
||||||
generator::OneGenerator<T> generator(indices, on_value, off_value);
|
generator::OneGenerator<T, TI> generator(indices, on_value, off_value);
|
||||||
output->device(d) = output->generate(generator);
|
output->device(d) = output->generate(generator);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -21,17 +21,24 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/kernels/one_hot_op.h"
|
#include "tensorflow/core/kernels/one_hot_op.h"
|
||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
|
#include "tensorflow/core/framework/types.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
typedef Eigen::GpuDevice GPUDevice;
|
typedef Eigen::GpuDevice GPUDevice;
|
||||||
|
|
||||||
#define DEFINE_GPU_SPEC(T) \
|
#define DEFINE_GPU_SPEC_INDEX(T, TI) \
|
||||||
template class generator::OneGenerator<T>; \
|
template class generator::OneGenerator<T, TI>; \
|
||||||
template struct functor::OneHot<GPUDevice, T>;
|
template struct functor::OneHot<GPUDevice, T, TI>;
|
||||||
|
|
||||||
|
#define DEFINE_GPU_SPEC(T) \
|
||||||
|
DEFINE_GPU_SPEC_INDEX(T, int32); \
|
||||||
|
DEFINE_GPU_SPEC_INDEX(T, int64)
|
||||||
|
|
||||||
|
|
||||||
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPEC);
|
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPEC);
|
||||||
|
|
||||||
|
#undef DEFINE_GPU_SPEC_INDEX
|
||||||
#undef DEFINE_GPU_SPEC
|
#undef DEFINE_GPU_SPEC
|
||||||
|
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
@ -1733,13 +1733,14 @@ dimension be equal to sizeof(`type`)/sizeof(`T`). The shape then goes from
|
|||||||
)doc");
|
)doc");
|
||||||
|
|
||||||
REGISTER_OP("OneHot")
|
REGISTER_OP("OneHot")
|
||||||
.Input("indices: int64")
|
.Input("indices: TI")
|
||||||
.Input("depth: int32")
|
.Input("depth: int32")
|
||||||
.Input("on_value: T")
|
.Input("on_value: T")
|
||||||
.Input("off_value: T")
|
.Input("off_value: T")
|
||||||
.Attr("axis: int = -1")
|
.Attr("axis: int = -1")
|
||||||
.Output("output: T")
|
.Output("output: T")
|
||||||
.Attr("T: type")
|
.Attr("T: type")
|
||||||
|
.Attr("TI: {int32, int64} = DT_INT64")
|
||||||
.Doc(R"doc(
|
.Doc(R"doc(
|
||||||
Returns a one-hot tensor.
|
Returns a one-hot tensor.
|
||||||
|
|
||||||
|
@ -34,7 +34,7 @@ class OneHotTest(tf.test.TestCase):
|
|||||||
ans = tf.one_hot(**inputs)
|
ans = tf.one_hot(**inputs)
|
||||||
if expected_err_re is None:
|
if expected_err_re is None:
|
||||||
tf_ans = ans.eval()
|
tf_ans = ans.eval()
|
||||||
self.assertAllClose(tf_ans, truth, atol=1e-10)
|
self.assertAllEqual(tf_ans, truth)
|
||||||
self.assertEqual(tf_ans.shape, ans.get_shape())
|
self.assertEqual(tf_ans.shape, ans.get_shape())
|
||||||
else:
|
else:
|
||||||
with self.assertRaisesOpError(expected_err_re):
|
with self.assertRaisesOpError(expected_err_re):
|
||||||
@ -77,7 +77,7 @@ class OneHotTest(tf.test.TestCase):
|
|||||||
truth=truth.T) # Output is transpose version in this case
|
truth=truth.T) # Output is transpose version in this case
|
||||||
|
|
||||||
def _testDefaultBasic(self, dtype):
|
def _testDefaultBasic(self, dtype):
|
||||||
indices = np.asarray([0, 2, -1, 1], dtype=dtype)
|
indices = np.asarray([0, 2, -1, 1], dtype=np.int64)
|
||||||
depth = 3
|
depth = 3
|
||||||
|
|
||||||
truth = np.asarray(
|
truth = np.asarray(
|
||||||
@ -89,18 +89,16 @@ class OneHotTest(tf.test.TestCase):
|
|||||||
|
|
||||||
# axis == -1
|
# axis == -1
|
||||||
self._testBothOneHot(
|
self._testBothOneHot(
|
||||||
indices=indices,
|
indices=indices,
|
||||||
depth=depth,
|
depth=depth,
|
||||||
dtype=dtype,
|
truth=truth)
|
||||||
truth=truth)
|
|
||||||
|
|
||||||
# axis == 0
|
# axis == 0
|
||||||
self._testBothOneHot(
|
self._testBothOneHot(
|
||||||
indices=indices,
|
indices=indices,
|
||||||
depth=depth,
|
depth=depth,
|
||||||
axis=0,
|
axis=0,
|
||||||
dtype=dtype,
|
truth=truth.T) # Output is transpose version in this case
|
||||||
truth=truth.T) # Output is transpose version in this case
|
|
||||||
|
|
||||||
def testFloatBasic(self):
|
def testFloatBasic(self):
|
||||||
self._testBasic(np.float32)
|
self._testBasic(np.float32)
|
||||||
@ -163,7 +161,7 @@ class OneHotTest(tf.test.TestCase):
|
|||||||
def _testDefaultValuesBatch(self, dtype):
|
def _testDefaultValuesBatch(self, dtype):
|
||||||
indices = np.asarray([[0, 2, -1, 1],
|
indices = np.asarray([[0, 2, -1, 1],
|
||||||
[1, 0, 1, -1]],
|
[1, 0, 1, -1]],
|
||||||
dtype=dtype)
|
dtype=np.int64)
|
||||||
depth = 3
|
depth = 3
|
||||||
|
|
||||||
truth = np.asarray(
|
truth = np.asarray(
|
||||||
@ -192,10 +190,10 @@ class OneHotTest(tf.test.TestCase):
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
truth=[truth[0].T, truth[1].T]) # Do not transpose the batch
|
truth=[truth[0].T, truth[1].T]) # Do not transpose the batch
|
||||||
|
|
||||||
def _testTypeBatch(self, dtype):
|
def _testValueTypeBatch(self, dtype):
|
||||||
indices = np.asarray([[0, 2, -1, 1],
|
indices = np.asarray([[0, 2, -1, 1],
|
||||||
[1, 0, 1, -1]],
|
[1, 0, 1, -1]],
|
||||||
dtype=dtype)
|
dtype=np.int64)
|
||||||
depth = 3
|
depth = 3
|
||||||
|
|
||||||
on_value = np.asarray(1.0, dtype=dtype)
|
on_value = np.asarray(1.0, dtype=dtype)
|
||||||
@ -218,6 +216,7 @@ class OneHotTest(tf.test.TestCase):
|
|||||||
on_value=on_value,
|
on_value=on_value,
|
||||||
off_value=off_value,
|
off_value=off_value,
|
||||||
depth=depth,
|
depth=depth,
|
||||||
|
dtype=dtype,
|
||||||
truth=truth)
|
truth=truth)
|
||||||
|
|
||||||
# axis == 1
|
# axis == 1
|
||||||
@ -227,32 +226,33 @@ class OneHotTest(tf.test.TestCase):
|
|||||||
off_value=off_value,
|
off_value=off_value,
|
||||||
depth=depth,
|
depth=depth,
|
||||||
axis=1,
|
axis=1,
|
||||||
|
dtype=dtype,
|
||||||
truth=[truth[0].T, truth[1].T]) # Do not transpose the batch
|
truth=[truth[0].T, truth[1].T]) # Do not transpose the batch
|
||||||
|
|
||||||
def testFloatBatch(self):
|
def testFloatBatch(self):
|
||||||
self._testBatch(np.float32)
|
self._testBatch(np.float32)
|
||||||
self._testDefaultValuesBatch(np.float32)
|
self._testDefaultValuesBatch(np.float32)
|
||||||
self._testTypeBatch(np.float32)
|
self._testValueTypeBatch(np.float32)
|
||||||
|
|
||||||
def testDoubleBatch(self):
|
def testDoubleBatch(self):
|
||||||
self._testBatch(np.float64)
|
self._testBatch(np.float64)
|
||||||
self._testDefaultValuesBatch(np.float64)
|
self._testDefaultValuesBatch(np.float64)
|
||||||
self._testTypeBatch(np.float64)
|
self._testValueTypeBatch(np.float64)
|
||||||
|
|
||||||
def testInt32Batch(self):
|
def testInt32Batch(self):
|
||||||
self._testBatch(np.int32)
|
self._testBatch(np.int32)
|
||||||
self._testDefaultValuesBatch(np.int32)
|
self._testDefaultValuesBatch(np.int32)
|
||||||
self._testTypeBatch(np.int32)
|
self._testValueTypeBatch(np.int32)
|
||||||
|
|
||||||
def testInt64Batch(self):
|
def testInt64Batch(self):
|
||||||
self._testBatch(np.int64)
|
self._testBatch(np.int64)
|
||||||
self._testDefaultValuesBatch(np.int64)
|
self._testDefaultValuesBatch(np.int64)
|
||||||
self._testTypeBatch(np.int64)
|
self._testValueTypeBatch(np.int64)
|
||||||
|
|
||||||
def testComplexBatch(self):
|
def testComplexBatch(self):
|
||||||
self._testBatch(np.complex64)
|
self._testBatch(np.complex64)
|
||||||
self._testDefaultValuesBatch(np.complex64)
|
# self._testDefaultValuesBatch(np.complex64)
|
||||||
self._testTypeBatch(np.complex64)
|
self._testValueTypeBatch(np.complex64)
|
||||||
|
|
||||||
def testSimpleCases(self):
|
def testSimpleCases(self):
|
||||||
indices = [0,1,2]
|
indices = [0,1,2]
|
||||||
@ -284,17 +284,6 @@ class OneHotTest(tf.test.TestCase):
|
|||||||
self._testBothOneHot(indices=indices, depth=depth, on_value=1,
|
self._testBothOneHot(indices=indices, depth=depth, on_value=1,
|
||||||
off_value=-1, truth=truth)
|
off_value=-1, truth=truth)
|
||||||
|
|
||||||
def testStringDtypeError(self):
|
|
||||||
indices = [0,1,2]
|
|
||||||
depth = 3
|
|
||||||
truth = np.asarray(
|
|
||||||
[[1.0, 0.0, 0.0],
|
|
||||||
[0.0, 1.0, 0.0],
|
|
||||||
[0.0, 0.0, 1.0]])
|
|
||||||
self._testBothOneHot(indices=indices, depth=depth, on_value=1,
|
|
||||||
off_value=-1, dtype=tf.string, raises=TypeError,
|
|
||||||
truth=truth)
|
|
||||||
|
|
||||||
def testSingleValueGiven(self):
|
def testSingleValueGiven(self):
|
||||||
# Only on_value provided
|
# Only on_value provided
|
||||||
indices = [0,1,2]
|
indices = [0,1,2]
|
||||||
@ -317,5 +306,111 @@ class OneHotTest(tf.test.TestCase):
|
|||||||
self._testBothOneHot(indices=indices, depth=depth,
|
self._testBothOneHot(indices=indices, depth=depth,
|
||||||
off_value=0.0, truth=truth)
|
off_value=0.0, truth=truth)
|
||||||
|
|
||||||
|
def testString(self):
|
||||||
|
indices = [0,1,2]
|
||||||
|
depth = 3
|
||||||
|
truth = np.asarray(
|
||||||
|
[[b"1.0", b"0.0", b"0.0"],
|
||||||
|
[b"0.0", b"1.0", b"0.0"],
|
||||||
|
[b"0.0", b"0.0", b"1.0"]])
|
||||||
|
on_value = np.asarray(b"1.0")
|
||||||
|
off_value = np.asarray(b"0.0")
|
||||||
|
|
||||||
|
self._testBothOneHot(indices=indices, depth=depth, on_value=on_value,
|
||||||
|
off_value=off_value, dtype=tf.string, truth=truth)
|
||||||
|
|
||||||
|
on_value = tf.constant(b"1.0")
|
||||||
|
off_value = tf.constant(b"0.0")
|
||||||
|
self._testBothOneHot(indices=indices, depth=depth, on_value=on_value,
|
||||||
|
off_value=off_value, dtype=tf.string, truth=truth)
|
||||||
|
|
||||||
|
on_value = b"1.0"
|
||||||
|
off_value = b"0.0"
|
||||||
|
self._testBothOneHot(indices=indices, depth=depth, on_value=on_value,
|
||||||
|
off_value=off_value, dtype=tf.string, truth=truth)
|
||||||
|
|
||||||
|
def testIndicesTypes(self):
|
||||||
|
tf_types = [tf.int32, tf.int64]
|
||||||
|
np_types = [np.int32, np.int64]
|
||||||
|
for itype in tf_types + np_types:
|
||||||
|
if itype in tf_types:
|
||||||
|
indices = tf.constant([[0, 2, -1, 1],
|
||||||
|
[1, 0, 1, -1]],
|
||||||
|
dtype=itype)
|
||||||
|
elif itype in np_types:
|
||||||
|
indices = np.asarray([[0, 2, -1, 1],
|
||||||
|
[1, 0, 1, -1]],
|
||||||
|
dtype=itype)
|
||||||
|
depth = 3
|
||||||
|
|
||||||
|
on_value = np.asarray(1.0, dtype=np.float32)
|
||||||
|
off_value = np.asarray(-1.0, dtype=np.float32)
|
||||||
|
|
||||||
|
truth = np.asarray(
|
||||||
|
[[[1.0, -1.0, -1.0],
|
||||||
|
[-1.0, -1.0, 1.0],
|
||||||
|
[-1.0, -1.0, -1.0],
|
||||||
|
[-1.0, 1.0, -1.0]],
|
||||||
|
[[-1.0, 1.0, -1.0],
|
||||||
|
[1.0, -1.0, -1.0],
|
||||||
|
[-1.0, 1.0, -1.0],
|
||||||
|
[-1.0, -1.0, -1.0]]],
|
||||||
|
dtype=np.float32)
|
||||||
|
|
||||||
|
# axis == -1
|
||||||
|
self._testBothOneHot(
|
||||||
|
indices=indices,
|
||||||
|
on_value=on_value,
|
||||||
|
off_value=off_value,
|
||||||
|
depth=depth,
|
||||||
|
truth=truth)
|
||||||
|
|
||||||
|
# axis == 1
|
||||||
|
self._testBothOneHot(
|
||||||
|
indices=indices,
|
||||||
|
on_value=on_value,
|
||||||
|
off_value=off_value,
|
||||||
|
depth=depth,
|
||||||
|
axis=1,
|
||||||
|
truth=[truth[0].T, truth[1].T]) # Do not transpose the batch
|
||||||
|
|
||||||
|
def testOnOffMismatchTypeError(self):
|
||||||
|
indices = [0, 1, 2]
|
||||||
|
depth = 3
|
||||||
|
on_value = np.asarray(1.0, np.float64)
|
||||||
|
off_value = np.asarray(0.0, np.float32)
|
||||||
|
|
||||||
|
self._testBothOneHot(
|
||||||
|
indices=indices,
|
||||||
|
depth=depth,
|
||||||
|
on_value=on_value,
|
||||||
|
off_value=off_value,
|
||||||
|
truth=None,
|
||||||
|
raises=TypeError)
|
||||||
|
|
||||||
|
def testDtypeMismatchTypeError(self):
|
||||||
|
indices = [0, 1, 2]
|
||||||
|
depth = 3
|
||||||
|
on_value = np.asarray(1.0, np.float32)
|
||||||
|
off_value = np.asarray(0.0, np.float32)
|
||||||
|
dtype = np.int32
|
||||||
|
|
||||||
|
self._testBothOneHot(
|
||||||
|
indices=indices,
|
||||||
|
depth=depth,
|
||||||
|
on_value=on_value,
|
||||||
|
dtype=dtype,
|
||||||
|
truth=None,
|
||||||
|
raises=TypeError)
|
||||||
|
|
||||||
|
self._testBothOneHot(
|
||||||
|
indices=indices,
|
||||||
|
depth=depth,
|
||||||
|
on_value=off_value,
|
||||||
|
dtype=dtype,
|
||||||
|
truth=None,
|
||||||
|
raises=TypeError)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
@ -1912,14 +1912,21 @@ def _DepthToSpaceShape(op):
|
|||||||
[input_shape[0], height, width, new_depth])]
|
[input_shape[0], height, width, new_depth])]
|
||||||
|
|
||||||
|
|
||||||
def one_hot(indices, depth, on_value=1, off_value=0,
|
def one_hot(indices, depth, on_value=None, off_value=None,
|
||||||
axis=None, dtype=dtypes.float32, name=None):
|
axis=None, dtype=None, name=None):
|
||||||
"""Returns a one-hot tensor.
|
"""Returns a one-hot tensor.
|
||||||
|
|
||||||
The locations represented by indices in `indices` take value `on_value`,
|
The locations represented by indices in `indices` take value `on_value`,
|
||||||
while all other locations take value `off_value`. By default, `on_value` is 1,
|
while all other locations take value `off_value`.
|
||||||
and `off_value` is 0. The type of the output tensor is specified by `dtype`,
|
|
||||||
which defaults to `tf.float32`.
|
`on_value` and `off_value` must have matching data types. If `dtype` is also
|
||||||
|
provided, they must be the same data type as specified by `dtype`.
|
||||||
|
|
||||||
|
If `on_value` is not provided, it will default to the value `1` with type
|
||||||
|
`dtype`
|
||||||
|
|
||||||
|
If `off_value` is not provided, it will default to the value `0` with type
|
||||||
|
`dtype`
|
||||||
|
|
||||||
If the input `indices` is rank `N`, the output will have rank `N+1`. The
|
If the input `indices` is rank `N`, the output will have rank `N+1`. The
|
||||||
new axis is created at dimension `axis` (default: the new axis is appended
|
new axis is created at dimension `axis` (default: the new axis is appended
|
||||||
@ -1941,6 +1948,13 @@ def one_hot(indices, depth, on_value=1, off_value=0,
|
|||||||
depth x batch x features if axis == 0
|
depth x batch x features if axis == 0
|
||||||
```
|
```
|
||||||
|
|
||||||
|
If `dtype` is not provided, it will attempt to assume the data type of
|
||||||
|
`on_value` or `off_value`, if one or both are passed in. If none of
|
||||||
|
`on_value`, `off_value`, or `dtype` are provided, `dtype` will default to the
|
||||||
|
value `tf.float32`
|
||||||
|
|
||||||
|
Note: If a non-numeric data type output is desired (tf.string, tf.bool, etc.),
|
||||||
|
both `on_value` and `off_value` _must_ be provided to `one_hot`
|
||||||
|
|
||||||
Examples
|
Examples
|
||||||
=========
|
=========
|
||||||
@ -1988,6 +2002,22 @@ def one_hot(indices, depth, on_value=1, off_value=0,
|
|||||||
]
|
]
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Using default values for `on_value` and `off_value`:
|
||||||
|
|
||||||
|
```
|
||||||
|
indices = [0, 1, 2]
|
||||||
|
depth = 3
|
||||||
|
```
|
||||||
|
|
||||||
|
The output will be
|
||||||
|
|
||||||
|
```
|
||||||
|
output =
|
||||||
|
[[1., 0., 0.],
|
||||||
|
[0., 1., 0.],
|
||||||
|
[0., 0., 1.]]
|
||||||
|
```
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
indices: A `Tensor` of indices.
|
indices: A `Tensor` of indices.
|
||||||
depth: A scalar defining the depth of the one hot dimension.
|
depth: A scalar defining the depth of the one hot dimension.
|
||||||
@ -2002,20 +2032,50 @@ def one_hot(indices, depth, on_value=1, off_value=0,
|
|||||||
output: The one-hot tensor.
|
output: The one-hot tensor.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: If dtype is `tf.string`
|
TypeError: If dtype of either `on_value` or `off_value` don't match `dtype`
|
||||||
|
TypeError: If dtype of `on_value` and `off_value` don't match one another
|
||||||
"""
|
"""
|
||||||
# Check for bad dtype specification
|
|
||||||
if dtype == dtypes.string:
|
|
||||||
raise TypeError("dtype must be a numeric type")
|
|
||||||
|
|
||||||
with ops.op_scope([indices, depth, on_value, off_value,
|
with ops.op_scope([indices, depth, on_value, off_value,
|
||||||
axis, dtype], name, "one_hot") as name:
|
axis, dtype], name, "one_hot") as name:
|
||||||
on_value = ops.convert_to_tensor(on_value, dtype=dtype, name="on_value")
|
on_exists = on_value is not None
|
||||||
off_value = ops.convert_to_tensor(off_value, dtype=dtype, name="off_value")
|
off_exists = off_value is not None
|
||||||
indices = ops.convert_to_tensor(indices, dtype=dtypes.int64, name="indices")
|
|
||||||
depth = ops.convert_to_tensor(depth, dtype=dtypes.int32, name="depth")
|
on_dtype = ops.convert_to_tensor(on_value).dtype.base_dtype if on_exists \
|
||||||
return gen_array_ops._one_hot(indices, depth, on_value, off_value, axis,
|
else None
|
||||||
name)
|
off_dtype = ops.convert_to_tensor(off_value).dtype.base_dtype if off_exists\
|
||||||
|
else None
|
||||||
|
|
||||||
|
if on_exists or off_exists:
|
||||||
|
if dtype is not None:
|
||||||
|
# Ensure provided on_value and/or off_value match dtype
|
||||||
|
if (on_exists and on_dtype != dtype):
|
||||||
|
raise TypeError("dtype {0} of on_value does not match " \
|
||||||
|
"dtype parameter {1}".format(on_dtype, dtype))
|
||||||
|
if (off_exists and off_dtype != dtype):
|
||||||
|
raise TypeError("dtype {0} of off_value does not match " \
|
||||||
|
"dtype parameter {1}".format(off_dtype, dtype))
|
||||||
|
else:
|
||||||
|
# dtype not provided: automatically assign it
|
||||||
|
dtype = on_dtype if on_exists else off_dtype
|
||||||
|
elif dtype is None:
|
||||||
|
# None of on_value, off_value, or dtype provided. Default dtype to float32
|
||||||
|
dtype = dtypes.float32
|
||||||
|
|
||||||
|
if not on_exists:
|
||||||
|
# on_value not provided: assign to value 1 of type dtype
|
||||||
|
on_value = ops.convert_to_tensor(1, dtype, name="on_value")
|
||||||
|
on_dtype = dtype
|
||||||
|
if not off_exists:
|
||||||
|
# off_value not provided: assign to value 0 of type dtype
|
||||||
|
off_value = ops.convert_to_tensor(0, dtype, name="off_value")
|
||||||
|
off_dtype = dtype
|
||||||
|
|
||||||
|
if on_dtype != off_dtype:
|
||||||
|
raise TypeError("dtype {0} of on_value does not match " \
|
||||||
|
"dtype {1} of off_value".format(on_dtype, off_dtype))
|
||||||
|
|
||||||
|
return gen_array_ops._one_hot(indices, depth, on_value,
|
||||||
|
off_value, axis, name)
|
||||||
|
|
||||||
|
|
||||||
@ops.RegisterShape("OneHot")
|
@ops.RegisterShape("OneHot")
|
||||||
|
Loading…
Reference in New Issue
Block a user