Infer unknown shapes for functions in C++

As we are implementing function support through C API, the new code path
runs shape inference of Operations representing functions, but we don't
yet support shape inference for functions.

Before this change, adding a function NodeDef would result in error.
This change pairs all functions with a shape inference function that
sets all output shapes to unknown.

PiperOrigin-RevId: 163830793
This commit is contained in:
A. Unique TensorFlower 2017-08-01 08:13:52 -07:00 committed by TensorFlower Gardener
parent 3cc5fc0886
commit 203c3f5fd4
4 changed files with 15 additions and 2 deletions

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <utility>
#include <vector>
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/function.pb_text.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
@ -867,6 +868,11 @@ Status FunctionCallFrame::SetRetval(int index, const Tensor& val) {
return Status::OK();
}
FunctionLibraryDefinition::FunctionDefAndOpRegistration::
FunctionDefAndOpRegistration(const FunctionDef& fdef_in)
: fdef(fdef_in),
op_registration_data(fdef.signature(), shape_inference::UnknownShape) {}
FunctionLibraryDefinition::FunctionLibraryDefinition(
const FunctionLibraryDefinition& other)
: default_registry_(other.default_registry_), func_grad_(other.func_grad_) {

View File

@ -348,8 +348,7 @@ class FunctionLibraryDefinition : public OpRegistryInterface {
private:
// TODO(cwhipkey): support shape functions in FunctionDefLibrary.
struct FunctionDefAndOpRegistration {
FunctionDefAndOpRegistration(const FunctionDef& fdef_in)
: fdef(fdef_in), op_registration_data(fdef.signature()) {}
FunctionDefAndOpRegistration(const FunctionDef& fdef_in);
FunctionDef fdef;
OpRegistrationData op_registration_data;

View File

@ -938,6 +938,12 @@ TEST(FunctionLibraryDefinitionTest, LookUp) {
ASSERT_NE(op_def, nullptr);
EXPECT_EQ(op_def->DebugString(),
test::function::XTimesTwo().signature().DebugString());
const OpRegistrationData* op_reg_data;
TF_EXPECT_OK(lib_def.LookUp("XTimesTwo", &op_reg_data));
ASSERT_NE(op_reg_data, nullptr);
// Shape inference function is initialized to UnknownShape.
ASSERT_NE(op_reg_data->shape_inference_fn, nullptr);
}
TEST(FunctionLibraryDefinitionTest, AddFunctionDef) {

View File

@ -38,6 +38,8 @@ struct OpRegistrationData {
public:
OpRegistrationData() {}
OpRegistrationData(const OpDef& def) : op_def(def) {}
OpRegistrationData(const OpDef& def, const OpShapeInferenceFn& fn)
: op_def(def), shape_inference_fn(fn) {}
OpDef op_def;
OpShapeInferenceFn shape_inference_fn;