C++ shape inference for OneHot op.

Change: 129145757
This commit is contained in:
Suharsh Sivakumar 2016-08-02 14:01:04 -08:00 committed by TensorFlower Gardener
parent 9b2c80c435
commit 5a828e3a9c
2 changed files with 59 additions and 0 deletions

View File

@ -2877,6 +2877,32 @@ REGISTER_OP("OneHot")
.Output("output: T")
.Attr("T: type")
.Attr("TI: {uint8, int32, int64} = DT_INT64")
.SetShapeFn([](InferenceContext* c) {
int32 axis;
TF_RETURN_IF_ERROR(c->GetAttr("axis", &axis));
if (axis < -1) return errors::InvalidArgument("axis must be >= -1");
const Dimension* depth;
TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(1, &depth));
const Shape* indices = c->input(0);
if (!c->RankKnown(indices)) return shape_inference::UnknownShape(c);
int32 new_rank = c->Rank(indices) + 1;
// We need to add new_rank to axis in the case the axis is -1 because
// C++ returns negative values from % if the dividend is negative.
int32 depth_index = (axis + new_rank) % new_rank;
// Out shape is indices[0:depth_index] + [depth] + indices[depth_index:].
const Shape* front;
const Shape* back;
const Shape* out;
TF_RETURN_IF_ERROR(c->Subshape(indices, 0, depth_index, &front));
TF_RETURN_IF_ERROR(c->Subshape(indices, depth_index, &back));
TF_RETURN_IF_ERROR(c->Concatenate(front, c->Vector(depth), &front));
TF_RETURN_IF_ERROR(c->Concatenate(front, back, &out));
c->set_output(0, out);
return Status::OK();
})
.Doc(R"doc(
Returns a one-hot tensor.

View File

@ -844,4 +844,37 @@ TEST(ArrayOpsTest, EditDistance_ShapeFn) {
"[?];[?];[1];[?];[?];[4]");
}
TEST(ArrayOpsTest, OneHot_ShapeFn) {
ShapeInferenceTestOp op("OneHot");
op.input_tensors.resize(4);
auto set_axis = [&op](int axis) {
TF_CHECK_OK(NodeDefBuilder("test", "OneHot")
.Input("indices", 0, DT_FLOAT)
.Input("depth", 1, DT_INT32)
.Input("on_value", 2, DT_FLOAT)
.Input("off_value", 3, DT_FLOAT)
.Attr("axis", axis)
.Finalize(&op.node_def));
};
// Invalid axis value.
set_axis(-2);
INFER_ERROR("axis must be >= -1", op, "?;?;?;?");
set_axis(1);
// If indices shape is unknown, we return an unknown shape.
INFER_OK(op, "?;[];?;?", "?");
// Depth must be scalar.
Tensor depth = test::AsTensor<int32>({1, 2});
op.input_tensors[1] = &depth;
INFER_ERROR("Input must be scalar but has rank 1", op, "?;[2];?;?");
// Full information is available.
depth = test::AsScalar<int32>(2);
INFER_OK(op, "[1,3,4];[];?;?", "[d0_0,2,d0_1,d0_2]");
set_axis(-1);
INFER_OK(op, "[1,3,4];[];?;?", "[d0_0,d0_1,d0_2,2]");
}
} // end namespace tensorflow