Merge pull request #41429 from dnguyen28061:scalar_shape
PiperOrigin-RevId: 322706987 Change-Id: I9eb4ecd72ea393501a7eb9f1a5e51383d50780fd
This commit is contained in:
commit
0b2019bdec
tensorflow/c
@ -104,6 +104,12 @@ TF_ShapeHandle* TF_NewShapeHandle() {
|
||||
return reinterpret_cast<TF_ShapeHandle*>(new ShapeHandle);
|
||||
}
|
||||
|
||||
TF_ShapeHandle* TF_ShapeInferenceContextScalar(TF_ShapeInferenceContext* ctx) {
|
||||
auto* handle = new ShapeHandle;
|
||||
*handle = reinterpret_cast<InferenceContext*>(ctx)->Scalar();
|
||||
return reinterpret_cast<TF_ShapeHandle*>(handle);
|
||||
}
|
||||
|
||||
TF_ShapeHandle* TF_ShapeInferenceContextVectorFromSize(
|
||||
TF_ShapeInferenceContext* ctx, size_t size) {
|
||||
auto* handle = new ShapeHandle;
|
||||
|
@ -280,6 +280,11 @@ extern void TF_ShapeInferenceContextSetOutput(TF_ShapeInferenceContext* ctx,
|
||||
int i, TF_ShapeHandle* handle,
|
||||
TF_Status* status);
|
||||
|
||||
// Returns a newly-allocated scalar shape handle. The returned handle should
|
||||
// be freed with TF_DeleteShapeHandle.
|
||||
TF_CAPI_EXPORT extern TF_ShapeHandle* TF_ShapeInferenceContextScalar(
|
||||
TF_ShapeInferenceContext* ctx);
|
||||
|
||||
// Returns a newly-allocate shape handle representing a vector of the given
|
||||
// size. The returned handle should be freed with TF_DeleteShapeHandle.
|
||||
TF_CAPI_EXPORT extern TF_ShapeHandle* TF_ShapeInferenceContextVectorFromSize(
|
||||
|
@ -316,5 +316,16 @@ TEST(OpsTest, ShapeInferenceSubshape) {
|
||||
TF_DeleteShapeHandle(handle);
|
||||
}
|
||||
|
||||
TEST(OpsTest, ShapeInferenceScalarShape) {
|
||||
NodeDef def;
|
||||
shape_inference::InferenceContext c(0, def, MakeOpDef(0, 0), {S({})}, {}, {},
|
||||
{});
|
||||
TF_ShapeHandle* TF_scalar_shape = TF_ShapeInferenceContextScalar(C_CTX(&c));
|
||||
shape_inference::ShapeHandle* scalar_shape =
|
||||
reinterpret_cast<shape_inference::ShapeHandle*>(TF_scalar_shape);
|
||||
ASSERT_EQ("[]", c.DebugString(*scalar_shape));
|
||||
TF_DeleteShapeHandle(TF_scalar_shape);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user