Merge pull request #2187 from samjabrahams/onehottypes
Fix convert_to_tensor error in tf.one_hot
This commit is contained in:
commit
4be0965c01
@ -39,7 +39,7 @@ namespace tensorflow {
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
template <typename Device, typename T>
|
||||
template <typename Device, typename T, typename TI>
|
||||
class OneHotOp : public OpKernel {
|
||||
public:
|
||||
explicit OneHotOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
@ -88,22 +88,22 @@ class OneHotOp : public OpKernel {
|
||||
// prefix_dim_size == # of elements before the axis
|
||||
// depth_v == # of elements per axis
|
||||
// suffix_dim_size == # of elements after the axis
|
||||
int64 prefix_dim_size = 1;
|
||||
TI prefix_dim_size = 1;
|
||||
for (int i = 0; i < axis; ++i) {
|
||||
prefix_dim_size *= indices_shape.dim_size(i);
|
||||
}
|
||||
int64 suffix_dim_size =
|
||||
TI suffix_dim_size =
|
||||
indices_shape.num_elements() / prefix_dim_size;
|
||||
|
||||
// Split indices into matrix of size prefix_dim_size x suffix_dim_size
|
||||
auto indices_t =
|
||||
indices.shaped<int64, 2>({prefix_dim_size, suffix_dim_size});
|
||||
indices.shaped<TI, 2>({prefix_dim_size, suffix_dim_size});
|
||||
// Split output into 3-Tensor of size:
|
||||
// prefix_dim_size x depth x suffix_dim_size.
|
||||
auto output_t =
|
||||
output->shaped<T, 3>({prefix_dim_size, depth_v, suffix_dim_size});
|
||||
|
||||
functor::OneHot<Device, T>::Compute(ctx->eigen_device<Device>(), indices_t,
|
||||
functor::OneHot<Device, T, TI>::Compute(ctx->eigen_device<Device>(), indices_t,
|
||||
on_value_t, off_value_t, &output_t);
|
||||
}
|
||||
|
||||
@ -113,44 +113,60 @@ class OneHotOp : public OpKernel {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(OneHotOp);
|
||||
};
|
||||
|
||||
#define REGISTER_ONE_HOT(type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("OneHot") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.HostMemory("depth"), \
|
||||
OneHotOp<CPUDevice, type>);
|
||||
#define REGISTER_ONE_HOT_INDEX(type, index_type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("OneHot") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<index_type>("TI") \
|
||||
.TypeConstraint<type>("T") \
|
||||
.HostMemory("depth"), \
|
||||
OneHotOp<CPUDevice, type, index_type>);
|
||||
|
||||
TF_CALL_NUMBER_TYPES(REGISTER_ONE_HOT);
|
||||
#define REGISTER_ONE_HOT(type) \
|
||||
REGISTER_ONE_HOT_INDEX(type, int32); \
|
||||
REGISTER_ONE_HOT_INDEX(type, int64)
|
||||
|
||||
TF_CALL_ALL_TYPES(REGISTER_ONE_HOT);
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
// Forward declarations of the functor specializations for GPU.
|
||||
namespace functor {
|
||||
#define DECLARE_GPU_SPEC(T) \
|
||||
template <> \
|
||||
void OneHot<GPUDevice, T>::Compute( \
|
||||
const GPUDevice& d, const typename TTypes<int64>::ConstMatrix& indices, \
|
||||
const typename TTypes<T>::ConstScalar& on_value, \
|
||||
const typename TTypes<T>::ConstScalar& off_value, \
|
||||
typename TTypes<T, 3>::Tensor* output); \
|
||||
extern template struct OneHot<GPUDevice, T>;
|
||||
#define DECLARE_GPU_SPEC_INDEX(T, TI) \
|
||||
template <> \
|
||||
void OneHot<GPUDevice, T, TI>::Compute( \
|
||||
const GPUDevice& d, const typename TTypes<TI>::ConstMatrix& indices, \
|
||||
const typename TTypes<T>::ConstScalar& on_value, \
|
||||
const typename TTypes<T>::ConstScalar& off_value, \
|
||||
typename TTypes<T, 3>::Tensor* output); \
|
||||
extern template struct OneHot<GPUDevice, T, TI>;
|
||||
|
||||
#define DECLARE_GPU_SPEC(T) \
|
||||
DECLARE_GPU_SPEC_INDEX(T, int32); \
|
||||
DECLARE_GPU_SPEC_INDEX(T, int64); \
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
|
||||
|
||||
#undef DECLARE_GPU_SPEC_INDEX
|
||||
#undef DECLARE_GPU_SPEC
|
||||
|
||||
} // namespace functor
|
||||
|
||||
// Registration of the GPU implementations.
|
||||
#define REGISTER_ONE_HOT_GPU(type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("OneHot") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.HostMemory("depth"), \
|
||||
OneHotOp<GPUDevice, type>);
|
||||
#define REGISTER_ONE_HOT_GPU_INDEX(type, index_type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("OneHot") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<index_type>("TI") \
|
||||
.TypeConstraint<type>("T") \
|
||||
.HostMemory("depth"), \
|
||||
OneHotOp<GPUDevice, type, index_type>);
|
||||
|
||||
#define REGISTER_ONE_HOT_GPU(type) \
|
||||
REGISTER_ONE_HOT_GPU_INDEX(type, int32); \
|
||||
REGISTER_ONE_HOT_GPU_INDEX(type, int64); \
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_ONE_HOT_GPU);
|
||||
|
||||
#undef REGISTER_ONE_HOT_GPU_INDEX
|
||||
#undef REGISTER_ONE_HOT_GPU
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
|
@ -27,11 +27,11 @@ namespace tensorflow {
|
||||
|
||||
namespace generator {
|
||||
|
||||
template <typename T>
|
||||
template <typename T, typename TI>
|
||||
class OneGenerator {
|
||||
public:
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
|
||||
OneGenerator(const TTypes<int64>::ConstMatrix& indices,
|
||||
OneGenerator(const typename TTypes<TI>::ConstMatrix& indices,
|
||||
const typename TTypes<T>::ConstScalar& on_value,
|
||||
const typename TTypes<T>::ConstScalar& off_value)
|
||||
: indices_(indices), on_value_(on_value), off_value_(off_value) {}
|
||||
@ -44,7 +44,7 @@ class OneGenerator {
|
||||
}
|
||||
|
||||
private:
|
||||
const TTypes<int64>::ConstMatrix indices_;
|
||||
const typename TTypes<TI>::ConstMatrix indices_;
|
||||
const typename TTypes<T>::ConstScalar on_value_;
|
||||
const typename TTypes<T>::ConstScalar off_value_;
|
||||
};
|
||||
@ -53,14 +53,14 @@ class OneGenerator {
|
||||
|
||||
namespace functor {
|
||||
|
||||
template <typename Device, typename T>
|
||||
template <typename Device, typename T, typename TI>
|
||||
struct OneHot {
|
||||
EIGEN_ALWAYS_INLINE static void Compute(
|
||||
const Device& d, const TTypes<int64>::ConstMatrix& indices,
|
||||
const Device& d, const typename TTypes<TI>::ConstMatrix& indices,
|
||||
const typename TTypes<T>::ConstScalar& on_value,
|
||||
const typename TTypes<T>::ConstScalar& off_value,
|
||||
typename TTypes<T, 3>::Tensor* output) {
|
||||
generator::OneGenerator<T> generator(indices, on_value, off_value);
|
||||
generator::OneGenerator<T, TI> generator(indices, on_value, off_value);
|
||||
output->device(d) = output->generate(generator);
|
||||
}
|
||||
};
|
||||
|
@ -21,17 +21,24 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/kernels/one_hot_op.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
#define DEFINE_GPU_SPEC(T) \
|
||||
template class generator::OneGenerator<T>; \
|
||||
template struct functor::OneHot<GPUDevice, T>;
|
||||
#define DEFINE_GPU_SPEC_INDEX(T, TI) \
|
||||
template class generator::OneGenerator<T, TI>; \
|
||||
template struct functor::OneHot<GPUDevice, T, TI>;
|
||||
|
||||
#define DEFINE_GPU_SPEC(T) \
|
||||
DEFINE_GPU_SPEC_INDEX(T, int32); \
|
||||
DEFINE_GPU_SPEC_INDEX(T, int64)
|
||||
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPEC);
|
||||
|
||||
#undef DEFINE_GPU_SPEC_INDEX
|
||||
#undef DEFINE_GPU_SPEC
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
@ -1733,13 +1733,14 @@ dimension be equal to sizeof(`type`)/sizeof(`T`). The shape then goes from
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("OneHot")
|
||||
.Input("indices: int64")
|
||||
.Input("indices: TI")
|
||||
.Input("depth: int32")
|
||||
.Input("on_value: T")
|
||||
.Input("off_value: T")
|
||||
.Attr("axis: int = -1")
|
||||
.Output("output: T")
|
||||
.Attr("T: type")
|
||||
.Attr("TI: {int32, int64} = DT_INT64")
|
||||
.Doc(R"doc(
|
||||
Returns a one-hot tensor.
|
||||
|
||||
|
@ -34,7 +34,7 @@ class OneHotTest(tf.test.TestCase):
|
||||
ans = tf.one_hot(**inputs)
|
||||
if expected_err_re is None:
|
||||
tf_ans = ans.eval()
|
||||
self.assertAllClose(tf_ans, truth, atol=1e-10)
|
||||
self.assertAllEqual(tf_ans, truth)
|
||||
self.assertEqual(tf_ans.shape, ans.get_shape())
|
||||
else:
|
||||
with self.assertRaisesOpError(expected_err_re):
|
||||
@ -77,7 +77,7 @@ class OneHotTest(tf.test.TestCase):
|
||||
truth=truth.T) # Output is transpose version in this case
|
||||
|
||||
def _testDefaultBasic(self, dtype):
|
||||
indices = np.asarray([0, 2, -1, 1], dtype=dtype)
|
||||
indices = np.asarray([0, 2, -1, 1], dtype=np.int64)
|
||||
depth = 3
|
||||
|
||||
truth = np.asarray(
|
||||
@ -89,18 +89,16 @@ class OneHotTest(tf.test.TestCase):
|
||||
|
||||
# axis == -1
|
||||
self._testBothOneHot(
|
||||
indices=indices,
|
||||
depth=depth,
|
||||
dtype=dtype,
|
||||
truth=truth)
|
||||
indices=indices,
|
||||
depth=depth,
|
||||
truth=truth)
|
||||
|
||||
# axis == 0
|
||||
self._testBothOneHot(
|
||||
indices=indices,
|
||||
depth=depth,
|
||||
axis=0,
|
||||
dtype=dtype,
|
||||
truth=truth.T) # Output is transpose version in this case
|
||||
indices=indices,
|
||||
depth=depth,
|
||||
axis=0,
|
||||
truth=truth.T) # Output is transpose version in this case
|
||||
|
||||
def testFloatBasic(self):
|
||||
self._testBasic(np.float32)
|
||||
@ -163,7 +161,7 @@ class OneHotTest(tf.test.TestCase):
|
||||
def _testDefaultValuesBatch(self, dtype):
|
||||
indices = np.asarray([[0, 2, -1, 1],
|
||||
[1, 0, 1, -1]],
|
||||
dtype=dtype)
|
||||
dtype=np.int64)
|
||||
depth = 3
|
||||
|
||||
truth = np.asarray(
|
||||
@ -192,10 +190,10 @@ class OneHotTest(tf.test.TestCase):
|
||||
dtype=dtype,
|
||||
truth=[truth[0].T, truth[1].T]) # Do not transpose the batch
|
||||
|
||||
def _testTypeBatch(self, dtype):
|
||||
def _testValueTypeBatch(self, dtype):
|
||||
indices = np.asarray([[0, 2, -1, 1],
|
||||
[1, 0, 1, -1]],
|
||||
dtype=dtype)
|
||||
dtype=np.int64)
|
||||
depth = 3
|
||||
|
||||
on_value = np.asarray(1.0, dtype=dtype)
|
||||
@ -218,6 +216,7 @@ class OneHotTest(tf.test.TestCase):
|
||||
on_value=on_value,
|
||||
off_value=off_value,
|
||||
depth=depth,
|
||||
dtype=dtype,
|
||||
truth=truth)
|
||||
|
||||
# axis == 1
|
||||
@ -227,32 +226,33 @@ class OneHotTest(tf.test.TestCase):
|
||||
off_value=off_value,
|
||||
depth=depth,
|
||||
axis=1,
|
||||
dtype=dtype,
|
||||
truth=[truth[0].T, truth[1].T]) # Do not transpose the batch
|
||||
|
||||
def testFloatBatch(self):
|
||||
self._testBatch(np.float32)
|
||||
self._testDefaultValuesBatch(np.float32)
|
||||
self._testTypeBatch(np.float32)
|
||||
self._testValueTypeBatch(np.float32)
|
||||
|
||||
def testDoubleBatch(self):
|
||||
self._testBatch(np.float64)
|
||||
self._testDefaultValuesBatch(np.float64)
|
||||
self._testTypeBatch(np.float64)
|
||||
self._testValueTypeBatch(np.float64)
|
||||
|
||||
def testInt32Batch(self):
|
||||
self._testBatch(np.int32)
|
||||
self._testDefaultValuesBatch(np.int32)
|
||||
self._testTypeBatch(np.int32)
|
||||
self._testValueTypeBatch(np.int32)
|
||||
|
||||
def testInt64Batch(self):
|
||||
self._testBatch(np.int64)
|
||||
self._testDefaultValuesBatch(np.int64)
|
||||
self._testTypeBatch(np.int64)
|
||||
self._testValueTypeBatch(np.int64)
|
||||
|
||||
def testComplexBatch(self):
|
||||
self._testBatch(np.complex64)
|
||||
self._testDefaultValuesBatch(np.complex64)
|
||||
self._testTypeBatch(np.complex64)
|
||||
# self._testDefaultValuesBatch(np.complex64)
|
||||
self._testValueTypeBatch(np.complex64)
|
||||
|
||||
def testSimpleCases(self):
|
||||
indices = [0,1,2]
|
||||
@ -284,17 +284,6 @@ class OneHotTest(tf.test.TestCase):
|
||||
self._testBothOneHot(indices=indices, depth=depth, on_value=1,
|
||||
off_value=-1, truth=truth)
|
||||
|
||||
def testStringDtypeError(self):
|
||||
indices = [0,1,2]
|
||||
depth = 3
|
||||
truth = np.asarray(
|
||||
[[1.0, 0.0, 0.0],
|
||||
[0.0, 1.0, 0.0],
|
||||
[0.0, 0.0, 1.0]])
|
||||
self._testBothOneHot(indices=indices, depth=depth, on_value=1,
|
||||
off_value=-1, dtype=tf.string, raises=TypeError,
|
||||
truth=truth)
|
||||
|
||||
def testSingleValueGiven(self):
|
||||
# Only on_value provided
|
||||
indices = [0,1,2]
|
||||
@ -317,5 +306,111 @@ class OneHotTest(tf.test.TestCase):
|
||||
self._testBothOneHot(indices=indices, depth=depth,
|
||||
off_value=0.0, truth=truth)
|
||||
|
||||
def testString(self):
|
||||
indices = [0,1,2]
|
||||
depth = 3
|
||||
truth = np.asarray(
|
||||
[[b"1.0", b"0.0", b"0.0"],
|
||||
[b"0.0", b"1.0", b"0.0"],
|
||||
[b"0.0", b"0.0", b"1.0"]])
|
||||
on_value = np.asarray(b"1.0")
|
||||
off_value = np.asarray(b"0.0")
|
||||
|
||||
self._testBothOneHot(indices=indices, depth=depth, on_value=on_value,
|
||||
off_value=off_value, dtype=tf.string, truth=truth)
|
||||
|
||||
on_value = tf.constant(b"1.0")
|
||||
off_value = tf.constant(b"0.0")
|
||||
self._testBothOneHot(indices=indices, depth=depth, on_value=on_value,
|
||||
off_value=off_value, dtype=tf.string, truth=truth)
|
||||
|
||||
on_value = b"1.0"
|
||||
off_value = b"0.0"
|
||||
self._testBothOneHot(indices=indices, depth=depth, on_value=on_value,
|
||||
off_value=off_value, dtype=tf.string, truth=truth)
|
||||
|
||||
def testIndicesTypes(self):
|
||||
tf_types = [tf.int32, tf.int64]
|
||||
np_types = [np.int32, np.int64]
|
||||
for itype in tf_types + np_types:
|
||||
if itype in tf_types:
|
||||
indices = tf.constant([[0, 2, -1, 1],
|
||||
[1, 0, 1, -1]],
|
||||
dtype=itype)
|
||||
elif itype in np_types:
|
||||
indices = np.asarray([[0, 2, -1, 1],
|
||||
[1, 0, 1, -1]],
|
||||
dtype=itype)
|
||||
depth = 3
|
||||
|
||||
on_value = np.asarray(1.0, dtype=np.float32)
|
||||
off_value = np.asarray(-1.0, dtype=np.float32)
|
||||
|
||||
truth = np.asarray(
|
||||
[[[1.0, -1.0, -1.0],
|
||||
[-1.0, -1.0, 1.0],
|
||||
[-1.0, -1.0, -1.0],
|
||||
[-1.0, 1.0, -1.0]],
|
||||
[[-1.0, 1.0, -1.0],
|
||||
[1.0, -1.0, -1.0],
|
||||
[-1.0, 1.0, -1.0],
|
||||
[-1.0, -1.0, -1.0]]],
|
||||
dtype=np.float32)
|
||||
|
||||
# axis == -1
|
||||
self._testBothOneHot(
|
||||
indices=indices,
|
||||
on_value=on_value,
|
||||
off_value=off_value,
|
||||
depth=depth,
|
||||
truth=truth)
|
||||
|
||||
# axis == 1
|
||||
self._testBothOneHot(
|
||||
indices=indices,
|
||||
on_value=on_value,
|
||||
off_value=off_value,
|
||||
depth=depth,
|
||||
axis=1,
|
||||
truth=[truth[0].T, truth[1].T]) # Do not transpose the batch
|
||||
|
||||
def testOnOffMismatchTypeError(self):
|
||||
indices = [0, 1, 2]
|
||||
depth = 3
|
||||
on_value = np.asarray(1.0, np.float64)
|
||||
off_value = np.asarray(0.0, np.float32)
|
||||
|
||||
self._testBothOneHot(
|
||||
indices=indices,
|
||||
depth=depth,
|
||||
on_value=on_value,
|
||||
off_value=off_value,
|
||||
truth=None,
|
||||
raises=TypeError)
|
||||
|
||||
def testDtypeMismatchTypeError(self):
|
||||
indices = [0, 1, 2]
|
||||
depth = 3
|
||||
on_value = np.asarray(1.0, np.float32)
|
||||
off_value = np.asarray(0.0, np.float32)
|
||||
dtype = np.int32
|
||||
|
||||
self._testBothOneHot(
|
||||
indices=indices,
|
||||
depth=depth,
|
||||
on_value=on_value,
|
||||
dtype=dtype,
|
||||
truth=None,
|
||||
raises=TypeError)
|
||||
|
||||
self._testBothOneHot(
|
||||
indices=indices,
|
||||
depth=depth,
|
||||
on_value=off_value,
|
||||
dtype=dtype,
|
||||
truth=None,
|
||||
raises=TypeError)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
||||
|
@ -1912,14 +1912,21 @@ def _DepthToSpaceShape(op):
|
||||
[input_shape[0], height, width, new_depth])]
|
||||
|
||||
|
||||
def one_hot(indices, depth, on_value=1, off_value=0,
|
||||
axis=None, dtype=dtypes.float32, name=None):
|
||||
def one_hot(indices, depth, on_value=None, off_value=None,
|
||||
axis=None, dtype=None, name=None):
|
||||
"""Returns a one-hot tensor.
|
||||
|
||||
The locations represented by indices in `indices` take value `on_value`,
|
||||
while all other locations take value `off_value`. By default, `on_value` is 1,
|
||||
and `off_value` is 0. The type of the output tensor is specified by `dtype`,
|
||||
which defaults to `tf.float32`.
|
||||
while all other locations take value `off_value`.
|
||||
|
||||
`on_value` and `off_value` must have matching data types. If `dtype` is also
|
||||
provided, they must be the same data type as specified by `dtype`.
|
||||
|
||||
If `on_value` is not provided, it will default to the value `1` with type
|
||||
`dtype`
|
||||
|
||||
If `off_value` is not provided, it will default to the value `0` with type
|
||||
`dtype`
|
||||
|
||||
If the input `indices` is rank `N`, the output will have rank `N+1`. The
|
||||
new axis is created at dimension `axis` (default: the new axis is appended
|
||||
@ -1941,6 +1948,13 @@ def one_hot(indices, depth, on_value=1, off_value=0,
|
||||
depth x batch x features if axis == 0
|
||||
```
|
||||
|
||||
If `dtype` is not provided, it will attempt to assume the data type of
|
||||
`on_value` or `off_value`, if one or both are passed in. If none of
|
||||
`on_value`, `off_value`, or `dtype` are provided, `dtype` will default to the
|
||||
value `tf.float32`
|
||||
|
||||
Note: If a non-numeric data type output is desired (tf.string, tf.bool, etc.),
|
||||
both `on_value` and `off_value` _must_ be provided to `one_hot`
|
||||
|
||||
Examples
|
||||
=========
|
||||
@ -1988,6 +2002,22 @@ def one_hot(indices, depth, on_value=1, off_value=0,
|
||||
]
|
||||
```
|
||||
|
||||
Using default values for `on_value` and `off_value`:
|
||||
|
||||
```
|
||||
indices = [0, 1, 2]
|
||||
depth = 3
|
||||
```
|
||||
|
||||
The output will be
|
||||
|
||||
```
|
||||
output =
|
||||
[[1., 0., 0.],
|
||||
[0., 1., 0.],
|
||||
[0., 0., 1.]]
|
||||
```
|
||||
|
||||
Args:
|
||||
indices: A `Tensor` of indices.
|
||||
depth: A scalar defining the depth of the one hot dimension.
|
||||
@ -2002,20 +2032,50 @@ def one_hot(indices, depth, on_value=1, off_value=0,
|
||||
output: The one-hot tensor.
|
||||
|
||||
Raises:
|
||||
TypeError: If dtype is `tf.string`
|
||||
TypeError: If dtype of either `on_value` or `off_value` don't match `dtype`
|
||||
TypeError: If dtype of `on_value` and `off_value` don't match one another
|
||||
"""
|
||||
# Check for bad dtype specification
|
||||
if dtype == dtypes.string:
|
||||
raise TypeError("dtype must be a numeric type")
|
||||
|
||||
with ops.op_scope([indices, depth, on_value, off_value,
|
||||
axis, dtype], name, "one_hot") as name:
|
||||
on_value = ops.convert_to_tensor(on_value, dtype=dtype, name="on_value")
|
||||
off_value = ops.convert_to_tensor(off_value, dtype=dtype, name="off_value")
|
||||
indices = ops.convert_to_tensor(indices, dtype=dtypes.int64, name="indices")
|
||||
depth = ops.convert_to_tensor(depth, dtype=dtypes.int32, name="depth")
|
||||
return gen_array_ops._one_hot(indices, depth, on_value, off_value, axis,
|
||||
name)
|
||||
on_exists = on_value is not None
|
||||
off_exists = off_value is not None
|
||||
|
||||
on_dtype = ops.convert_to_tensor(on_value).dtype.base_dtype if on_exists \
|
||||
else None
|
||||
off_dtype = ops.convert_to_tensor(off_value).dtype.base_dtype if off_exists\
|
||||
else None
|
||||
|
||||
if on_exists or off_exists:
|
||||
if dtype is not None:
|
||||
# Ensure provided on_value and/or off_value match dtype
|
||||
if (on_exists and on_dtype != dtype):
|
||||
raise TypeError("dtype {0} of on_value does not match " \
|
||||
"dtype parameter {1}".format(on_dtype, dtype))
|
||||
if (off_exists and off_dtype != dtype):
|
||||
raise TypeError("dtype {0} of off_value does not match " \
|
||||
"dtype parameter {1}".format(off_dtype, dtype))
|
||||
else:
|
||||
# dtype not provided: automatically assign it
|
||||
dtype = on_dtype if on_exists else off_dtype
|
||||
elif dtype is None:
|
||||
# None of on_value, off_value, or dtype provided. Default dtype to float32
|
||||
dtype = dtypes.float32
|
||||
|
||||
if not on_exists:
|
||||
# on_value not provided: assign to value 1 of type dtype
|
||||
on_value = ops.convert_to_tensor(1, dtype, name="on_value")
|
||||
on_dtype = dtype
|
||||
if not off_exists:
|
||||
# off_value not provided: assign to value 0 of type dtype
|
||||
off_value = ops.convert_to_tensor(0, dtype, name="off_value")
|
||||
off_dtype = dtype
|
||||
|
||||
if on_dtype != off_dtype:
|
||||
raise TypeError("dtype {0} of on_value does not match " \
|
||||
"dtype {1} of off_value".format(on_dtype, off_dtype))
|
||||
|
||||
return gen_array_ops._one_hot(indices, depth, on_value,
|
||||
off_value, axis, name)
|
||||
|
||||
|
||||
@ops.RegisterShape("OneHot")
|
||||
|
Loading…
Reference in New Issue
Block a user