C++ shape inference for OneHot op.
Change: 129145757
This commit is contained in:
parent
9b2c80c435
commit
5a828e3a9c
@ -2877,6 +2877,32 @@ REGISTER_OP("OneHot")
|
||||
.Output("output: T")
|
||||
.Attr("T: type")
|
||||
.Attr("TI: {uint8, int32, int64} = DT_INT64")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
int32 axis;
|
||||
TF_RETURN_IF_ERROR(c->GetAttr("axis", &axis));
|
||||
if (axis < -1) return errors::InvalidArgument("axis must be >= -1");
|
||||
|
||||
const Dimension* depth;
|
||||
TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(1, &depth));
|
||||
|
||||
const Shape* indices = c->input(0);
|
||||
if (!c->RankKnown(indices)) return shape_inference::UnknownShape(c);
|
||||
|
||||
int32 new_rank = c->Rank(indices) + 1;
|
||||
// We need to add new_rank to axis in the case the axis is -1 because
|
||||
// C++ returns negative values from % if the dividend is negative.
|
||||
int32 depth_index = (axis + new_rank) % new_rank;
|
||||
// Out shape is indices[0:depth_index] + [depth] + indices[depth_index:].
|
||||
const Shape* front;
|
||||
const Shape* back;
|
||||
const Shape* out;
|
||||
TF_RETURN_IF_ERROR(c->Subshape(indices, 0, depth_index, &front));
|
||||
TF_RETURN_IF_ERROR(c->Subshape(indices, depth_index, &back));
|
||||
TF_RETURN_IF_ERROR(c->Concatenate(front, c->Vector(depth), &front));
|
||||
TF_RETURN_IF_ERROR(c->Concatenate(front, back, &out));
|
||||
c->set_output(0, out);
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
Returns a one-hot tensor.
|
||||
|
||||
|
@ -844,4 +844,37 @@ TEST(ArrayOpsTest, EditDistance_ShapeFn) {
|
||||
"[?];[?];[1];[?];[?];[4]");
|
||||
}
|
||||
|
||||
TEST(ArrayOpsTest, OneHot_ShapeFn) {
|
||||
ShapeInferenceTestOp op("OneHot");
|
||||
op.input_tensors.resize(4);
|
||||
auto set_axis = [&op](int axis) {
|
||||
TF_CHECK_OK(NodeDefBuilder("test", "OneHot")
|
||||
.Input("indices", 0, DT_FLOAT)
|
||||
.Input("depth", 1, DT_INT32)
|
||||
.Input("on_value", 2, DT_FLOAT)
|
||||
.Input("off_value", 3, DT_FLOAT)
|
||||
.Attr("axis", axis)
|
||||
.Finalize(&op.node_def));
|
||||
};
|
||||
|
||||
// Invalid axis value.
|
||||
set_axis(-2);
|
||||
INFER_ERROR("axis must be >= -1", op, "?;?;?;?");
|
||||
set_axis(1);
|
||||
|
||||
// If indices shape is unknown, we return an unknown shape.
|
||||
INFER_OK(op, "?;[];?;?", "?");
|
||||
|
||||
// Depth must be scalar.
|
||||
Tensor depth = test::AsTensor<int32>({1, 2});
|
||||
op.input_tensors[1] = &depth;
|
||||
INFER_ERROR("Input must be scalar but has rank 1", op, "?;[2];?;?");
|
||||
|
||||
// Full information is available.
|
||||
depth = test::AsScalar<int32>(2);
|
||||
INFER_OK(op, "[1,3,4];[];?;?", "[d0_0,2,d0_1,d0_2]");
|
||||
set_axis(-1);
|
||||
INFER_OK(op, "[1,3,4];[];?;?", "[d0_0,d0_1,d0_2,2]");
|
||||
}
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
Loading…
x
Reference in New Issue
Block a user