Adds tests for unknown shape in bitcast
Change: 122089210
This commit is contained in:
parent
3451a5903f
commit
136a89a7df
@ -46,7 +46,8 @@ class BitcastOp : public OpKernel {
|
||||
OP_REQUIRES(context, in_size_ >= out_size_ ||
|
||||
(input_tensor.dims() > 0 &&
|
||||
input_tensor.dim_size(input_tensor.dims() - 1) ==
|
||||
out_size_ / in_size_),
|
||||
out_size_ / in_size_) ||
|
||||
input_tensor.dim_size(input_tensor.dims()) == -1,
|
||||
errors::InvalidArgument(
|
||||
"Cannot bitcast from ", DataTypeString(input_data_type_),
|
||||
" to ", DataTypeString(output_data_type_), ": shape ",
|
||||
|
@ -68,6 +68,11 @@ class BitcastTest(tf.test.TestCase):
|
||||
shape = [4]
|
||||
self._testBitcast(x, datatype, shape)
|
||||
|
||||
def testUnknown(self):
|
||||
x = tf.placeholder(tf.float32)
|
||||
datatype = tf.int8
|
||||
tf.bitcast(x, datatype, None)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
||||
|
@ -1126,17 +1126,19 @@ def _SqueezeShape(op):
|
||||
def _BitcastShape(op):
|
||||
"""Shape function for Bitcast op."""
|
||||
input_shape = op.inputs[0].get_shape()
|
||||
if input_shape == tensor_shape.unknown_shape():
|
||||
return [tensor_shape.unknown_shape()]
|
||||
input_type = op.inputs[0].dtype
|
||||
size_of_input = input_type.size
|
||||
output = dtypes.as_dtype(op.get_attr("type"))
|
||||
size_of_output = output.size
|
||||
if size_of_input == size_of_output:
|
||||
return [tensor_shape.TensorShape(input_shape)]
|
||||
return [input_shape]
|
||||
else:
|
||||
if size_of_output > size_of_input:
|
||||
new_shape = input_shape.as_list()
|
||||
new_shape = input_shape.with_rank_at_least(1).as_list()
|
||||
last_val = new_shape[-1]
|
||||
if last_val == (size_of_output // size_of_input):
|
||||
if last_val is None or last_val == (size_of_output // size_of_input):
|
||||
new_shape = new_shape[:-1]
|
||||
else:
|
||||
raise ValueError(
|
||||
@ -1918,8 +1920,8 @@ def one_hot(indices, depth, on_value=1, off_value=0,
|
||||
off_value = ops.convert_to_tensor(off_value, dtype=dtype, name="off_value")
|
||||
indices = ops.convert_to_tensor(indices, dtype=dtypes.int64, name="indices")
|
||||
depth = ops.convert_to_tensor(depth, dtype=dtypes.int32, name="depth")
|
||||
return gen_array_ops._one_hot(indices, depth, on_value,
|
||||
off_value, axis, name)
|
||||
return gen_array_ops._one_hot(indices, depth, on_value, off_value, axis,
|
||||
name)
|
||||
|
||||
|
||||
@ops.RegisterShape("OneHot")
|
||||
|
Loading…
Reference in New Issue
Block a user