Merge pull request from dnguyen28061:scalar_shape

PiperOrigin-RevId: 322706987
Change-Id: I9eb4ecd72ea393501a7eb9f1a5e51383d50780fd
This commit is contained in:
TensorFlower Gardener 2020-07-22 19:48:11 -07:00
commit 0b2019bdec
3 changed files with 22 additions and 0 deletions

View File

@ -104,6 +104,12 @@ TF_ShapeHandle* TF_NewShapeHandle() {
return reinterpret_cast<TF_ShapeHandle*>(new ShapeHandle);
}
TF_ShapeHandle* TF_ShapeInferenceContextScalar(TF_ShapeInferenceContext* ctx) {
auto* handle = new ShapeHandle;
*handle = reinterpret_cast<InferenceContext*>(ctx)->Scalar();
return reinterpret_cast<TF_ShapeHandle*>(handle);
}
TF_ShapeHandle* TF_ShapeInferenceContextVectorFromSize(
TF_ShapeInferenceContext* ctx, size_t size) {
auto* handle = new ShapeHandle;

View File

@ -280,6 +280,11 @@ extern void TF_ShapeInferenceContextSetOutput(TF_ShapeInferenceContext* ctx,
int i, TF_ShapeHandle* handle,
TF_Status* status);
// Returns a newly-allocated scalar shape handle. The returned handle should
// be freed with TF_DeleteShapeHandle.
TF_CAPI_EXPORT extern TF_ShapeHandle* TF_ShapeInferenceContextScalar(
TF_ShapeInferenceContext* ctx);
// Returns a newly-allocate shape handle representing a vector of the given
// size. The returned handle should be freed with TF_DeleteShapeHandle.
TF_CAPI_EXPORT extern TF_ShapeHandle* TF_ShapeInferenceContextVectorFromSize(

View File

@ -316,5 +316,16 @@ TEST(OpsTest, ShapeInferenceSubshape) {
TF_DeleteShapeHandle(handle);
}
TEST(OpsTest, ShapeInferenceScalarShape) {
NodeDef def;
shape_inference::InferenceContext c(0, def, MakeOpDef(0, 0), {S({})}, {}, {},
{});
TF_ShapeHandle* TF_scalar_shape = TF_ShapeInferenceContextScalar(C_CTX(&c));
shape_inference::ShapeHandle* scalar_shape =
reinterpret_cast<shape_inference::ShapeHandle*>(TF_scalar_shape);
ASSERT_EQ("[]", c.DebugString(*scalar_shape));
TF_DeleteShapeHandle(TF_scalar_shape);
}
} // namespace
} // namespace tensorflow