Add variable operators to Flex

This is required for enabling flex fallback for complicated variable operators
for stateful random operators and so on.

PiperOrigin-RevId: 353578616
Change-Id: Ic71f36199ff437c2efa6f259ed2fd474694617d9
This commit is contained in:
Jaesung Chung 2021-01-24 22:06:24 -08:00 committed by TensorFlower Gardener
parent 167d0c0532
commit 77a9b64c88
2 changed files with 23 additions and 0 deletions

View File

@ -5929,10 +5929,12 @@ filegroup(
"relu_op.h", "relu_op.h",
"relu_op_functor.h", "relu_op_functor.h",
"reshape_util.h", "reshape_util.h",
"resource_variable_ops.h",
"reverse_op.h", "reverse_op.h",
"roll_op.h", "roll_op.h",
"save_restore_tensor.h", "save_restore_tensor.h",
"scan_ops.h", "scan_ops.h",
"scatter_functor.h",
"scatter_nd_op.h", "scatter_nd_op.h",
"segment_reduction_ops.h", "segment_reduction_ops.h",
"segment_reduction_ops_impl.h", "segment_reduction_ops_impl.h",
@ -6177,6 +6179,7 @@ filegroup(
"regex_full_match_op.cc", "regex_full_match_op.cc",
"relu_op.cc", "relu_op.cc",
"reshape_util.cc", "reshape_util.cc",
"resource_variable_ops.cc",
"restore_op.cc", "restore_op.cc",
"reverse_op.cc", "reverse_op.cc",
"roll_op.cc", "roll_op.cc",
@ -6184,6 +6187,7 @@ filegroup(
"save_restore_tensor.cc", "save_restore_tensor.cc",
"save_restore_v2_ops.cc", "save_restore_v2_ops.cc",
"scan_ops.cc", "scan_ops.cc",
"scatter_functor.cc",
"scatter_nd_op.cc", "scatter_nd_op.cc",
"scatter_nd_op_cpu_impl_0.cc", "scatter_nd_op_cpu_impl_0.cc",
"scatter_nd_op_cpu_impl_1.cc", "scatter_nd_op_cpu_impl_1.cc",

View File

@ -59,7 +59,10 @@ const std::set<std::string>& GetFlexAllowlist() {
"Assert", "Assert",
"Assign", "Assign",
"AssignAdd", "AssignAdd",
"AssignAddVariableOp",
"AssignSub", "AssignSub",
"AssignSubVariableOp",
"AssignVariableOp",
"Atan", "Atan",
"Atan2", "Atan2",
"AudioSpectrogram", "AudioSpectrogram",
@ -154,6 +157,7 @@ const std::set<std::string>& GetFlexAllowlist() {
"DepthToSpace", "DepthToSpace",
"DepthwiseConv2dNative", "DepthwiseConv2dNative",
"Dequantize", "Dequantize",
"DestroyResourceOp",
"DestroyTemporaryVariable", "DestroyTemporaryVariable",
"Diag", "Diag",
"DiagPart", "DiagPart",
@ -375,6 +379,7 @@ const std::set<std::string>& GetFlexAllowlist() {
"RandomUniformInt", "RandomUniformInt",
"Range", "Range",
"Rank", "Rank",
"ReadVariableOp",
"Real", "Real",
"RealDiv", "RealDiv",
"Reciprocal", "Reciprocal",
@ -420,11 +425,20 @@ const std::set<std::string>& GetFlexAllowlist() {
"ResourceApplyProximalAdagrad", "ResourceApplyProximalAdagrad",
"ResourceApplyProximalGradientDescent", "ResourceApplyProximalGradientDescent",
"ResourceApplyRMSProp", "ResourceApplyRMSProp",
"ResourceGather",
"ResourceGatherNd",
"ResourceScatterAdd",
"ResourceScatterDiv",
"ResourceScatterMax",
"ResourceScatterMin",
"ResourceScatterMul",
"ResourceScatterNdAdd", "ResourceScatterNdAdd",
"ResourceScatterNdMax", "ResourceScatterNdMax",
"ResourceScatterNdMin", "ResourceScatterNdMin",
"ResourceScatterNdSub", "ResourceScatterNdSub",
"ResourceScatterNdUpdate", "ResourceScatterNdUpdate",
"ResourceScatterSub",
"ResourceScatterUpdate",
"ResourceSparseApplyAdadelta", "ResourceSparseApplyAdadelta",
"ResourceSparseApplyAdagrad", "ResourceSparseApplyAdagrad",
"ResourceSparseApplyAdagradDA", "ResourceSparseApplyAdagradDA",
@ -660,7 +674,10 @@ const std::set<std::string>& GetFlexAllowlist() {
"UnsortedSegmentProd", "UnsortedSegmentProd",
"UnsortedSegmentSum", "UnsortedSegmentSum",
"UnwrapDatasetVariant", "UnwrapDatasetVariant",
"VarHandleOp",
"VarIsInitializedOp",
"Variable", "Variable",
"VariableShape",
"VariableV2", "VariableV2",
"Where", "Where",
"WrapDatasetVariant", "WrapDatasetVariant",
@ -679,10 +696,12 @@ const std::set<std::string>& GetFlexAllowlist() {
"_ListToArray", "_ListToArray",
"_ParallelConcatStart", "_ParallelConcatStart",
"_ParallelConcatUpdate", "_ParallelConcatUpdate",
"_ReadVariablesOp",
"_Recv", "_Recv",
"_Retval", "_Retval",
"_Send", "_Send",
"_SwitchN", "_SwitchN",
"_VarHandlesOp",
// go/keep-sorted end // go/keep-sorted end
}); });
return *allowlisted_flex_ops; return *allowlisted_flex_ops;