Switch several ops in array_ops.py to use C++ shape functions.
Change C++ shape function for ExpandDims to be more permissive - it now allows 'dim' to be any tensor with 1 element, although that is not currently converted to use C++ because of a separate issue to fix first (later change). Change C++ shape functions for SpaceToBatch and BatchToSpace to output rank-4 unknown shapes. Change: 132578764
This commit is contained in:
parent
71e3186fd3
commit
9205b55c6b
@ -270,14 +270,12 @@ class InferenceContext {
|
||||
// Returns in <*out> a sub-shape of <s> with dimensions [start:].
|
||||
// <start> can be negative to index from the end of the shape. If <start> >
|
||||
// rank of <s>, then an empty subshape is returned.
|
||||
// Returns an error if the rank of <s> is < <start>.
|
||||
Status Subshape(ShapeHandle s, int64 start,
|
||||
ShapeHandle* out) TF_MUST_USE_RESULT;
|
||||
|
||||
// Returns in <*out> a sub-shape of <s>, with dimensions [start:end].
|
||||
// <start> and <end> can be negative, to index from the end of the shape.
|
||||
// <start> and <end> are set to the rank of <s> if > rank of <s>.
|
||||
// Returns an error if the rank of <s> is insufficient.
|
||||
Status Subshape(ShapeHandle s, int64 start, int64 end,
|
||||
ShapeHandle* out) TF_MUST_USE_RESULT;
|
||||
|
||||
|
@ -2281,10 +2281,12 @@ REGISTER_OP("ExpandDims")
|
||||
.Attr("Tdim: {int32, int64} = DT_INT32")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
ShapeHandle input = c->input(0);
|
||||
ShapeHandle expand_dim;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &expand_dim));
|
||||
|
||||
const Tensor* dim_t = c->input_tensor(1);
|
||||
if (dim_t != nullptr && dim_t->NumElements() != 1) {
|
||||
return errors::InvalidArgument(
|
||||
"'dim' input must be a tensor with a single value");
|
||||
}
|
||||
if (dim_t == nullptr || !c->RankKnown(input)) {
|
||||
c->set_output(0, c->UnknownShape());
|
||||
return Status::OK();
|
||||
@ -2516,7 +2518,8 @@ REGISTER_OP("SpaceToBatch")
|
||||
DimensionHandle pad1_dim = c->Dim(paddings, 1);
|
||||
|
||||
if (!c->ValueKnown(pad0_dim) || !c->ValueKnown(pad1_dim)) {
|
||||
return shape_inference::UnknownShape(c);
|
||||
c->set_output(0, c->UnknownShapeOfRank(4));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int64 pad0 = c->Value(pad0_dim);
|
||||
@ -2694,7 +2697,8 @@ REGISTER_OP("BatchToSpace")
|
||||
DimensionHandle crops1_dim = c->Dim(crops, 1);
|
||||
|
||||
if (!c->ValueKnown(crops0_dim) || !c->ValueKnown(crops1_dim)) {
|
||||
return shape_inference::UnknownShape(c);
|
||||
c->set_output(0, c->UnknownShapeOfRank(4));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int64 crops0 = c->Value(crops0_dim);
|
||||
|
@ -385,7 +385,6 @@ TEST(ArrayOpsTest, ExpandDims_ShapeFn) {
|
||||
|
||||
// With unknown dim tensor value, output is unknown.
|
||||
INFER_OK(op, "?;?", "?");
|
||||
INFER_ERROR("Shape must be rank 0 but is rank 1", op, "?;[1]");
|
||||
Tensor dim_t;
|
||||
op.input_tensors[1] = &dim_t;
|
||||
|
||||
@ -399,11 +398,21 @@ TEST(ArrayOpsTest, ExpandDims_ShapeFn) {
|
||||
dim_t = test::AsScalar<int32>(idx);
|
||||
INFER_OK(op, "?;?", "?");
|
||||
INFER_OK(op, "[5,?,7];?", "[d0_0,1,d0_1,d0_2]");
|
||||
|
||||
// Repeat with int64.
|
||||
dim_t = test::AsScalar<int64>(idx);
|
||||
INFER_OK(op, "?;?", "?");
|
||||
INFER_OK(op, "[5,?,7];?", "[d0_0,1,d0_1,d0_2]");
|
||||
}
|
||||
for (int32 idx : {2, -2}) {
|
||||
dim_t = test::AsScalar<int32>(idx);
|
||||
INFER_OK(op, "?;?", "?");
|
||||
INFER_OK(op, "[5,?,7];?", "[d0_0,d0_1,1,d0_2]");
|
||||
|
||||
// Repeat with int64.
|
||||
dim_t = test::AsScalar<int64>(idx);
|
||||
INFER_OK(op, "?;?", "?");
|
||||
INFER_OK(op, "[5,?,7];?", "[d0_0,d0_1,1,d0_2]");
|
||||
}
|
||||
|
||||
for (int32 idx : {3, -1}) {
|
||||
@ -411,7 +420,26 @@ TEST(ArrayOpsTest, ExpandDims_ShapeFn) {
|
||||
dim_t = test::AsScalar<int32>(idx);
|
||||
INFER_OK(op, "?;?", "?");
|
||||
INFER_OK(op, "[5,?,7];?", "[d0_0,d0_1,d0_2,1]");
|
||||
|
||||
// Repeat with int64.
|
||||
dim_t = test::AsScalar<int64>(idx);
|
||||
INFER_OK(op, "?;?", "?");
|
||||
INFER_OK(op, "[5,?,7];?", "[d0_0,d0_1,d0_2,1]");
|
||||
}
|
||||
|
||||
// Expand using an input vector tensor.
|
||||
std::vector<int32> dims;
|
||||
dims.push_back(0);
|
||||
dim_t = test::AsTensor<int32>(dims);
|
||||
INFER_OK(op, "?;?", "?");
|
||||
INFER_OK(op, "[5,?,7];?", "[1,d0_0,d0_1,d0_2]");
|
||||
|
||||
// Expand using too many input elements.
|
||||
dims.push_back(1);
|
||||
dim_t = test::AsTensor<int32>(dims);
|
||||
INFER_ERROR("'dim' input must be a tensor with a single", op, "?;?");
|
||||
INFER_ERROR("'dim' input must be a tensor with a single", op, "[5,6,7];?");
|
||||
|
||||
// Examples from ExpandDims doc.
|
||||
dim_t = test::AsScalar<int32>(0);
|
||||
INFER_OK(op, "[2];[]", "[1,d0_0]");
|
||||
@ -928,8 +956,8 @@ TEST(ArrayOpsTest, SpaceToBatch_ShapeFn) {
|
||||
// Paddings not known, but batch size can be computed.
|
||||
INFER_OK(op, "[1,10,10,3];[2,2]", "[4,?,?,d0_3]");
|
||||
|
||||
// Unknown paddings means unknown shape
|
||||
INFER_OK(op, "[1,10,10,3];?", "?");
|
||||
// Unknown paddings means unknown shape of rank 4.
|
||||
INFER_OK(op, "[1,10,10,3];?", "[?,?,?,?]");
|
||||
|
||||
// Paddings not correct shape
|
||||
INFER_ERROR("Shape must be rank 2 but is rank 1", op, "[1,10,10,3];[4]");
|
||||
@ -971,7 +999,7 @@ TEST(ArrayOpsTest, BatchToSpace_ShapeFn) {
|
||||
"[5,8,8,3];[2,2]");
|
||||
|
||||
// Unknown croppings means unknown shape
|
||||
INFER_OK(op, "[4,8,8,3];?", "?");
|
||||
INFER_OK(op, "[4,8,8,3];?", "[?,?,?,?]");
|
||||
|
||||
// croppings not correct shape
|
||||
INFER_ERROR("Shape must be rank 2 but is rank 1", op, "[4,8,8,3];[4]");
|
||||
|
@ -646,7 +646,7 @@ def call_cpp_shape_fn(op, input_tensors_needed=None,
|
||||
if str(result) != str(python_result):
|
||||
raise ValueError(
|
||||
("Python vs CPP shape mismatch. "
|
||||
"python: %s vs CPP: %s on node %s "
|
||||
"CPP: %s vs python: %s on node %s "
|
||||
"with input shapes %s") % (
|
||||
str(result), str(python_result), str(op.node_def),
|
||||
",".join([str(i.get_shape()) for i in op.inputs])))
|
||||
|
@ -166,7 +166,7 @@ class SpaceToBatchErrorHandlingTest(tf.test.TestCase):
|
||||
x_np = [[[[1], [2]], [[3], [4]]]]
|
||||
paddings = np.zeros((2, 2), dtype=np.int32)
|
||||
block_size = 10
|
||||
with self.assertRaises(IndexError):
|
||||
with self.assertRaises(ValueError):
|
||||
out_tf = tf.space_to_batch(x_np, paddings, block_size)
|
||||
out_tf.eval()
|
||||
|
||||
@ -175,7 +175,7 @@ class SpaceToBatchErrorHandlingTest(tf.test.TestCase):
|
||||
x_np = [[[[1], [2], [3]], [[3], [4], [7]]]]
|
||||
paddings = np.zeros((2, 2), dtype=np.int32)
|
||||
block_size = 3
|
||||
with self.assertRaises(IndexError):
|
||||
with self.assertRaises(ValueError):
|
||||
_ = tf.space_to_batch(x_np, paddings, block_size)
|
||||
|
||||
def testBlockSizeNotDivisibleHeight(self):
|
||||
@ -183,7 +183,7 @@ class SpaceToBatchErrorHandlingTest(tf.test.TestCase):
|
||||
x_np = [[[[1], [2]], [[3], [4]], [[5], [6]]]]
|
||||
paddings = np.zeros((2, 2), dtype=np.int32)
|
||||
block_size = 3
|
||||
with self.assertRaises(IndexError):
|
||||
with self.assertRaises(ValueError):
|
||||
_ = tf.space_to_batch(x_np, paddings, block_size)
|
||||
|
||||
def testBlockSizeNotDivisibleBoth(self):
|
||||
@ -191,7 +191,7 @@ class SpaceToBatchErrorHandlingTest(tf.test.TestCase):
|
||||
x_np = [[[[1], [2]], [[3], [4]]]]
|
||||
paddings = np.zeros((2, 2), dtype=np.int32)
|
||||
block_size = 3
|
||||
with self.assertRaises(IndexError):
|
||||
with self.assertRaises(ValueError):
|
||||
_ = tf.space_to_batch(x_np, paddings, block_size)
|
||||
|
||||
def testUnknownShape(self):
|
||||
|
@ -226,7 +226,7 @@ class TransposeTest(tf.test.TestCase):
|
||||
self._testError(np.arange(0., 2 ** 11).reshape([2] * 11),
|
||||
np.arange(11),
|
||||
"not implemented")
|
||||
with self.assertRaises(IndexError):
|
||||
with self.assertRaises(ValueError):
|
||||
tf.transpose(np.arange(0., 30).reshape([2, 3, 5]), [0, 1, 3])
|
||||
self._testError(np.arange(0., 30).reshape([2, 3, 5]),
|
||||
[0, 1, 1],
|
||||
|
@ -807,39 +807,7 @@ ops.RegisterShape("Unpack")(common_shapes.call_cpp_shape_fn)
|
||||
|
||||
@ops.RegisterShape("Concat")
|
||||
def _ConcatShape(op):
|
||||
concat_dim = tensor_util.constant_value(op.inputs[0])
|
||||
if concat_dim is None:
|
||||
# Return an unknown shape with the same rank as the inputs, or an
|
||||
# unknown rank if no input's rank is known.
|
||||
rank = None
|
||||
for value in op.inputs[1:]:
|
||||
if rank is not None:
|
||||
value.get_shape().assert_has_rank(rank)
|
||||
else:
|
||||
rank = value.get_shape().ndims
|
||||
if rank == 0:
|
||||
raise ValueError("Can't concatenate scalars (use tf.pack instead)")
|
||||
return [tensor_shape.unknown_shape(ndims=rank)]
|
||||
|
||||
else:
|
||||
# Merge all the non-concat dims, and sum the concat dim to make an
|
||||
# output shape.
|
||||
concat_dim = int(concat_dim)
|
||||
if concat_dim < 0:
|
||||
raise ValueError("Expected concat_dim >= 0, but got %d" % concat_dim)
|
||||
|
||||
output_shape = op.inputs[1].get_shape()
|
||||
for value in op.inputs[2:]:
|
||||
value_shape = value.get_shape()
|
||||
if value_shape.ndims is not None and concat_dim >= value_shape.ndims:
|
||||
raise ValueError("Expected concat_dim in range [0, %d), but got %d" %
|
||||
(value_shape.ndims, concat_dim))
|
||||
before = output_shape[:concat_dim].merge_with(value_shape[:concat_dim])
|
||||
at = output_shape[concat_dim] + value_shape[concat_dim]
|
||||
after = output_shape[
|
||||
concat_dim + 1:].merge_with(value_shape[concat_dim + 1:])
|
||||
output_shape = before.concatenate(at).concatenate(after)
|
||||
return [output_shape]
|
||||
return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[0])
|
||||
|
||||
|
||||
ops.RegisterShape("ConcatOffset")(common_shapes.call_cpp_shape_fn)
|
||||
@ -1834,63 +1802,12 @@ ops.RegisterShape("ListDiff")(common_shapes.call_cpp_shape_fn)
|
||||
@ops.RegisterShape("Pad")
|
||||
@ops.RegisterShape("MirrorPad")
|
||||
def _PadShape(op):
|
||||
"""Shape function for the Pad op.
|
||||
|
||||
This op has two inputs:
|
||||
|
||||
* input: A rank-N tensor.
|
||||
* paddings: An N-by-2 matrix, in which the i^th row contains the
|
||||
number of padding elements to add before and after `input` in the
|
||||
i^th dimension.
|
||||
|
||||
It has one output, which has the same rank as input, and additional
|
||||
elements according to the values in paddings.
|
||||
|
||||
Args:
|
||||
op: A Pad Operation.
|
||||
|
||||
Returns:
|
||||
A single-element list containing the shape of the output.
|
||||
|
||||
Raises:
|
||||
ValueError: If the input shapes are incompatible.
|
||||
"""
|
||||
paddings_shape = op.inputs[1].get_shape().with_rank(2)
|
||||
input_shape = op.inputs[0].get_shape()
|
||||
input_shape = input_shape.with_rank(paddings_shape[0].value)
|
||||
paddings_shape = paddings_shape.merge_with(
|
||||
tensor_shape.matrix(input_shape.ndims, 2))
|
||||
paddings = tensor_util.constant_value(op.inputs[1])
|
||||
if paddings is None:
|
||||
return [tensor_shape.unknown_shape(ndims=input_shape.ndims)]
|
||||
else:
|
||||
output_dims = []
|
||||
for i, dim in enumerate(input_shape.dims):
|
||||
if paddings[i, 0] < 0 or paddings[i, 1] < 0:
|
||||
raise ValueError("paddings must be non-negative")
|
||||
output_dims.append(dim + paddings[i, 0] + paddings[i, 1])
|
||||
return [tensor_shape.TensorShape(output_dims)]
|
||||
return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[1])
|
||||
|
||||
|
||||
@ops.RegisterShape("MirrorPadGrad")
|
||||
def _MirrorPadGradShape(op):
|
||||
"""Shape function for the MirrorPadGrad op."""
|
||||
paddings_shape = op.inputs[1].get_shape().with_rank(2)
|
||||
input_shape = op.inputs[0].get_shape().with_rank(paddings_shape[0].value)
|
||||
paddings_shape = paddings_shape.merge_with(tensor_shape.matrix(
|
||||
input_shape.ndims, 2))
|
||||
paddings = tensor_util.constant_value(op.inputs[1])
|
||||
if paddings is None:
|
||||
return [tensor_shape.unknown_shape(ndims=input_shape.ndims)]
|
||||
|
||||
output_dims = []
|
||||
for i, dim in enumerate(input_shape.dims):
|
||||
if paddings[i, 0] < 0 or paddings[i, 1] < 0:
|
||||
raise ValueError("Paddings must be non-negative.")
|
||||
if dim < paddings[i, 0] + paddings[i, 1]:
|
||||
raise ValueError("Output dimension is negative.")
|
||||
output_dims.append(dim - paddings[i, 0] - paddings[i, 1])
|
||||
return [tensor_shape.TensorShape(output_dims)]
|
||||
return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[1])
|
||||
|
||||
|
||||
ops.RegisterShape("ReverseSequence")(common_shapes.call_cpp_shape_fn)
|
||||
@ -1900,58 +1817,12 @@ ops.RegisterShape("ShapeN")(common_shapes.call_cpp_shape_fn)
|
||||
|
||||
@ops.RegisterShape("Transpose")
|
||||
def _TransposeShape(op):
|
||||
"""Shape function for the Transpose op.
|
||||
|
||||
This op takes two inputs:
|
||||
|
||||
* input: a rank-N tensor of arbitrary shape.
|
||||
* shuffle: a length-N vector.
|
||||
|
||||
Its output is the rank-N tensor computed by permuting the dimensions
|
||||
of input according to shuffle.
|
||||
|
||||
Args:
|
||||
op: A Transpose op.
|
||||
|
||||
Returns:
|
||||
A single-element list containing the shape of the output.
|
||||
|
||||
Raises:
|
||||
ValueError: If the shapes of input and shuffle are incompatible.
|
||||
IndexError: If shuffle contains an index that is >= the rank of input.
|
||||
"""
|
||||
input_shape = op.inputs[0].get_shape()
|
||||
transpose_shape = op.inputs[1].get_shape().merge_with(tensor_shape.vector(
|
||||
input_shape.ndims))
|
||||
transpose_vec = tensor_util.constant_value(op.inputs[1])
|
||||
if transpose_vec is None:
|
||||
return [tensor_shape.unknown_shape(ndims=transpose_shape[0].value)]
|
||||
else:
|
||||
return [tensor_shape.TensorShape([input_shape[i]
|
||||
for i in transpose_vec.tolist()])]
|
||||
return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[1])
|
||||
|
||||
|
||||
@ops.RegisterShape("Split")
|
||||
def _SplitShape(op):
|
||||
"""Shape function for the Split op."""
|
||||
split_dim = tensor_util.constant_value(op.inputs[0])
|
||||
num_split = len(op.outputs)
|
||||
input_shape = op.inputs[1].get_shape()
|
||||
if split_dim is None:
|
||||
return [tensor_shape.unknown_shape(ndims=input_shape.ndims)] * num_split
|
||||
else:
|
||||
split_dim = int(split_dim)
|
||||
input_shape = input_shape.with_rank_at_least(split_dim + 1)
|
||||
if not (input_shape[split_dim] % num_split).is_compatible_with(0):
|
||||
raise ValueError(
|
||||
"Number of ways to split should evenly divide the split "
|
||||
"dimension but got split_dim %d (size = %d) and num_split %d" %
|
||||
(split_dim, input_shape[split_dim].value, num_split))
|
||||
prefix = input_shape[:split_dim]
|
||||
size_in_split_dim = input_shape[split_dim] // num_split
|
||||
suffix = input_shape[split_dim + 1:]
|
||||
output_shape = prefix.concatenate(size_in_split_dim).concatenate(suffix)
|
||||
return [output_shape] * num_split
|
||||
return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[0])
|
||||
|
||||
|
||||
@ops.RegisterShape("Tile")
|
||||
@ -2088,18 +1959,7 @@ def edit_distance(hypothesis, truth, normalize=True, name="edit_distance"):
|
||||
|
||||
@ops.RegisterShape("EditDistance")
|
||||
def _EditDistanceShape(op):
|
||||
"""Shape function for the EditDistance op."""
|
||||
hypothesis_shape = tensor_util.constant_value(op.inputs[2])
|
||||
truth_shape = tensor_util.constant_value(op.inputs[5])
|
||||
if hypothesis_shape is not None and truth_shape is not None:
|
||||
if len(hypothesis_shape) != len(truth_shape):
|
||||
raise ValueError(
|
||||
"Inconsistent ranks in hypothesis and truth. Saw shapes: %s and %s" %
|
||||
(str(hypothesis_shape), str(truth_shape)))
|
||||
return [tensor_shape.TensorShape(
|
||||
[max(h, t) for h, t in zip(hypothesis_shape[:-1], truth_shape[:-1])])]
|
||||
|
||||
return [tensor_shape.unknown_shape()]
|
||||
return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[2, 5])
|
||||
|
||||
|
||||
# The remaining ops do not change the shape of their inputs.
|
||||
@ -2164,80 +2024,7 @@ def _ExtractImagePatchesShape(op):
|
||||
|
||||
@ops.RegisterShape("SpaceToBatch")
|
||||
def _SpaceToBatchShape(op):
|
||||
"""Shape function for the SpaceToBatch op.
|
||||
|
||||
The output shape is determined by the following inputs/ attributes:
|
||||
|
||||
* input: A rank-4 tensor with shape [B, H, W, D]
|
||||
* paddings: A 2-by-2 matrix, specified as follows:
|
||||
|
||||
paddings = [[pad_top, pad_bottom], [pad_left, pad_right]],
|
||||
|
||||
implying effective padded spatial dimensions:
|
||||
|
||||
Hp = pad_top + H + pad_bottom
|
||||
Wp = pad_left + W + pad_right
|
||||
|
||||
Both Hp and Wp must be multiples of block_size.
|
||||
* block_size: an int.
|
||||
|
||||
Its output is also a rank-4 tensor with shape:
|
||||
|
||||
[B*block_size*block_size, Hp/block_size, Wp/block_size, D]
|
||||
|
||||
Args:
|
||||
op: A SpaceToBatch op.
|
||||
|
||||
Returns:
|
||||
A single-element list containing the shape of the output.
|
||||
|
||||
Raises:
|
||||
ValueError: If the shapes of inputs are not as expected.
|
||||
IndexError: If block_size does not divide Wp or Hp.
|
||||
"""
|
||||
# Check that the input tensor is 4-D.
|
||||
try:
|
||||
input_shape = op.inputs[0].get_shape().with_rank(4)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
"tf.space_to_batch() requires 4-D input tensor.")
|
||||
|
||||
# Check that the paddings tensor is a matrix with shape [2, 2].
|
||||
try:
|
||||
paddings_shape = op.inputs[1].get_shape().with_rank(2)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
"tf.space_to_batch() requires 2-D paddings tensor.")
|
||||
|
||||
if paddings_shape[0] != 2 or paddings_shape[1] != 2:
|
||||
raise ValueError(
|
||||
"tf.space_to_batch() requires input paddings with shape [2, 2].")
|
||||
|
||||
block_size = op.get_attr("block_size")
|
||||
if block_size <= 1:
|
||||
raise ValueError("Attribute block_size has to be > 1.")
|
||||
|
||||
paddings = tensor_util.constant_value(op.inputs[1])
|
||||
if paddings is not None:
|
||||
if (paddings[0, 0] < 0 or paddings[0, 1] < 0 or
|
||||
paddings[1, 0] < 0 or paddings[1, 1] < 0):
|
||||
raise ValueError("paddings cannot be negative.")
|
||||
|
||||
input_height = input_shape[1] + paddings[0, 0] + paddings[0, 1]
|
||||
input_width = input_shape[2] + paddings[1, 0] + paddings[1, 1]
|
||||
|
||||
if input_height % block_size > 0 or input_width % block_size > 0:
|
||||
raise IndexError("block_size needs to divide both width and height.")
|
||||
else:
|
||||
input_height = tensor_shape.Dimension(None)
|
||||
input_width = tensor_shape.Dimension(None)
|
||||
|
||||
batch = input_shape[0] * block_size * block_size
|
||||
height = input_height // block_size
|
||||
width = input_width // block_size
|
||||
depth = input_shape[3]
|
||||
|
||||
return [tensor_shape.TensorShape([batch, height, width, depth])]
|
||||
return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[1])
|
||||
|
||||
|
||||
@ops.RegisterShape("BatchToSpace")
|
||||
@ -2584,33 +2371,7 @@ def one_hot(indices, depth, on_value=None, off_value=None,
|
||||
|
||||
@ops.RegisterShape("OneHot")
|
||||
def _OneHotShape(op):
|
||||
"""Shape function for the OneHot op.
|
||||
|
||||
It closely follows the code in the .cc implementation.
|
||||
|
||||
Args:
|
||||
op: A OneHot Operation.
|
||||
|
||||
Returns:
|
||||
A single-element list containing the shape of the output.
|
||||
|
||||
Raises:
|
||||
ValueError: if axis < -1.
|
||||
"""
|
||||
indices_shape = op.inputs[0].get_shape()
|
||||
indices_dims = indices_shape.ndims
|
||||
depth = tensor_util.constant_value(op.inputs[1])
|
||||
axis = op.get_attr("axis")
|
||||
|
||||
if axis < -1:
|
||||
raise ValueError("axis must be >= -1")
|
||||
|
||||
new_shape = None
|
||||
if indices_dims is not None:
|
||||
new_shape = indices_shape.as_list()
|
||||
new_shape.insert(axis % (indices_dims + 1), depth)
|
||||
|
||||
return [tensor_shape.TensorShape(new_shape)]
|
||||
return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[1])
|
||||
|
||||
|
||||
@ops.RegisterShape("PlaceholderWithDefault")
|
||||
|
Loading…
Reference in New Issue
Block a user