Adds tests for unknown shape in bitcast
Change: 122089210
This commit is contained in:
parent
3451a5903f
commit
136a89a7df
@ -46,7 +46,8 @@ class BitcastOp : public OpKernel {
|
|||||||
OP_REQUIRES(context, in_size_ >= out_size_ ||
|
OP_REQUIRES(context, in_size_ >= out_size_ ||
|
||||||
(input_tensor.dims() > 0 &&
|
(input_tensor.dims() > 0 &&
|
||||||
input_tensor.dim_size(input_tensor.dims() - 1) ==
|
input_tensor.dim_size(input_tensor.dims() - 1) ==
|
||||||
out_size_ / in_size_),
|
out_size_ / in_size_) ||
|
||||||
|
input_tensor.dim_size(input_tensor.dims()) == -1,
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
"Cannot bitcast from ", DataTypeString(input_data_type_),
|
"Cannot bitcast from ", DataTypeString(input_data_type_),
|
||||||
" to ", DataTypeString(output_data_type_), ": shape ",
|
" to ", DataTypeString(output_data_type_), ": shape ",
|
||||||
|
@ -68,6 +68,11 @@ class BitcastTest(tf.test.TestCase):
|
|||||||
shape = [4]
|
shape = [4]
|
||||||
self._testBitcast(x, datatype, shape)
|
self._testBitcast(x, datatype, shape)
|
||||||
|
|
||||||
|
def testUnknown(self):
|
||||||
|
x = tf.placeholder(tf.float32)
|
||||||
|
datatype = tf.int8
|
||||||
|
tf.bitcast(x, datatype, None)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
tf.test.main()
|
tf.test.main()
|
||||||
|
@ -1126,17 +1126,19 @@ def _SqueezeShape(op):
|
|||||||
def _BitcastShape(op):
|
def _BitcastShape(op):
|
||||||
"""Shape function for Bitcast op."""
|
"""Shape function for Bitcast op."""
|
||||||
input_shape = op.inputs[0].get_shape()
|
input_shape = op.inputs[0].get_shape()
|
||||||
|
if input_shape == tensor_shape.unknown_shape():
|
||||||
|
return [tensor_shape.unknown_shape()]
|
||||||
input_type = op.inputs[0].dtype
|
input_type = op.inputs[0].dtype
|
||||||
size_of_input = input_type.size
|
size_of_input = input_type.size
|
||||||
output = dtypes.as_dtype(op.get_attr("type"))
|
output = dtypes.as_dtype(op.get_attr("type"))
|
||||||
size_of_output = output.size
|
size_of_output = output.size
|
||||||
if size_of_input == size_of_output:
|
if size_of_input == size_of_output:
|
||||||
return [tensor_shape.TensorShape(input_shape)]
|
return [input_shape]
|
||||||
else:
|
else:
|
||||||
if size_of_output > size_of_input:
|
if size_of_output > size_of_input:
|
||||||
new_shape = input_shape.as_list()
|
new_shape = input_shape.with_rank_at_least(1).as_list()
|
||||||
last_val = new_shape[-1]
|
last_val = new_shape[-1]
|
||||||
if last_val == (size_of_output // size_of_input):
|
if last_val is None or last_val == (size_of_output // size_of_input):
|
||||||
new_shape = new_shape[:-1]
|
new_shape = new_shape[:-1]
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -1822,7 +1824,7 @@ def one_hot(indices, depth, on_value=1, off_value=0,
|
|||||||
|
|
||||||
The locations represented by indices in `indices` take value `on_value`,
|
The locations represented by indices in `indices` take value `on_value`,
|
||||||
while all other locations take value `off_value`. By default, `on_value` is 1,
|
while all other locations take value `off_value`. By default, `on_value` is 1,
|
||||||
and `off_value` is 0. The type of the output tensor is specified by `dtype`,
|
and `off_value` is 0. The type of the output tensor is specified by `dtype`,
|
||||||
which defaults to `tf.float32`.
|
which defaults to `tf.float32`.
|
||||||
|
|
||||||
If the input `indices` is rank `N`, the output will have rank `N+1`. The
|
If the input `indices` is rank `N`, the output will have rank `N+1`. The
|
||||||
@ -1918,8 +1920,8 @@ def one_hot(indices, depth, on_value=1, off_value=0,
|
|||||||
off_value = ops.convert_to_tensor(off_value, dtype=dtype, name="off_value")
|
off_value = ops.convert_to_tensor(off_value, dtype=dtype, name="off_value")
|
||||||
indices = ops.convert_to_tensor(indices, dtype=dtypes.int64, name="indices")
|
indices = ops.convert_to_tensor(indices, dtype=dtypes.int64, name="indices")
|
||||||
depth = ops.convert_to_tensor(depth, dtype=dtypes.int32, name="depth")
|
depth = ops.convert_to_tensor(depth, dtype=dtypes.int32, name="depth")
|
||||||
return gen_array_ops._one_hot(indices, depth, on_value,
|
return gen_array_ops._one_hot(indices, depth, on_value, off_value, axis,
|
||||||
off_value, axis, name)
|
name)
|
||||||
|
|
||||||
|
|
||||||
@ops.RegisterShape("OneHot")
|
@ops.RegisterShape("OneHot")
|
||||||
|
Loading…
Reference in New Issue
Block a user