Infer unknown shapes for functions in C++
As we are implementing function support through C API, the new code path runs shape inference of Operations representing functions, but we don't yet support shape inference for functions. Before this change, adding a function NodeDef would result in error. This change pairs all functions with a shape inference function that sets all output shapes to unknown. PiperOrigin-RevId: 163830793
This commit is contained in:
parent
3cc5fc0886
commit
203c3f5fd4
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/function.pb_text.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
@ -867,6 +868,11 @@ Status FunctionCallFrame::SetRetval(int index, const Tensor& val) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
FunctionLibraryDefinition::FunctionDefAndOpRegistration::
|
||||
FunctionDefAndOpRegistration(const FunctionDef& fdef_in)
|
||||
: fdef(fdef_in),
|
||||
op_registration_data(fdef.signature(), shape_inference::UnknownShape) {}
|
||||
|
||||
FunctionLibraryDefinition::FunctionLibraryDefinition(
|
||||
const FunctionLibraryDefinition& other)
|
||||
: default_registry_(other.default_registry_), func_grad_(other.func_grad_) {
|
||||
|
@ -348,8 +348,7 @@ class FunctionLibraryDefinition : public OpRegistryInterface {
|
||||
private:
|
||||
// TODO(cwhipkey): support shape functions in FunctionDefLibrary.
|
||||
struct FunctionDefAndOpRegistration {
|
||||
FunctionDefAndOpRegistration(const FunctionDef& fdef_in)
|
||||
: fdef(fdef_in), op_registration_data(fdef.signature()) {}
|
||||
FunctionDefAndOpRegistration(const FunctionDef& fdef_in);
|
||||
|
||||
FunctionDef fdef;
|
||||
OpRegistrationData op_registration_data;
|
||||
|
@ -938,6 +938,12 @@ TEST(FunctionLibraryDefinitionTest, LookUp) {
|
||||
ASSERT_NE(op_def, nullptr);
|
||||
EXPECT_EQ(op_def->DebugString(),
|
||||
test::function::XTimesTwo().signature().DebugString());
|
||||
|
||||
const OpRegistrationData* op_reg_data;
|
||||
TF_EXPECT_OK(lib_def.LookUp("XTimesTwo", &op_reg_data));
|
||||
ASSERT_NE(op_reg_data, nullptr);
|
||||
// Shape inference function is initialized to UnknownShape.
|
||||
ASSERT_NE(op_reg_data->shape_inference_fn, nullptr);
|
||||
}
|
||||
|
||||
TEST(FunctionLibraryDefinitionTest, AddFunctionDef) {
|
||||
|
@ -38,6 +38,8 @@ struct OpRegistrationData {
|
||||
public:
|
||||
OpRegistrationData() {}
|
||||
OpRegistrationData(const OpDef& def) : op_def(def) {}
|
||||
OpRegistrationData(const OpDef& def, const OpShapeInferenceFn& fn)
|
||||
: op_def(def), shape_inference_fn(fn) {}
|
||||
|
||||
OpDef op_def;
|
||||
OpShapeInferenceFn shape_inference_fn;
|
||||
|
Loading…
x
Reference in New Issue
Block a user