Adds tests for unknown shape in bitcast

Change: 122089210
This commit is contained in:
Olivia Nordquist 2016-05-11 12:49:29 -08:00 committed by TensorFlower Gardener
parent 3451a5903f
commit 136a89a7df
3 changed files with 15 additions and 7 deletions

View File

@ -46,7 +46,8 @@ class BitcastOp : public OpKernel {
OP_REQUIRES(context, in_size_ >= out_size_ || OP_REQUIRES(context, in_size_ >= out_size_ ||
(input_tensor.dims() > 0 && (input_tensor.dims() > 0 &&
input_tensor.dim_size(input_tensor.dims() - 1) == input_tensor.dim_size(input_tensor.dims() - 1) ==
out_size_ / in_size_), out_size_ / in_size_) ||
input_tensor.dim_size(input_tensor.dims()) == -1,
errors::InvalidArgument( errors::InvalidArgument(
"Cannot bitcast from ", DataTypeString(input_data_type_), "Cannot bitcast from ", DataTypeString(input_data_type_),
" to ", DataTypeString(output_data_type_), ": shape ", " to ", DataTypeString(output_data_type_), ": shape ",

View File

@ -68,6 +68,11 @@ class BitcastTest(tf.test.TestCase):
shape = [4] shape = [4]
self._testBitcast(x, datatype, shape) self._testBitcast(x, datatype, shape)
def testUnknown(self):
x = tf.placeholder(tf.float32)
datatype = tf.int8
tf.bitcast(x, datatype, None)
if __name__ == "__main__": if __name__ == "__main__":
tf.test.main() tf.test.main()

View File

@ -1126,17 +1126,19 @@ def _SqueezeShape(op):
def _BitcastShape(op): def _BitcastShape(op):
"""Shape function for Bitcast op.""" """Shape function for Bitcast op."""
input_shape = op.inputs[0].get_shape() input_shape = op.inputs[0].get_shape()
if input_shape == tensor_shape.unknown_shape():
return [tensor_shape.unknown_shape()]
input_type = op.inputs[0].dtype input_type = op.inputs[0].dtype
size_of_input = input_type.size size_of_input = input_type.size
output = dtypes.as_dtype(op.get_attr("type")) output = dtypes.as_dtype(op.get_attr("type"))
size_of_output = output.size size_of_output = output.size
if size_of_input == size_of_output: if size_of_input == size_of_output:
return [tensor_shape.TensorShape(input_shape)] return [input_shape]
else: else:
if size_of_output > size_of_input: if size_of_output > size_of_input:
new_shape = input_shape.as_list() new_shape = input_shape.with_rank_at_least(1).as_list()
last_val = new_shape[-1] last_val = new_shape[-1]
if last_val == (size_of_output // size_of_input): if last_val is None or last_val == (size_of_output // size_of_input):
new_shape = new_shape[:-1] new_shape = new_shape[:-1]
else: else:
raise ValueError( raise ValueError(
@ -1822,7 +1824,7 @@ def one_hot(indices, depth, on_value=1, off_value=0,
The locations represented by indices in `indices` take value `on_value`, The locations represented by indices in `indices` take value `on_value`,
while all other locations take value `off_value`. By default, `on_value` is 1, while all other locations take value `off_value`. By default, `on_value` is 1,
and `off_value` is 0. The type of the output tensor is specified by `dtype`, and `off_value` is 0. The type of the output tensor is specified by `dtype`,
which defaults to `tf.float32`. which defaults to `tf.float32`.
If the input `indices` is rank `N`, the output will have rank `N+1`. The If the input `indices` is rank `N`, the output will have rank `N+1`. The
@ -1918,8 +1920,8 @@ def one_hot(indices, depth, on_value=1, off_value=0,
off_value = ops.convert_to_tensor(off_value, dtype=dtype, name="off_value") off_value = ops.convert_to_tensor(off_value, dtype=dtype, name="off_value")
indices = ops.convert_to_tensor(indices, dtype=dtypes.int64, name="indices") indices = ops.convert_to_tensor(indices, dtype=dtypes.int64, name="indices")
depth = ops.convert_to_tensor(depth, dtype=dtypes.int32, name="depth") depth = ops.convert_to_tensor(depth, dtype=dtypes.int32, name="depth")
return gen_array_ops._one_hot(indices, depth, on_value, return gen_array_ops._one_hot(indices, depth, on_value, off_value, axis,
off_value, axis, name) name)
@ops.RegisterShape("OneHot") @ops.RegisterShape("OneHot")