From 22f801369fd989332b1ada3091b205295dba8606 Mon Sep 17 00:00:00 2001 From: Artem Belevich <tra@google.com> Date: Thu, 13 Feb 2020 14:15:35 -0800 Subject: [PATCH] Simplify gradient exclusions data to speed up compilation w/ clang on windows. PiperOrigin-RevId: 294998085 Change-Id: Ie56b8f2cf4ed1e5fd8e2a641947b2d69f316e86a --- tensorflow/python/eager/BUILD | 1 + .../eager/gradient_input_output_exclusions.py | 109 +- .../eager/pywrap_gradient_exclusions.cc | 1699 +++++++++-------- .../python/eager/pywrap_gradient_exclusions.h | 21 +- tensorflow/python/eager/pywrap_tfe_src.cc | 18 +- 5 files changed, 940 insertions(+), 908 deletions(-) diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index 2cede68fe1d..b0fe6398f01 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -54,6 +54,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/hash", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:variant", ], ) diff --git a/tensorflow/python/eager/gradient_input_output_exclusions.py b/tensorflow/python/eager/gradient_input_output_exclusions.py index dfd79b8204a..2340ad41715 100644 --- a/tensorflow/python/eager/gradient_input_output_exclusions.py +++ b/tensorflow/python/eager/gradient_input_output_exclusions.py @@ -63,10 +63,33 @@ limitations under the License. _INCLUDES = """ #include "tensorflow/python/eager/pywrap_gradient_exclusions.h" +#include "absl/types/optional.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" using tensorflow::string; + +namespace { +// Keep static data in a format that's easy to init statically. +struct OpIndexInfo { + const char *op_name; + int num_indices; + std::array<int, 4> unused_indices; +}; + +// Helper function to initialize FlatMap<string,FlatSet> from OpIndexInfo. +template <typename T> +auto OpGradientInfoInit(const T &a) { + auto *m = new tensorflow::gtl::FlatMap<string, tensorflow::gtl::FlatSet<int>>; + for (const auto &item : a) { + m->emplace(string(item.op_name), + tensorflow::gtl::FlatSet<int>( + item.unused_indices.begin(), + item.unused_indices.begin() + item.num_indices)); + } + return m; +} +} // namespace """ _EXCLUDED_OPS = [ @@ -281,7 +304,6 @@ def get_entries(attr_name): """ assert attr_name in ["inputs", "outputs"] entries = {} - spaces = " " for op_type in ops._gradient_registry.list(): # pylint: disable=protected-access if op_type in _EXCLUDED_OPS: continue @@ -291,72 +313,57 @@ def get_entries(attr_name): if gradient_fn is None: # NotDifferentiable if num_values != -1: - entries[op_type] = spaces + "{\"%s\", {true, {}}}," % op_type + entries[op_type] = "{\"%s\"}," % op_type continue used_tensors = _live_tensors(gradient_fn, attr_name=attr_name) if used_tensors is _ALL: continue elif not used_tensors: - entries[op_type] = spaces + "{\"%s\", {true, {}}}," % op_type + entries[op_type] = "{\"%s\"}," % op_type else: all_tensors = set(range(num_values)) unused_tensors = all_tensors - used_tensors if unused_tensors: - entries[op_type] = spaces + "{\"%s\", {false, {%s}}}," % ( - op_type, ", ".join(str(i) for i in sorted(list(unused_tensors)))) + unused_tensor_list = sorted(list(unused_tensors)) + entries[op_type] = "{\"%s\", %d, {%s}}," % ( + op_type, len(unused_tensor_list), ", ".join( + str(i) for i in unused_tensor_list)) return entries +def get_function(name, entries): + """Generates lookup function with given name and lookup table entries.""" + contents = """ +absl::optional<tensorflow::gtl::FlatSet<int>> {name}( + const tensorflow::string &op_name) {{ + static std::array<OpIndexInfo, {count}> a = {{{{ +""".format( + name=name, count=len(entries) + 1) + contents += " " + contents += "\n ".join(entries[op_type] for op_type in sorted(entries)) + contents += "\n {\"VarHandleOp\"}," + contents += """ + }}; + static const auto &m = *OpGradientInfoInit(a); + + auto it = m.find(op_name); + if (it != m.end()) { + return it->second; + } + return absl::nullopt; +} +""" + return contents + + def get_contents(): """Returns contents for the generated file.""" contents = "" contents += _GENERATED_FILE_HEADER + _INCLUDES - contents += """ -bool OpGradientDoesntRequireInputIndices( - const string& op_name, - std::pair<bool, tensorflow::gtl::FlatSet<int>>** output) { - static tensorflow::gtl::FlatMap< - string, std::pair<bool, tensorflow::gtl::FlatSet<int>>>* m = - new tensorflow::gtl::FlatMap< - string, std::pair<bool, tensorflow::gtl::FlatSet<int>>>({ -""" - - entries = get_entries("inputs") - contents += "\n".join(entries[op_type] for op_type in sorted(entries)) - contents += "\n {\"VarHandleOp\", {true, {}}},\n" - contents += """ }); - - auto it = m->find(op_name); - - if (it == m->end()) return false; - - *output = &it->second; - return true; -} -""" - contents += """ -bool OpGradientDoesntRequireOutputIndices( - const string& op_name, - std::pair<bool, tensorflow::gtl::FlatSet<int>>** output) { - static tensorflow::gtl::FlatMap< - string, std::pair<bool, tensorflow::gtl::FlatSet<int>>>* m = - new tensorflow::gtl::FlatMap< - string, std::pair<bool, tensorflow::gtl::FlatSet<int>>>({ -""" - - entries = get_entries("outputs") - contents += "\n".join(entries[op_type] for op_type in sorted(entries)) - contents += "\n {\"VarHandleOp\", {true, {}}},\n" - contents += """ }); - - auto it = m->find(op_name); - - if (it == m->end()) return false; - - *output = &it->second; - return true; -} -""" + contents += get_function("OpGradientUnusedInputIndices", + get_entries("inputs")) + contents += get_function("OpGradientUnusedOutputIndices", + get_entries("outputs")) return contents diff --git a/tensorflow/python/eager/pywrap_gradient_exclusions.cc b/tensorflow/python/eager/pywrap_gradient_exclusions.cc index 6647728828c..afae0b57ee7 100644 --- a/tensorflow/python/eager/pywrap_gradient_exclusions.cc +++ b/tensorflow/python/eager/pywrap_gradient_exclusions.cc @@ -20,857 +20,872 @@ limitations under the License. #include "tensorflow/python/eager/pywrap_gradient_exclusions.h" +#include "absl/types/optional.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" using tensorflow::string; -bool OpGradientDoesntRequireInputIndices( - const string& op_name, - std::pair<bool, tensorflow::gtl::FlatSet<int>>** output) { - static tensorflow::gtl::FlatMap< - string, std::pair<bool, tensorflow::gtl::FlatSet<int>>>* m = - new tensorflow::gtl::FlatMap< - string, std::pair<bool, tensorflow::gtl::FlatSet<int>>>({ - {"Acosh", {true, {}}}, - {"AllToAll", {false, {0}}}, - {"ApproximateEqual", {true, {}}}, - {"ArgMax", {true, {}}}, - {"ArgMin", {true, {}}}, - {"AsString", {true, {}}}, - {"Asinh", {true, {}}}, - {"Assign", {true, {}}}, - {"AssignAdd", {true, {}}}, - {"AssignSub", {true, {}}}, - {"AudioSummary", {true, {}}}, - {"AudioSummaryV2", {true, {}}}, - {"AvgPool3DGrad", {false, {1}}}, - {"AvgPoolGrad", {false, {1}}}, - {"BatchNormWithGlobalNormalization", {false, {3}}}, - {"BatchToSpace", {false, {0}}}, - {"BatchToSpaceND", {false, {0}}}, - {"BiasAdd", {true, {}}}, - {"BiasAddV1", {true, {}}}, - {"BitwiseAnd", {true, {}}}, - {"BitwiseOr", {true, {}}}, - {"BitwiseXor", {true, {}}}, - {"BroadcastGradientArgs", {true, {}}}, - {"CSRSparseMatrixToDense", {true, {}}}, - {"CTCBeamSearchDecoder", {true, {}}}, - {"CTCGreedyDecoder", {true, {}}}, - {"CTCLoss", {true, {}}}, - {"CTCLossV2", {true, {}}}, - {"Ceil", {true, {}}}, - {"CheckNumerics", {true, {}}}, - {"CheckNumericsV2", {true, {}}}, - {"Cholesky", {true, {}}}, - {"CollectivePermute", {false, {0}}}, - {"Conj", {true, {}}}, - {"ConjugateTranspose", {false, {0}}}, - {"Const", {true, {}}}, - {"Conv2DBackpropFilter", {false, {1}}}, - {"Conv2DBackpropInput", {false, {0}}}, - {"Conv3DBackpropFilterV2", {false, {1}}}, - {"Conv3DBackpropInputV2", {false, {0}}}, - {"CropAndResize", {false, {3}}}, - {"CrossReplicaSum", {false, {0}}}, - {"Cumsum", {false, {0}}}, - {"DebugGradientIdentity", {true, {}}}, - {"DebugGradientRefIdentity", {true, {}}}, - {"DebugIdentityV2", {true, {}}}, - {"DecodeBase64", {true, {}}}, - {"DecodePaddedRaw", {true, {}}}, - {"DecodeProtoV2", {true, {}}}, - {"DecodeRaw", {true, {}}}, - {"DeleteSessionTensor", {true, {}}}, - {"DenseToCSRSparseMatrix", {true, {}}}, - {"DenseToDenseSetOperation", {true, {}}}, - {"DenseToSparseSetOperation", {true, {}}}, - {"DepthToSpace", {true, {}}}, - {"DepthwiseConv2dNativeBackpropFilter", {false, {1}}}, - {"DepthwiseConv2dNativeBackpropInput", {false, {0}}}, - {"Diag", {true, {}}}, - {"DiagPart", {true, {}}}, - {"DrawBoundingBoxes", {true, {}}}, - {"EditDistance", {true, {}}}, - {"Elu", {true, {}}}, - {"EncodeBase64", {true, {}}}, - {"EnsureShape", {true, {}}}, - {"Enter", {true, {}}}, - {"Equal", {true, {}}}, - {"Erfinv", {true, {}}}, - {"Exit", {true, {}}}, - {"Exp", {true, {}}}, - {"ExpandDims", {false, {1}}}, - {"ExtractGlimpse", {true, {}}}, - {"FFT", {true, {}}}, - {"FFT2D", {true, {}}}, - {"FFT3D", {true, {}}}, - {"Fill", {true, {}}}, - {"FixedLengthRecordReader", {true, {}}}, - {"Floor", {true, {}}}, - {"FloorDiv", {true, {}}}, - {"FusedBatchNorm", {false, {2}}}, - {"FusedBatchNormGradV3", {false, {5}}}, - {"FusedBatchNormV2", {false, {2}}}, - {"FusedBatchNormV3", {false, {2}}}, - {"GenerateBoundingBoxProposals", {true, {}}}, - {"GenerateVocabRemapping", {true, {}}}, - {"GetSessionHandle", {true, {}}}, - {"GetSessionHandleV2", {true, {}}}, - {"GetSessionTensor", {true, {}}}, - {"Greater", {true, {}}}, - {"GreaterEqual", {true, {}}}, - {"HSVToRGB", {true, {}}}, - {"HashTable", {true, {}}}, - {"HashTableV2", {true, {}}}, - {"HistogramSummary", {true, {}}}, - {"IFFT", {true, {}}}, - {"IFFT2D", {true, {}}}, - {"IFFT3D", {true, {}}}, - {"Identity", {true, {}}}, - {"IdentityN", {true, {}}}, - {"IdentityReader", {true, {}}}, - {"Imag", {true, {}}}, - {"ImageProjectiveTransformV2", {false, {2}}}, - {"ImageSummary", {true, {}}}, - {"InitializeTable", {true, {}}}, - {"InitializeTableFromTextFile", {true, {}}}, - {"InitializeTableFromTextFileV2", {true, {}}}, - {"InitializeTableV2", {true, {}}}, - {"Inv", {true, {}}}, - {"Invert", {true, {}}}, - {"InvertPermutation", {true, {}}}, - {"LMDBReader", {true, {}}}, - {"LeakyReluGrad", {false, {0}}}, - {"LeftShift", {true, {}}}, - {"Less", {true, {}}}, - {"LessEqual", {true, {}}}, - {"LinSpace", {true, {}}}, - {"LoadAndRemapMatrix", {true, {}}}, - {"LogSoftmax", {true, {}}}, - {"LogicalAnd", {true, {}}}, - {"LogicalNot", {true, {}}}, - {"LogicalOr", {true, {}}}, - {"LookupTableFind", {true, {}}}, - {"LookupTableFindV2", {true, {}}}, - {"LookupTableInsert", {true, {}}}, - {"LookupTableInsertV2", {true, {}}}, - {"LookupTableSize", {true, {}}}, - {"LookupTableSizeV2", {true, {}}}, - {"LoopCond", {true, {}}}, - {"MatrixBandPart", {false, {0}}}, - {"MatrixDiag", {true, {}}}, - {"MatrixDiagPartV2", {false, {2}}}, - {"MatrixDiagPartV3", {false, {2}}}, - {"MatrixDiagV2", {false, {0, 2, 3, 4}}}, - {"MatrixDiagV3", {false, {0, 2, 3, 4}}}, - {"MatrixInverse", {true, {}}}, - {"MatrixSetDiagV2", {false, {0}}}, - {"MatrixSetDiagV3", {false, {0}}}, - {"MatrixSolve", {false, {1}}}, - {"MatrixSquareRoot", {true, {}}}, - {"MaxPool3DGrad", {false, {2}}}, - {"MaxPool3DGradGrad", {false, {2}}}, - {"MaxPoolGrad", {false, {2}}}, - {"MaxPoolGradGrad", {false, {2}}}, - {"MaxPoolGradV2", {false, {2}}}, - {"MirrorPad", {false, {0}}}, - {"MirrorPadGrad", {false, {0}}}, - {"Multinomial", {true, {}}}, - {"MutableDenseHashTable", {true, {}}}, - {"MutableDenseHashTableV2", {true, {}}}, - {"MutableHashTable", {true, {}}}, - {"MutableHashTableOfTensors", {true, {}}}, - {"MutableHashTableOfTensorsV2", {true, {}}}, - {"MutableHashTableV2", {true, {}}}, - {"NcclAllReduce", {true, {}}}, - {"NcclBroadcast", {true, {}}}, - {"Ndtri", {true, {}}}, - {"Neg", {true, {}}}, - {"NextIteration", {true, {}}}, - {"NonMaxSuppression", {true, {}}}, - {"NonMaxSuppressionV2", {true, {}}}, - {"NonMaxSuppressionWithOverlaps", {true, {}}}, - {"NotEqual", {true, {}}}, - {"NthElement", {false, {1}}}, - {"OneHot", {true, {}}}, - {"OnesLike", {true, {}}}, - {"OptionalGetValue", {true, {}}}, - {"Pack", {true, {}}}, - {"ParameterizedTruncatedNormal", {true, {}}}, - {"ParseTensor", {true, {}}}, - {"PlaceholderWithDefault", {true, {}}}, - {"PopulationCount", {true, {}}}, - {"PreventGradient", {true, {}}}, - {"Qr", {true, {}}}, - {"QuantizeAndDequantize", {true, {}}}, - {"QuantizeAndDequantizeV2", {true, {}}}, - {"QuantizeAndDequantizeV3", {true, {}}}, - {"QueueClose", {true, {}}}, - {"QueueDequeue", {true, {}}}, - {"QueueDequeueMany", {true, {}}}, - {"QueueDequeueUpTo", {true, {}}}, - {"QueueSize", {true, {}}}, - {"RaggedRange", {true, {}}}, - {"RandomCrop", {true, {}}}, - {"RandomStandardNormal", {true, {}}}, - {"RandomUniform", {true, {}}}, - {"Range", {true, {}}}, - {"Rank", {true, {}}}, - {"ReadVariableOp", {true, {}}}, - {"ReaderNumRecordsProduced", {true, {}}}, - {"ReaderNumWorkUnitsCompleted", {true, {}}}, - {"ReaderRead", {true, {}}}, - {"ReaderReadUpTo", {true, {}}}, - {"ReaderReset", {true, {}}}, - {"ReaderRestoreState", {true, {}}}, - {"ReaderSerializeState", {true, {}}}, - {"Real", {true, {}}}, - {"Reciprocal", {true, {}}}, - {"ReduceJoin", {true, {}}}, - {"RefEnter", {true, {}}}, - {"RefExit", {true, {}}}, - {"RefIdentity", {true, {}}}, - {"RefNextIteration", {true, {}}}, - {"RegexReplace", {true, {}}}, - {"Relu", {true, {}}}, - {"Relu6", {true, {}}}, - {"Relu6Grad", {false, {0}}}, - {"ReluGrad", {false, {0}}}, - {"Reshape", {false, {1}}}, - {"ResizeBicubic", {false, {1}}}, - {"ResizeBilinear", {false, {1}}}, - {"ResizeNearestNeighbor", {false, {1}}}, - {"Reverse", {false, {0}}}, - {"ReverseSequence", {false, {0}}}, - {"ReverseV2", {false, {0}}}, - {"RightShift", {true, {}}}, - {"Rint", {true, {}}}, - {"Roll", {false, {0}}}, - {"Round", {true, {}}}, - {"Rsqrt", {true, {}}}, - {"SampleDistortedBoundingBox", {true, {}}}, - {"SampleDistortedBoundingBoxV2", {true, {}}}, - {"ScalarSummary", {true, {}}}, - {"ScaleAndTranslate", {false, {1}}}, - {"ScatterAdd", {true, {}}}, - {"ScatterDiv", {true, {}}}, - {"ScatterMul", {true, {}}}, - {"ScatterNd", {false, {1, 2}}}, - {"ScatterNdAdd", {true, {}}}, - {"ScatterNdNonAliasingAdd", {false, {0, 2}}}, - {"ScatterNdSub", {true, {}}}, - {"ScatterNdUpdate", {true, {}}}, - {"ScatterSub", {true, {}}}, - {"SdcaFprint", {true, {}}}, - {"SegmentSum", {false, {0}}}, - {"Select", {false, {2}}}, - {"Selu", {true, {}}}, - {"SerializeTensor", {true, {}}}, - {"SetSize", {true, {}}}, - {"Shape", {true, {}}}, - {"Sigmoid", {true, {}}}, - {"Size", {true, {}}}, - {"Slice", {false, {2}}}, - {"Softmax", {true, {}}}, - {"SoftmaxCrossEntropyWithLogits", {false, {1}}}, - {"SpaceToBatch", {false, {0}}}, - {"SpaceToBatchND", {false, {0}}}, - {"SpaceToDepth", {true, {}}}, - {"SparseAdd", {false, {2, 5, 6}}}, - {"SparseAddGrad", {true, {}}}, - {"SparseDenseCwiseAdd", {true, {}}}, - {"SparseFillEmptyRows", {true, {}}}, - {"SparseMatrixMul", {true, {}}}, - {"SparseMatrixNNZ", {true, {}}}, - {"SparseMatrixSoftmax", {true, {}}}, - {"SparseMatrixTranspose", {true, {}}}, - {"SparseMatrixZeros", {true, {}}}, - {"SparseReduceSum", {false, {1}}}, - {"SparseReorder", {false, {1}}}, - {"SparseSegmentMeanWithNumSegments", {false, {3}}}, - {"SparseSegmentSqrtNWithNumSegments", {false, {3}}}, - {"SparseSegmentSumWithNumSegments", {false, {3}}}, - {"SparseSlice", {false, {2, 4}}}, - {"SparseSoftmax", {false, {1}}}, - {"SparseSoftmaxCrossEntropyWithLogits", {false, {1}}}, - {"SparseSparseMaximum", {true, {}}}, - {"SparseSparseMinimum", {true, {}}}, - {"SparseTensorDenseAdd", {false, {1, 2, 3}}}, - {"SparseToSparseSetOperation", {true, {}}}, - {"Split", {false, {1}}}, - {"Sqrt", {true, {}}}, - {"SqrtGrad", {false, {1}}}, - {"Stack", {true, {}}}, - {"StackClose", {true, {}}}, - {"StackPop", {true, {}}}, - {"StackPush", {true, {}}}, - {"StatelessMultinomial", {true, {}}}, - {"StatelessRandomBinomial", {true, {}}}, - {"StatelessRandomGammaV2", {false, {1}}}, - {"StatelessRandomNormal", {true, {}}}, - {"StatelessRandomPoisson", {true, {}}}, - {"StatelessRandomUniform", {true, {}}}, - {"StatelessRandomUniformFullInt", {true, {}}}, - {"StatelessRandomUniformInt", {true, {}}}, - {"StatelessTruncatedNormal", {true, {}}}, - {"StopGradient", {true, {}}}, - {"StridedSliceGrad", {false, {0, 4}}}, - {"StringSplit", {true, {}}}, - {"StringToHashBucket", {true, {}}}, - {"StringToHashBucketFast", {true, {}}}, - {"StringToHashBucketStrong", {true, {}}}, - {"StringToNumber", {true, {}}}, - {"TFRecordReader", {true, {}}}, - {"Tanh", {true, {}}}, - {"TensorArray", {true, {}}}, - {"TensorArrayClose", {true, {}}}, - {"TensorArrayCloseV2", {true, {}}}, - {"TensorArrayCloseV3", {true, {}}}, - {"TensorArrayGrad", {true, {}}}, - {"TensorArrayGradV2", {true, {}}}, - {"TensorArrayGradV3", {true, {}}}, - {"TensorArrayGradWithShape", {true, {}}}, - {"TensorArrayScatter", {false, {2, 3}}}, - {"TensorArrayScatterV2", {false, {2, 3}}}, - {"TensorArrayScatterV3", {false, {2, 3}}}, - {"TensorArraySize", {true, {}}}, - {"TensorArraySizeV2", {true, {}}}, - {"TensorArraySizeV3", {true, {}}}, - {"TensorArraySplit", {false, {1, 2, 3}}}, - {"TensorArraySplitV2", {false, {1, 2, 3}}}, - {"TensorArraySplitV3", {false, {1, 2, 3}}}, - {"TensorArrayV2", {true, {}}}, - {"TensorArrayV3", {true, {}}}, - {"TensorArrayWrite", {false, {2, 3}}}, - {"TensorArrayWriteV2", {false, {2, 3}}}, - {"TensorArrayWriteV3", {false, {2, 3}}}, - {"TensorListConcatLists", {true, {}}}, - {"TensorListConcatV2", {false, {1, 2}}}, - {"TensorListElementShape", {true, {}}}, - {"TensorListFromTensor", {false, {1}}}, - {"TensorListGetItem", {false, {2}}}, - {"TensorListLength", {true, {}}}, - {"TensorListPopBack", {true, {}}}, - {"TensorListPushBack", {false, {0}}}, - {"TensorListPushBackBatch", {true, {}}}, - {"TensorListScatter", {false, {2}}}, - {"TensorListScatterV2", {false, {2, 3}}}, - {"TensorListStack", {true, {}}}, - {"TensorScatterAdd", {false, {0, 2}}}, - {"TensorScatterSub", {false, {0, 2}}}, - {"TensorScatterUpdate", {false, {0}}}, - {"TensorSummary", {true, {}}}, - {"TensorSummaryV2", {true, {}}}, - {"TextLineReader", {true, {}}}, - {"Timestamp", {true, {}}}, - {"TopKV2", {false, {1}}}, - {"Transpose", {false, {0}}}, - {"TridiagonalSolve", {false, {1}}}, - {"TruncateDiv", {true, {}}}, - {"TruncatedNormal", {true, {}}}, - {"Unpack", {true, {}}}, - {"UnsortedSegmentSum", {false, {0, 2}}}, - {"VarIsInitializedOp", {true, {}}}, - {"VariableShape", {true, {}}}, - {"WholeFileReader", {true, {}}}, - {"XlaClusterOutput", {true, {}}}, - {"XlaSharding", {true, {}}}, - {"ZerosLike", {true, {}}}, - {"VarHandleOp", {true, {}}}, - }); +namespace { +// Keep static data in a format that's easy to init statically. +struct OpIndexInfo { + const char *op_name; + int num_indices; + std::array<int, 4> unused_indices; +}; - auto it = m->find(op_name); +// Helper function to initialize FlatMap<string,FlatSet> from OpIndexInfo. +template <typename T> +auto OpGradientInfoInit(const T &a) { + auto *m = new tensorflow::gtl::FlatMap<string, tensorflow::gtl::FlatSet<int>>; + for (const auto &item : a) { + m->emplace(string(item.op_name), + tensorflow::gtl::FlatSet<int>( + item.unused_indices.begin(), + item.unused_indices.begin() + item.num_indices)); + } + return m; +} +} // namespace - if (it == m->end()) return false; +absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices( + const tensorflow::string &op_name) { + static std::array<OpIndexInfo, 347> a = {{ + {"Acosh"}, + {"AllToAll", 1, {0}}, + {"ApproximateEqual"}, + {"ArgMax"}, + {"ArgMin"}, + {"AsString"}, + {"Asinh"}, + {"Assign"}, + {"AssignAdd"}, + {"AssignSub"}, + {"AudioSummary"}, + {"AudioSummaryV2"}, + {"AvgPool3DGrad", 1, {1}}, + {"AvgPoolGrad", 1, {1}}, + {"BatchNormWithGlobalNormalization", 1, {3}}, + {"BatchToSpace", 1, {0}}, + {"BatchToSpaceND", 1, {0}}, + {"BiasAdd"}, + {"BiasAddV1"}, + {"BitwiseAnd"}, + {"BitwiseOr"}, + {"BitwiseXor"}, + {"BroadcastGradientArgs"}, + {"CSRSparseMatrixToDense"}, + {"CTCBeamSearchDecoder"}, + {"CTCGreedyDecoder"}, + {"CTCLoss"}, + {"CTCLossV2"}, + {"Ceil"}, + {"CheckNumerics"}, + {"CheckNumericsV2"}, + {"Cholesky"}, + {"CollectivePermute", 1, {0}}, + {"Conj"}, + {"ConjugateTranspose", 1, {0}}, + {"Const"}, + {"Conv2DBackpropFilter", 1, {1}}, + {"Conv2DBackpropInput", 1, {0}}, + {"Conv3DBackpropFilterV2", 1, {1}}, + {"Conv3DBackpropInputV2", 1, {0}}, + {"CropAndResize", 1, {3}}, + {"CrossReplicaSum", 1, {0}}, + {"Cumsum", 1, {0}}, + {"DebugGradientIdentity"}, + {"DebugGradientRefIdentity"}, + {"DebugIdentityV2"}, + {"DecodeBase64"}, + {"DecodePaddedRaw"}, + {"DecodeProtoV2"}, + {"DecodeRaw"}, + {"DeleteSessionTensor"}, + {"DenseToCSRSparseMatrix"}, + {"DenseToDenseSetOperation"}, + {"DenseToSparseSetOperation"}, + {"DepthToSpace"}, + {"DepthwiseConv2dNativeBackpropFilter", 1, {1}}, + {"DepthwiseConv2dNativeBackpropInput", 1, {0}}, + {"Diag"}, + {"DiagPart"}, + {"DrawBoundingBoxes"}, + {"EditDistance"}, + {"Elu"}, + {"EncodeBase64"}, + {"EnsureShape"}, + {"Enter"}, + {"Equal"}, + {"Erfinv"}, + {"Exit"}, + {"Exp"}, + {"ExpandDims", 1, {1}}, + {"ExtractGlimpse"}, + {"FFT"}, + {"FFT2D"}, + {"FFT3D"}, + {"Fill"}, + {"FixedLengthRecordReader"}, + {"Floor"}, + {"FloorDiv"}, + {"FusedBatchNorm", 1, {2}}, + {"FusedBatchNormGradV3", 1, {5}}, + {"FusedBatchNormV2", 1, {2}}, + {"FusedBatchNormV3", 1, {2}}, + {"GenerateBoundingBoxProposals"}, + {"GenerateVocabRemapping"}, + {"GetSessionHandle"}, + {"GetSessionHandleV2"}, + {"GetSessionTensor"}, + {"Greater"}, + {"GreaterEqual"}, + {"HSVToRGB"}, + {"HashTable"}, + {"HashTableV2"}, + {"HistogramSummary"}, + {"IFFT"}, + {"IFFT2D"}, + {"IFFT3D"}, + {"Identity"}, + {"IdentityN"}, + {"IdentityReader"}, + {"Imag"}, + {"ImageProjectiveTransformV2", 1, {2}}, + {"ImageSummary"}, + {"InitializeTable"}, + {"InitializeTableFromTextFile"}, + {"InitializeTableFromTextFileV2"}, + {"InitializeTableV2"}, + {"Inv"}, + {"Invert"}, + {"InvertPermutation"}, + {"LMDBReader"}, + {"LeakyReluGrad", 1, {0}}, + {"LeftShift"}, + {"Less"}, + {"LessEqual"}, + {"LinSpace"}, + {"LoadAndRemapMatrix"}, + {"LogSoftmax"}, + {"LogicalAnd"}, + {"LogicalNot"}, + {"LogicalOr"}, + {"LookupTableFind"}, + {"LookupTableFindV2"}, + {"LookupTableInsert"}, + {"LookupTableInsertV2"}, + {"LookupTableSize"}, + {"LookupTableSizeV2"}, + {"LoopCond"}, + {"MatrixBandPart", 1, {0}}, + {"MatrixDiag"}, + {"MatrixDiagPartV2", 1, {2}}, + {"MatrixDiagPartV3", 1, {2}}, + {"MatrixDiagV2", 4, {0, 2, 3, 4}}, + {"MatrixDiagV3", 4, {0, 2, 3, 4}}, + {"MatrixInverse"}, + {"MatrixSetDiagV2", 1, {0}}, + {"MatrixSetDiagV3", 1, {0}}, + {"MatrixSolve", 1, {1}}, + {"MatrixSquareRoot"}, + {"MaxPool3DGrad", 1, {2}}, + {"MaxPool3DGradGrad", 1, {2}}, + {"MaxPoolGrad", 1, {2}}, + {"MaxPoolGradGrad", 1, {2}}, + {"MaxPoolGradV2", 1, {2}}, + {"MirrorPad", 1, {0}}, + {"MirrorPadGrad", 1, {0}}, + {"Multinomial"}, + {"MutableDenseHashTable"}, + {"MutableDenseHashTableV2"}, + {"MutableHashTable"}, + {"MutableHashTableOfTensors"}, + {"MutableHashTableOfTensorsV2"}, + {"MutableHashTableV2"}, + {"NcclAllReduce"}, + {"NcclBroadcast"}, + {"Ndtri"}, + {"Neg"}, + {"NextIteration"}, + {"NonMaxSuppression"}, + {"NonMaxSuppressionV2"}, + {"NonMaxSuppressionWithOverlaps"}, + {"NotEqual"}, + {"NthElement", 1, {1}}, + {"OneHot"}, + {"OnesLike"}, + {"OptionalGetValue"}, + {"Pack"}, + {"ParameterizedTruncatedNormal"}, + {"ParseTensor"}, + {"PlaceholderWithDefault"}, + {"PopulationCount"}, + {"PreventGradient"}, + {"Qr"}, + {"QuantizeAndDequantize"}, + {"QuantizeAndDequantizeV2"}, + {"QuantizeAndDequantizeV3"}, + {"QueueClose"}, + {"QueueDequeue"}, + {"QueueDequeueMany"}, + {"QueueDequeueUpTo"}, + {"QueueSize"}, + {"RaggedRange"}, + {"RandomCrop"}, + {"RandomStandardNormal"}, + {"RandomUniform"}, + {"Range"}, + {"Rank"}, + {"ReadVariableOp"}, + {"ReaderNumRecordsProduced"}, + {"ReaderNumWorkUnitsCompleted"}, + {"ReaderRead"}, + {"ReaderReadUpTo"}, + {"ReaderReset"}, + {"ReaderRestoreState"}, + {"ReaderSerializeState"}, + {"Real"}, + {"Reciprocal"}, + {"ReduceJoin"}, + {"RefEnter"}, + {"RefExit"}, + {"RefIdentity"}, + {"RefNextIteration"}, + {"RegexReplace"}, + {"Relu"}, + {"Relu6"}, + {"Relu6Grad", 1, {0}}, + {"ReluGrad", 1, {0}}, + {"Reshape", 1, {1}}, + {"ResizeBicubic", 1, {1}}, + {"ResizeBilinear", 1, {1}}, + {"ResizeNearestNeighbor", 1, {1}}, + {"Reverse", 1, {0}}, + {"ReverseSequence", 1, {0}}, + {"ReverseV2", 1, {0}}, + {"RightShift"}, + {"Rint"}, + {"Roll", 1, {0}}, + {"Round"}, + {"Rsqrt"}, + {"SampleDistortedBoundingBox"}, + {"SampleDistortedBoundingBoxV2"}, + {"ScalarSummary"}, + {"ScaleAndTranslate", 1, {1}}, + {"ScatterAdd"}, + {"ScatterDiv"}, + {"ScatterMul"}, + {"ScatterNd", 2, {1, 2}}, + {"ScatterNdAdd"}, + {"ScatterNdNonAliasingAdd", 2, {0, 2}}, + {"ScatterNdSub"}, + {"ScatterNdUpdate"}, + {"ScatterSub"}, + {"SdcaFprint"}, + {"SegmentSum", 1, {0}}, + {"Select", 1, {2}}, + {"Selu"}, + {"SerializeTensor"}, + {"SetSize"}, + {"Shape"}, + {"Sigmoid"}, + {"Size"}, + {"Slice", 1, {2}}, + {"Softmax"}, + {"SoftmaxCrossEntropyWithLogits", 1, {1}}, + {"SpaceToBatch", 1, {0}}, + {"SpaceToBatchND", 1, {0}}, + {"SpaceToDepth"}, + {"SparseAdd", 3, {2, 5, 6}}, + {"SparseAddGrad"}, + {"SparseDenseCwiseAdd"}, + {"SparseFillEmptyRows"}, + {"SparseMatrixMul"}, + {"SparseMatrixNNZ"}, + {"SparseMatrixSoftmax"}, + {"SparseMatrixTranspose"}, + {"SparseMatrixZeros"}, + {"SparseReduceSum", 1, {1}}, + {"SparseReorder", 1, {1}}, + {"SparseSegmentMeanWithNumSegments", 1, {3}}, + {"SparseSegmentSqrtNWithNumSegments", 1, {3}}, + {"SparseSegmentSumWithNumSegments", 1, {3}}, + {"SparseSlice", 2, {2, 4}}, + {"SparseSoftmax", 1, {1}}, + {"SparseSoftmaxCrossEntropyWithLogits", 1, {1}}, + {"SparseSparseMaximum"}, + {"SparseSparseMinimum"}, + {"SparseTensorDenseAdd", 3, {1, 2, 3}}, + {"SparseToSparseSetOperation"}, + {"Split", 1, {1}}, + {"Sqrt"}, + {"SqrtGrad", 1, {1}}, + {"Stack"}, + {"StackClose"}, + {"StackPop"}, + {"StackPush"}, + {"StatelessMultinomial"}, + {"StatelessRandomBinomial"}, + {"StatelessRandomGammaV2", 1, {1}}, + {"StatelessRandomNormal"}, + {"StatelessRandomPoisson"}, + {"StatelessRandomUniform"}, + {"StatelessRandomUniformFullInt"}, + {"StatelessRandomUniformInt"}, + {"StatelessTruncatedNormal"}, + {"StopGradient"}, + {"StridedSliceGrad", 2, {0, 4}}, + {"StringSplit"}, + {"StringToHashBucket"}, + {"StringToHashBucketFast"}, + {"StringToHashBucketStrong"}, + {"StringToNumber"}, + {"TFRecordReader"}, + {"Tanh"}, + {"TensorArray"}, + {"TensorArrayClose"}, + {"TensorArrayCloseV2"}, + {"TensorArrayCloseV3"}, + {"TensorArrayGrad"}, + {"TensorArrayGradV2"}, + {"TensorArrayGradV3"}, + {"TensorArrayGradWithShape"}, + {"TensorArrayScatter", 2, {2, 3}}, + {"TensorArrayScatterV2", 2, {2, 3}}, + {"TensorArrayScatterV3", 2, {2, 3}}, + {"TensorArraySize"}, + {"TensorArraySizeV2"}, + {"TensorArraySizeV3"}, + {"TensorArraySplit", 3, {1, 2, 3}}, + {"TensorArraySplitV2", 3, {1, 2, 3}}, + {"TensorArraySplitV3", 3, {1, 2, 3}}, + {"TensorArrayV2"}, + {"TensorArrayV3"}, + {"TensorArrayWrite", 2, {2, 3}}, + {"TensorArrayWriteV2", 2, {2, 3}}, + {"TensorArrayWriteV3", 2, {2, 3}}, + {"TensorListConcatLists"}, + {"TensorListConcatV2", 2, {1, 2}}, + {"TensorListElementShape"}, + {"TensorListFromTensor", 1, {1}}, + {"TensorListGetItem", 1, {2}}, + {"TensorListLength"}, + {"TensorListPopBack"}, + {"TensorListPushBack", 1, {0}}, + {"TensorListPushBackBatch"}, + {"TensorListScatter", 1, {2}}, + {"TensorListScatterV2", 2, {2, 3}}, + {"TensorListStack"}, + {"TensorScatterAdd", 2, {0, 2}}, + {"TensorScatterSub", 2, {0, 2}}, + {"TensorScatterUpdate", 1, {0}}, + {"TensorSummary"}, + {"TensorSummaryV2"}, + {"TextLineReader"}, + {"Timestamp"}, + {"TopKV2", 1, {1}}, + {"Transpose", 1, {0}}, + {"TridiagonalSolve", 1, {1}}, + {"TruncateDiv"}, + {"TruncatedNormal"}, + {"Unpack"}, + {"UnsortedSegmentSum", 2, {0, 2}}, + {"VarIsInitializedOp"}, + {"VariableShape"}, + {"WholeFileReader"}, + {"XlaClusterOutput"}, + {"XlaSharding"}, + {"ZerosLike"}, + {"VarHandleOp"}, + }}; + static const auto &m = *OpGradientInfoInit(a); - *output = &it->second; - return true; + auto it = m.find(op_name); + if (it != m.end()) { + return it->second; + } + return absl::nullopt; } -bool OpGradientDoesntRequireOutputIndices( - const string& op_name, - std::pair<bool, tensorflow::gtl::FlatSet<int>>** output) { - static tensorflow::gtl::FlatMap< - string, std::pair<bool, tensorflow::gtl::FlatSet<int>>>* m = - new tensorflow::gtl::FlatMap< - string, std::pair<bool, tensorflow::gtl::FlatSet<int>>>({ - {"Abs", {true, {}}}, - {"AccumulateNV2", {true, {}}}, - {"Acos", {true, {}}}, - {"Add", {true, {}}}, - {"AddN", {true, {}}}, - {"AddV2", {true, {}}}, - {"AllToAll", {true, {}}}, - {"Angle", {true, {}}}, - {"ApproximateEqual", {true, {}}}, - {"ArgMax", {true, {}}}, - {"ArgMin", {true, {}}}, - {"AsString", {true, {}}}, - {"Asin", {true, {}}}, - {"Assert", {true, {}}}, - {"Assign", {true, {}}}, - {"AssignAdd", {true, {}}}, - {"AssignSub", {true, {}}}, - {"Atan", {true, {}}}, - {"Atan2", {true, {}}}, - {"Atanh", {true, {}}}, - {"AudioSummary", {true, {}}}, - {"AudioSummaryV2", {true, {}}}, - {"AvgPool", {true, {}}}, - {"AvgPool3D", {true, {}}}, - {"AvgPool3DGrad", {true, {}}}, - {"AvgPoolGrad", {true, {}}}, - {"BatchMatMul", {true, {}}}, - {"BatchMatMulV2", {true, {}}}, - {"BatchNormWithGlobalNormalization", {true, {}}}, - {"BatchToSpace", {true, {}}}, - {"BatchToSpaceND", {true, {}}}, - {"Betainc", {true, {}}}, - {"BiasAdd", {true, {}}}, - {"BiasAddGrad", {true, {}}}, - {"BiasAddV1", {true, {}}}, - {"BitwiseAnd", {true, {}}}, - {"BitwiseOr", {true, {}}}, - {"BitwiseXor", {true, {}}}, - {"BroadcastGradientArgs", {true, {}}}, - {"BroadcastTo", {true, {}}}, - {"CSRSparseMatrixToDense", {true, {}}}, - {"CTCGreedyDecoder", {true, {}}}, - {"CTCLoss", {false, {0}}}, - {"CTCLossV2", {false, {0}}}, - {"Cast", {true, {}}}, - {"Ceil", {true, {}}}, - {"CheckNumerics", {true, {}}}, - {"CheckNumericsV2", {true, {}}}, - {"CollectivePermute", {true, {}}}, - {"Complex", {true, {}}}, - {"Concat", {true, {}}}, - {"ConcatV2", {true, {}}}, - {"Conj", {true, {}}}, - {"ConjugateTranspose", {true, {}}}, - {"Const", {true, {}}}, - {"Conv2D", {true, {}}}, - {"Conv2DBackpropFilter", {true, {}}}, - {"Conv2DBackpropInput", {true, {}}}, - {"Conv3D", {true, {}}}, - {"Conv3DBackpropFilterV2", {true, {}}}, - {"Conv3DBackpropInputV2", {true, {}}}, - {"Cos", {true, {}}}, - {"Cosh", {true, {}}}, - {"CropAndResize", {true, {}}}, - {"Cross", {true, {}}}, - {"CrossReplicaSum", {true, {}}}, - {"Cumprod", {true, {}}}, - {"Cumsum", {true, {}}}, - {"DebugGradientIdentity", {true, {}}}, - {"DebugGradientRefIdentity", {true, {}}}, - {"DebugIdentityV2", {true, {}}}, - {"DecodeBase64", {true, {}}}, - {"DecodePaddedRaw", {true, {}}}, - {"DecodeRaw", {true, {}}}, - {"DeleteSessionTensor", {true, {}}}, - {"DenseToCSRSparseMatrix", {true, {}}}, - {"DenseToDenseSetOperation", {true, {}}}, - {"DenseToSparseSetOperation", {true, {}}}, - {"DepthToSpace", {true, {}}}, - {"DepthwiseConv2dNative", {true, {}}}, - {"DepthwiseConv2dNativeBackpropFilter", {true, {}}}, - {"DepthwiseConv2dNativeBackpropInput", {true, {}}}, - {"Diag", {true, {}}}, - {"DiagPart", {true, {}}}, - {"Digamma", {true, {}}}, - {"Dilation2D", {true, {}}}, - {"Div", {true, {}}}, - {"DivNoNan", {true, {}}}, - {"DrawBoundingBoxes", {true, {}}}, - {"DynamicPartition", {true, {}}}, - {"EditDistance", {true, {}}}, - {"Einsum", {true, {}}}, - {"EluGrad", {true, {}}}, - {"EncodeBase64", {true, {}}}, - {"EncodeProto", {true, {}}}, - {"EnsureShape", {true, {}}}, - {"Enter", {true, {}}}, - {"Equal", {true, {}}}, - {"Erf", {true, {}}}, - {"Erfc", {true, {}}}, - {"Exit", {true, {}}}, - {"ExpandDims", {true, {}}}, - {"Expint", {true, {}}}, - {"Expm1", {true, {}}}, - {"ExtractGlimpse", {true, {}}}, - {"FFT", {true, {}}}, - {"FFT2D", {true, {}}}, - {"FFT3D", {true, {}}}, - {"FakeQuantWithMinMaxArgs", {true, {}}}, - {"FakeQuantWithMinMaxVars", {true, {}}}, - {"FakeQuantWithMinMaxVarsPerChannel", {true, {}}}, - {"Fill", {true, {}}}, - {"FixedLengthRecordReader", {true, {}}}, - {"Floor", {true, {}}}, - {"FloorDiv", {true, {}}}, - {"FloorMod", {true, {}}}, - {"FractionalAvgPool", {false, {0}}}, - {"FresnelCos", {true, {}}}, - {"FresnelSin", {true, {}}}, - {"FusedBatchNorm", {false, {0, 1, 2}}}, - {"FusedBatchNormGrad", {true, {}}}, - {"FusedBatchNormGradV2", {true, {}}}, - {"FusedBatchNormGradV3", {true, {}}}, - {"FusedBatchNormV2", {false, {0, 1, 2}}}, - {"FusedBatchNormV3", {false, {0, 1, 2}}}, - {"Gather", {true, {}}}, - {"GatherNd", {true, {}}}, - {"GatherV2", {true, {}}}, - {"GenerateBoundingBoxProposals", {true, {}}}, - {"GenerateVocabRemapping", {true, {}}}, - {"GetSessionHandle", {true, {}}}, - {"GetSessionHandleV2", {true, {}}}, - {"GetSessionTensor", {true, {}}}, - {"Greater", {true, {}}}, - {"GreaterEqual", {true, {}}}, - {"HSVToRGB", {true, {}}}, - {"HashTable", {true, {}}}, - {"HashTableV2", {true, {}}}, - {"HistogramSummary", {true, {}}}, - {"IFFT", {true, {}}}, - {"IFFT2D", {true, {}}}, - {"IFFT3D", {true, {}}}, - {"IRFFT", {true, {}}}, - {"IRFFT2D", {true, {}}}, - {"Identity", {true, {}}}, - {"IdentityN", {true, {}}}, - {"IdentityReader", {true, {}}}, - {"Igamma", {true, {}}}, - {"Igammac", {true, {}}}, - {"Imag", {true, {}}}, - {"ImageProjectiveTransformV2", {true, {}}}, - {"ImageSummary", {true, {}}}, - {"InitializeTable", {true, {}}}, - {"InitializeTableFromTextFile", {true, {}}}, - {"InitializeTableFromTextFileV2", {true, {}}}, - {"InitializeTableV2", {true, {}}}, - {"InvGrad", {true, {}}}, - {"Invert", {true, {}}}, - {"InvertPermutation", {true, {}}}, - {"L2Loss", {true, {}}}, - {"LMDBReader", {true, {}}}, - {"LeakyRelu", {true, {}}}, - {"LeakyReluGrad", {true, {}}}, - {"LeftShift", {true, {}}}, - {"Less", {true, {}}}, - {"LessEqual", {true, {}}}, - {"Lgamma", {true, {}}}, - {"LinSpace", {true, {}}}, - {"LoadAndRemapMatrix", {true, {}}}, - {"Log", {true, {}}}, - {"Log1p", {true, {}}}, - {"LogMatrixDeterminant", {false, {0}}}, - {"LogicalAnd", {true, {}}}, - {"LogicalNot", {true, {}}}, - {"LogicalOr", {true, {}}}, - {"LookupTableFind", {true, {}}}, - {"LookupTableFindV2", {true, {}}}, - {"LookupTableInsert", {true, {}}}, - {"LookupTableInsertV2", {true, {}}}, - {"LookupTableSize", {true, {}}}, - {"LookupTableSizeV2", {true, {}}}, - {"LoopCond", {true, {}}}, - {"MatMul", {true, {}}}, - {"MatrixBandPart", {true, {}}}, - {"MatrixDiag", {true, {}}}, - {"MatrixDiagPart", {true, {}}}, - {"MatrixDiagPartV2", {true, {}}}, - {"MatrixDiagPartV3", {true, {}}}, - {"MatrixDiagV2", {true, {}}}, - {"MatrixDiagV3", {true, {}}}, - {"MatrixSetDiag", {true, {}}}, - {"MatrixSetDiagV2", {true, {}}}, - {"MatrixSetDiagV3", {true, {}}}, - {"MaxPool3DGrad", {true, {}}}, - {"MaxPool3DGradGrad", {true, {}}}, - {"MaxPoolGrad", {true, {}}}, - {"MaxPoolGradGrad", {true, {}}}, - {"MaxPoolGradV2", {true, {}}}, - {"MaxPoolWithArgmax", {false, {0}}}, - {"Maximum", {true, {}}}, - {"Merge", {false, {0}}}, - {"MergeSummary", {true, {}}}, - {"Minimum", {true, {}}}, - {"MirrorPad", {true, {}}}, - {"MirrorPadGrad", {true, {}}}, - {"Mul", {true, {}}}, - {"MulNoNan", {true, {}}}, - {"Multinomial", {true, {}}}, - {"MutableDenseHashTable", {true, {}}}, - {"MutableDenseHashTableV2", {true, {}}}, - {"MutableHashTable", {true, {}}}, - {"MutableHashTableOfTensors", {true, {}}}, - {"MutableHashTableOfTensorsV2", {true, {}}}, - {"MutableHashTableV2", {true, {}}}, - {"NcclAllReduce", {true, {}}}, - {"NcclBroadcast", {true, {}}}, - {"NcclReduce", {true, {}}}, - {"Neg", {true, {}}}, - {"NextAfter", {true, {}}}, - {"NextIteration", {true, {}}}, - {"NonMaxSuppression", {true, {}}}, - {"NonMaxSuppressionV2", {true, {}}}, - {"NonMaxSuppressionWithOverlaps", {true, {}}}, - {"NotEqual", {true, {}}}, - {"OneHot", {true, {}}}, - {"OnesLike", {true, {}}}, - {"OptionalFromValue", {true, {}}}, - {"OptionalGetValue", {true, {}}}, - {"Pack", {true, {}}}, - {"Pad", {true, {}}}, - {"PadV2", {true, {}}}, - {"ParameterizedTruncatedNormal", {true, {}}}, - {"ParseTensor", {true, {}}}, - {"PlaceholderWithDefault", {true, {}}}, - {"Polygamma", {true, {}}}, - {"PopulationCount", {true, {}}}, - {"PreventGradient", {true, {}}}, - {"Print", {true, {}}}, - {"Prod", {true, {}}}, - {"QuantizeAndDequantize", {true, {}}}, - {"QuantizeAndDequantizeV2", {true, {}}}, - {"QuantizeAndDequantizeV3", {true, {}}}, - {"QueueClose", {true, {}}}, - {"QueueEnqueue", {true, {}}}, - {"QueueEnqueueMany", {true, {}}}, - {"QueueSize", {true, {}}}, - {"RFFT", {true, {}}}, - {"RFFT2D", {true, {}}}, - {"RaggedGather", {true, {}}}, - {"RaggedRange", {true, {}}}, - {"RaggedTensorToSparse", {true, {}}}, - {"RaggedTensorToTensor", {true, {}}}, - {"RaggedTensorToVariant", {true, {}}}, - {"RandomCrop", {true, {}}}, - {"RandomStandardNormal", {true, {}}}, - {"RandomUniform", {true, {}}}, - {"Range", {true, {}}}, - {"Rank", {true, {}}}, - {"ReadVariableOp", {true, {}}}, - {"ReaderNumRecordsProduced", {true, {}}}, - {"ReaderNumWorkUnitsCompleted", {true, {}}}, - {"ReaderRead", {true, {}}}, - {"ReaderReadUpTo", {true, {}}}, - {"ReaderReset", {true, {}}}, - {"ReaderRestoreState", {true, {}}}, - {"ReaderSerializeState", {true, {}}}, - {"Real", {true, {}}}, - {"RealDiv", {true, {}}}, - {"ReciprocalGrad", {true, {}}}, - {"ReduceJoin", {true, {}}}, - {"RefEnter", {true, {}}}, - {"RefExit", {true, {}}}, - {"RefIdentity", {true, {}}}, - {"RefMerge", {false, {0}}}, - {"RefNextIteration", {true, {}}}, - {"RefSwitch", {true, {}}}, - {"RegexReplace", {true, {}}}, - {"Relu6Grad", {true, {}}}, - {"ReluGrad", {true, {}}}, - {"Reshape", {true, {}}}, - {"ResizeBicubic", {true, {}}}, - {"ResizeBilinear", {true, {}}}, - {"ResizeNearestNeighbor", {true, {}}}, - {"ResourceGather", {true, {}}}, - {"ResourceGatherNd", {true, {}}}, - {"Reverse", {true, {}}}, - {"ReverseSequence", {true, {}}}, - {"ReverseV2", {true, {}}}, - {"RightShift", {true, {}}}, - {"Rint", {true, {}}}, - {"Roll", {true, {}}}, - {"Round", {true, {}}}, - {"RsqrtGrad", {true, {}}}, - {"SampleDistortedBoundingBox", {true, {}}}, - {"SampleDistortedBoundingBoxV2", {true, {}}}, - {"ScalarSummary", {true, {}}}, - {"ScaleAndTranslate", {true, {}}}, - {"ScatterAdd", {true, {}}}, - {"ScatterDiv", {true, {}}}, - {"ScatterMul", {true, {}}}, - {"ScatterNd", {true, {}}}, - {"ScatterNdAdd", {true, {}}}, - {"ScatterNdNonAliasingAdd", {true, {}}}, - {"ScatterNdSub", {true, {}}}, - {"ScatterNdUpdate", {true, {}}}, - {"ScatterSub", {true, {}}}, - {"SdcaFprint", {true, {}}}, - {"SdcaShrinkL1", {true, {}}}, - {"SegmentMean", {true, {}}}, - {"SegmentSum", {true, {}}}, - {"Select", {true, {}}}, - {"SeluGrad", {true, {}}}, - {"SerializeTensor", {true, {}}}, - {"SetSize", {true, {}}}, - {"Shape", {true, {}}}, - {"SigmoidGrad", {true, {}}}, - {"Sign", {true, {}}}, - {"Sin", {true, {}}}, - {"Sinh", {true, {}}}, - {"Size", {true, {}}}, - {"SoftmaxCrossEntropyWithLogits", {false, {0}}}, - {"Softplus", {true, {}}}, - {"SoftplusGrad", {true, {}}}, - {"Softsign", {true, {}}}, - {"SpaceToBatch", {true, {}}}, - {"SpaceToBatchND", {true, {}}}, - {"SpaceToDepth", {true, {}}}, - {"SparseAdd", {false, {1, 2}}}, - {"SparseAddGrad", {true, {}}}, - {"SparseConcat", {true, {}}}, - {"SparseDenseCwiseAdd", {true, {}}}, - {"SparseDenseCwiseDiv", {true, {}}}, - {"SparseDenseCwiseMul", {true, {}}}, - {"SparseFillEmptyRows", {false, {0, 1, 2}}}, - {"SparseMatMul", {true, {}}}, - {"SparseMatrixAdd", {true, {}}}, - {"SparseMatrixMatMul", {true, {}}}, - {"SparseMatrixMul", {true, {}}}, - {"SparseMatrixNNZ", {true, {}}}, - {"SparseMatrixSparseMatMul", {true, {}}}, - {"SparseMatrixTranspose", {true, {}}}, - {"SparseMatrixZeros", {true, {}}}, - {"SparseReduceSum", {true, {}}}, - {"SparseReorder", {true, {}}}, - {"SparseSegmentMean", {true, {}}}, - {"SparseSegmentMeanWithNumSegments", {true, {}}}, - {"SparseSegmentSqrtN", {true, {}}}, - {"SparseSegmentSqrtNWithNumSegments", {true, {}}}, - {"SparseSegmentSum", {true, {}}}, - {"SparseSegmentSumWithNumSegments", {true, {}}}, - {"SparseSlice", {false, {1, 2}}}, - {"SparseSoftmaxCrossEntropyWithLogits", {false, {0}}}, - {"SparseSparseMaximum", {true, {}}}, - {"SparseSparseMinimum", {true, {}}}, - {"SparseTensorDenseAdd", {true, {}}}, - {"SparseTensorDenseMatMul", {true, {}}}, - {"SparseToDense", {true, {}}}, - {"SparseToSparseSetOperation", {true, {}}}, - {"Spence", {true, {}}}, - {"Split", {true, {}}}, - {"SplitV", {true, {}}}, - {"Square", {true, {}}}, - {"SquaredDifference", {true, {}}}, - {"Squeeze", {true, {}}}, - {"Stack", {true, {}}}, - {"StackClose", {true, {}}}, - {"StackPop", {true, {}}}, - {"StackPush", {true, {}}}, - {"StatelessMultinomial", {true, {}}}, - {"StatelessRandomBinomial", {true, {}}}, - {"StatelessRandomNormal", {true, {}}}, - {"StatelessRandomPoisson", {true, {}}}, - {"StatelessRandomUniform", {true, {}}}, - {"StatelessRandomUniformFullInt", {true, {}}}, - {"StatelessRandomUniformInt", {true, {}}}, - {"StatelessTruncatedNormal", {true, {}}}, - {"StopGradient", {true, {}}}, - {"StridedSlice", {true, {}}}, - {"StridedSliceGrad", {true, {}}}, - {"StringJoin", {true, {}}}, - {"StringSplit", {true, {}}}, - {"StringToHashBucket", {true, {}}}, - {"StringToHashBucketFast", {true, {}}}, - {"StringToHashBucketStrong", {true, {}}}, - {"StringToNumber", {true, {}}}, - {"Sub", {true, {}}}, - {"Sum", {true, {}}}, - {"Switch", {true, {}}}, - {"TFRecordReader", {true, {}}}, - {"TPUEmbeddingActivations", {true, {}}}, - {"TPUReplicatedInput", {true, {}}}, - {"Tan", {true, {}}}, - {"TanhGrad", {true, {}}}, - {"TensorArray", {true, {}}}, - {"TensorArrayClose", {true, {}}}, - {"TensorArrayCloseV2", {true, {}}}, - {"TensorArrayCloseV3", {true, {}}}, - {"TensorArrayConcat", {false, {0}}}, - {"TensorArrayConcatV2", {false, {0}}}, - {"TensorArrayConcatV3", {false, {0}}}, - {"TensorArrayGather", {true, {}}}, - {"TensorArrayGatherV2", {true, {}}}, - {"TensorArrayGatherV3", {true, {}}}, - {"TensorArrayGrad", {true, {}}}, - {"TensorArrayGradV2", {true, {}}}, - {"TensorArrayGradV3", {true, {}}}, - {"TensorArrayGradWithShape", {true, {}}}, - {"TensorArrayRead", {true, {}}}, - {"TensorArrayReadV2", {true, {}}}, - {"TensorArrayReadV3", {true, {}}}, - {"TensorArrayScatter", {true, {}}}, - {"TensorArrayScatterV2", {true, {}}}, - {"TensorArrayScatterV3", {true, {}}}, - {"TensorArraySize", {true, {}}}, - {"TensorArraySizeV2", {true, {}}}, - {"TensorArraySizeV3", {true, {}}}, - {"TensorArraySplit", {true, {}}}, - {"TensorArraySplitV2", {true, {}}}, - {"TensorArraySplitV3", {true, {}}}, - {"TensorArrayV2", {true, {}}}, - {"TensorArrayV3", {true, {}}}, - {"TensorArrayWrite", {true, {}}}, - {"TensorArrayWriteV2", {true, {}}}, - {"TensorArrayWriteV3", {true, {}}}, - {"TensorListConcat", {false, {0}}}, - {"TensorListConcatLists", {true, {}}}, - {"TensorListConcatV2", {false, {0}}}, - {"TensorListElementShape", {true, {}}}, - {"TensorListGather", {true, {}}}, - {"TensorListGetItem", {true, {}}}, - {"TensorListLength", {true, {}}}, - {"TensorListPopBack", {false, {1}}}, - {"TensorListPushBack", {true, {}}}, - {"TensorListPushBackBatch", {true, {}}}, - {"TensorListResize", {true, {}}}, - {"TensorListScatter", {true, {}}}, - {"TensorListScatterIntoExistingList", {true, {}}}, - {"TensorListScatterV2", {true, {}}}, - {"TensorListSetItem", {true, {}}}, - {"TensorListSplit", {true, {}}}, - {"TensorListStack", {true, {}}}, - {"TensorScatterAdd", {true, {}}}, - {"TensorScatterSub", {true, {}}}, - {"TensorScatterUpdate", {true, {}}}, - {"TensorSummary", {true, {}}}, - {"TensorSummaryV2", {true, {}}}, - {"TextLineReader", {true, {}}}, - {"Tile", {true, {}}}, - {"Timestamp", {true, {}}}, - {"TopK", {false, {0}}}, - {"TopKV2", {false, {0}}}, - {"Transpose", {true, {}}}, - {"TridiagonalMatMul", {true, {}}}, - {"TruncateDiv", {true, {}}}, - {"TruncatedNormal", {true, {}}}, - {"Unpack", {true, {}}}, - {"UnsortedSegmentSum", {true, {}}}, - {"VarIsInitializedOp", {true, {}}}, - {"VariableShape", {true, {}}}, - {"WholeFileReader", {true, {}}}, - {"Xdivy", {true, {}}}, - {"XlaClusterOutput", {true, {}}}, - {"XlaEinsum", {true, {}}}, - {"XlaSharding", {true, {}}}, - {"Xlog1py", {true, {}}}, - {"Xlogy", {true, {}}}, - {"ZerosLike", {true, {}}}, - {"Zeta", {true, {}}}, - {"VarHandleOp", {true, {}}}, - }); +absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedOutputIndices( + const tensorflow::string &op_name) { + static std::array<OpIndexInfo, 469> a = {{ + {"Abs"}, + {"AccumulateNV2"}, + {"Acos"}, + {"Add"}, + {"AddN"}, + {"AddV2"}, + {"AllToAll"}, + {"Angle"}, + {"ApproximateEqual"}, + {"ArgMax"}, + {"ArgMin"}, + {"AsString"}, + {"Asin"}, + {"Assert"}, + {"Assign"}, + {"AssignAdd"}, + {"AssignSub"}, + {"Atan"}, + {"Atan2"}, + {"Atanh"}, + {"AudioSummary"}, + {"AudioSummaryV2"}, + {"AvgPool"}, + {"AvgPool3D"}, + {"AvgPool3DGrad"}, + {"AvgPoolGrad"}, + {"BatchMatMul"}, + {"BatchMatMulV2"}, + {"BatchNormWithGlobalNormalization"}, + {"BatchToSpace"}, + {"BatchToSpaceND"}, + {"Betainc"}, + {"BiasAdd"}, + {"BiasAddGrad"}, + {"BiasAddV1"}, + {"BitwiseAnd"}, + {"BitwiseOr"}, + {"BitwiseXor"}, + {"BroadcastGradientArgs"}, + {"BroadcastTo"}, + {"CSRSparseMatrixToDense"}, + {"CTCGreedyDecoder"}, + {"CTCLoss", 1, {0}}, + {"CTCLossV2", 1, {0}}, + {"Cast"}, + {"Ceil"}, + {"CheckNumerics"}, + {"CheckNumericsV2"}, + {"CollectivePermute"}, + {"Complex"}, + {"Concat"}, + {"ConcatV2"}, + {"Conj"}, + {"ConjugateTranspose"}, + {"Const"}, + {"Conv2D"}, + {"Conv2DBackpropFilter"}, + {"Conv2DBackpropInput"}, + {"Conv3D"}, + {"Conv3DBackpropFilterV2"}, + {"Conv3DBackpropInputV2"}, + {"Cos"}, + {"Cosh"}, + {"CropAndResize"}, + {"Cross"}, + {"CrossReplicaSum"}, + {"Cumprod"}, + {"Cumsum"}, + {"DebugGradientIdentity"}, + {"DebugGradientRefIdentity"}, + {"DebugIdentityV2"}, + {"DecodeBase64"}, + {"DecodePaddedRaw"}, + {"DecodeRaw"}, + {"DeleteSessionTensor"}, + {"DenseToCSRSparseMatrix"}, + {"DenseToDenseSetOperation"}, + {"DenseToSparseSetOperation"}, + {"DepthToSpace"}, + {"DepthwiseConv2dNative"}, + {"DepthwiseConv2dNativeBackpropFilter"}, + {"DepthwiseConv2dNativeBackpropInput"}, + {"Diag"}, + {"DiagPart"}, + {"Digamma"}, + {"Dilation2D"}, + {"Div"}, + {"DivNoNan"}, + {"DrawBoundingBoxes"}, + {"DynamicPartition"}, + {"EditDistance"}, + {"Einsum"}, + {"EluGrad"}, + {"EncodeBase64"}, + {"EncodeProto"}, + {"EnsureShape"}, + {"Enter"}, + {"Equal"}, + {"Erf"}, + {"Erfc"}, + {"Exit"}, + {"ExpandDims"}, + {"Expint"}, + {"Expm1"}, + {"ExtractGlimpse"}, + {"FFT"}, + {"FFT2D"}, + {"FFT3D"}, + {"FakeQuantWithMinMaxArgs"}, + {"FakeQuantWithMinMaxVars"}, + {"FakeQuantWithMinMaxVarsPerChannel"}, + {"Fill"}, + {"FixedLengthRecordReader"}, + {"Floor"}, + {"FloorDiv"}, + {"FloorMod"}, + {"FractionalAvgPool", 1, {0}}, + {"FresnelCos"}, + {"FresnelSin"}, + {"FusedBatchNorm", 3, {0, 1, 2}}, + {"FusedBatchNormGrad"}, + {"FusedBatchNormGradV2"}, + {"FusedBatchNormGradV3"}, + {"FusedBatchNormV2", 3, {0, 1, 2}}, + {"FusedBatchNormV3", 3, {0, 1, 2}}, + {"Gather"}, + {"GatherNd"}, + {"GatherV2"}, + {"GenerateBoundingBoxProposals"}, + {"GenerateVocabRemapping"}, + {"GetSessionHandle"}, + {"GetSessionHandleV2"}, + {"GetSessionTensor"}, + {"Greater"}, + {"GreaterEqual"}, + {"HSVToRGB"}, + {"HashTable"}, + {"HashTableV2"}, + {"HistogramSummary"}, + {"IFFT"}, + {"IFFT2D"}, + {"IFFT3D"}, + {"IRFFT"}, + {"IRFFT2D"}, + {"Identity"}, + {"IdentityN"}, + {"IdentityReader"}, + {"Igamma"}, + {"Igammac"}, + {"Imag"}, + {"ImageProjectiveTransformV2"}, + {"ImageSummary"}, + {"InitializeTable"}, + {"InitializeTableFromTextFile"}, + {"InitializeTableFromTextFileV2"}, + {"InitializeTableV2"}, + {"InvGrad"}, + {"Invert"}, + {"InvertPermutation"}, + {"L2Loss"}, + {"LMDBReader"}, + {"LeakyRelu"}, + {"LeakyReluGrad"}, + {"LeftShift"}, + {"Less"}, + {"LessEqual"}, + {"Lgamma"}, + {"LinSpace"}, + {"LoadAndRemapMatrix"}, + {"Log"}, + {"Log1p"}, + {"LogMatrixDeterminant", 1, {0}}, + {"LogicalAnd"}, + {"LogicalNot"}, + {"LogicalOr"}, + {"LookupTableFind"}, + {"LookupTableFindV2"}, + {"LookupTableInsert"}, + {"LookupTableInsertV2"}, + {"LookupTableSize"}, + {"LookupTableSizeV2"}, + {"LoopCond"}, + {"MatMul"}, + {"MatrixBandPart"}, + {"MatrixDiag"}, + {"MatrixDiagPart"}, + {"MatrixDiagPartV2"}, + {"MatrixDiagPartV3"}, + {"MatrixDiagV2"}, + {"MatrixDiagV3"}, + {"MatrixSetDiag"}, + {"MatrixSetDiagV2"}, + {"MatrixSetDiagV3"}, + {"MaxPool3DGrad"}, + {"MaxPool3DGradGrad"}, + {"MaxPoolGrad"}, + {"MaxPoolGradGrad"}, + {"MaxPoolGradV2"}, + {"MaxPoolWithArgmax", 1, {0}}, + {"Maximum"}, + {"Merge", 1, {0}}, + {"MergeSummary"}, + {"Minimum"}, + {"MirrorPad"}, + {"MirrorPadGrad"}, + {"Mul"}, + {"MulNoNan"}, + {"Multinomial"}, + {"MutableDenseHashTable"}, + {"MutableDenseHashTableV2"}, + {"MutableHashTable"}, + {"MutableHashTableOfTensors"}, + {"MutableHashTableOfTensorsV2"}, + {"MutableHashTableV2"}, + {"NcclAllReduce"}, + {"NcclBroadcast"}, + {"NcclReduce"}, + {"Neg"}, + {"NextAfter"}, + {"NextIteration"}, + {"NonMaxSuppression"}, + {"NonMaxSuppressionV2"}, + {"NonMaxSuppressionWithOverlaps"}, + {"NotEqual"}, + {"OneHot"}, + {"OnesLike"}, + {"OptionalFromValue"}, + {"OptionalGetValue"}, + {"Pack"}, + {"Pad"}, + {"PadV2"}, + {"ParameterizedTruncatedNormal"}, + {"ParseTensor"}, + {"PlaceholderWithDefault"}, + {"Polygamma"}, + {"PopulationCount"}, + {"PreventGradient"}, + {"Print"}, + {"Prod"}, + {"QuantizeAndDequantize"}, + {"QuantizeAndDequantizeV2"}, + {"QuantizeAndDequantizeV3"}, + {"QueueClose"}, + {"QueueEnqueue"}, + {"QueueEnqueueMany"}, + {"QueueSize"}, + {"RFFT"}, + {"RFFT2D"}, + {"RaggedGather"}, + {"RaggedRange"}, + {"RaggedTensorToSparse"}, + {"RaggedTensorToTensor"}, + {"RaggedTensorToVariant"}, + {"RandomCrop"}, + {"RandomStandardNormal"}, + {"RandomUniform"}, + {"Range"}, + {"Rank"}, + {"ReadVariableOp"}, + {"ReaderNumRecordsProduced"}, + {"ReaderNumWorkUnitsCompleted"}, + {"ReaderRead"}, + {"ReaderReadUpTo"}, + {"ReaderReset"}, + {"ReaderRestoreState"}, + {"ReaderSerializeState"}, + {"Real"}, + {"RealDiv"}, + {"ReciprocalGrad"}, + {"ReduceJoin"}, + {"RefEnter"}, + {"RefExit"}, + {"RefIdentity"}, + {"RefMerge", 1, {0}}, + {"RefNextIteration"}, + {"RefSwitch"}, + {"RegexReplace"}, + {"Relu6Grad"}, + {"ReluGrad"}, + {"Reshape"}, + {"ResizeBicubic"}, + {"ResizeBilinear"}, + {"ResizeNearestNeighbor"}, + {"ResourceGather"}, + {"ResourceGatherNd"}, + {"Reverse"}, + {"ReverseSequence"}, + {"ReverseV2"}, + {"RightShift"}, + {"Rint"}, + {"Roll"}, + {"Round"}, + {"RsqrtGrad"}, + {"SampleDistortedBoundingBox"}, + {"SampleDistortedBoundingBoxV2"}, + {"ScalarSummary"}, + {"ScaleAndTranslate"}, + {"ScatterAdd"}, + {"ScatterDiv"}, + {"ScatterMul"}, + {"ScatterNd"}, + {"ScatterNdAdd"}, + {"ScatterNdNonAliasingAdd"}, + {"ScatterNdSub"}, + {"ScatterNdUpdate"}, + {"ScatterSub"}, + {"SdcaFprint"}, + {"SdcaShrinkL1"}, + {"SegmentMean"}, + {"SegmentSum"}, + {"Select"}, + {"SeluGrad"}, + {"SerializeTensor"}, + {"SetSize"}, + {"Shape"}, + {"SigmoidGrad"}, + {"Sign"}, + {"Sin"}, + {"Sinh"}, + {"Size"}, + {"SoftmaxCrossEntropyWithLogits", 1, {0}}, + {"Softplus"}, + {"SoftplusGrad"}, + {"Softsign"}, + {"SpaceToBatch"}, + {"SpaceToBatchND"}, + {"SpaceToDepth"}, + {"SparseAdd", 2, {1, 2}}, + {"SparseAddGrad"}, + {"SparseConcat"}, + {"SparseDenseCwiseAdd"}, + {"SparseDenseCwiseDiv"}, + {"SparseDenseCwiseMul"}, + {"SparseFillEmptyRows", 3, {0, 1, 2}}, + {"SparseMatMul"}, + {"SparseMatrixAdd"}, + {"SparseMatrixMatMul"}, + {"SparseMatrixMul"}, + {"SparseMatrixNNZ"}, + {"SparseMatrixSparseMatMul"}, + {"SparseMatrixTranspose"}, + {"SparseMatrixZeros"}, + {"SparseReduceSum"}, + {"SparseReorder"}, + {"SparseSegmentMean"}, + {"SparseSegmentMeanWithNumSegments"}, + {"SparseSegmentSqrtN"}, + {"SparseSegmentSqrtNWithNumSegments"}, + {"SparseSegmentSum"}, + {"SparseSegmentSumWithNumSegments"}, + {"SparseSlice", 2, {1, 2}}, + {"SparseSoftmaxCrossEntropyWithLogits", 1, {0}}, + {"SparseSparseMaximum"}, + {"SparseSparseMinimum"}, + {"SparseTensorDenseAdd"}, + {"SparseTensorDenseMatMul"}, + {"SparseToDense"}, + {"SparseToSparseSetOperation"}, + {"Spence"}, + {"Split"}, + {"SplitV"}, + {"Square"}, + {"SquaredDifference"}, + {"Squeeze"}, + {"Stack"}, + {"StackClose"}, + {"StackPop"}, + {"StackPush"}, + {"StatelessMultinomial"}, + {"StatelessRandomBinomial"}, + {"StatelessRandomNormal"}, + {"StatelessRandomPoisson"}, + {"StatelessRandomUniform"}, + {"StatelessRandomUniformFullInt"}, + {"StatelessRandomUniformInt"}, + {"StatelessTruncatedNormal"}, + {"StopGradient"}, + {"StridedSlice"}, + {"StridedSliceGrad"}, + {"StringJoin"}, + {"StringSplit"}, + {"StringToHashBucket"}, + {"StringToHashBucketFast"}, + {"StringToHashBucketStrong"}, + {"StringToNumber"}, + {"Sub"}, + {"Sum"}, + {"Switch"}, + {"TFRecordReader"}, + {"TPUEmbeddingActivations"}, + {"TPUReplicatedInput"}, + {"Tan"}, + {"TanhGrad"}, + {"TensorArray"}, + {"TensorArrayClose"}, + {"TensorArrayCloseV2"}, + {"TensorArrayCloseV3"}, + {"TensorArrayConcat", 1, {0}}, + {"TensorArrayConcatV2", 1, {0}}, + {"TensorArrayConcatV3", 1, {0}}, + {"TensorArrayGather"}, + {"TensorArrayGatherV2"}, + {"TensorArrayGatherV3"}, + {"TensorArrayGrad"}, + {"TensorArrayGradV2"}, + {"TensorArrayGradV3"}, + {"TensorArrayGradWithShape"}, + {"TensorArrayRead"}, + {"TensorArrayReadV2"}, + {"TensorArrayReadV3"}, + {"TensorArrayScatter"}, + {"TensorArrayScatterV2"}, + {"TensorArrayScatterV3"}, + {"TensorArraySize"}, + {"TensorArraySizeV2"}, + {"TensorArraySizeV3"}, + {"TensorArraySplit"}, + {"TensorArraySplitV2"}, + {"TensorArraySplitV3"}, + {"TensorArrayV2"}, + {"TensorArrayV3"}, + {"TensorArrayWrite"}, + {"TensorArrayWriteV2"}, + {"TensorArrayWriteV3"}, + {"TensorListConcat", 1, {0}}, + {"TensorListConcatLists"}, + {"TensorListConcatV2", 1, {0}}, + {"TensorListElementShape"}, + {"TensorListGather"}, + {"TensorListGetItem"}, + {"TensorListLength"}, + {"TensorListPopBack", 1, {1}}, + {"TensorListPushBack"}, + {"TensorListPushBackBatch"}, + {"TensorListResize"}, + {"TensorListScatter"}, + {"TensorListScatterIntoExistingList"}, + {"TensorListScatterV2"}, + {"TensorListSetItem"}, + {"TensorListSplit"}, + {"TensorListStack"}, + {"TensorScatterAdd"}, + {"TensorScatterSub"}, + {"TensorScatterUpdate"}, + {"TensorSummary"}, + {"TensorSummaryV2"}, + {"TextLineReader"}, + {"Tile"}, + {"Timestamp"}, + {"TopK", 1, {0}}, + {"TopKV2", 1, {0}}, + {"Transpose"}, + {"TridiagonalMatMul"}, + {"TruncateDiv"}, + {"TruncatedNormal"}, + {"Unpack"}, + {"UnsortedSegmentSum"}, + {"VarIsInitializedOp"}, + {"VariableShape"}, + {"WholeFileReader"}, + {"Xdivy"}, + {"XlaClusterOutput"}, + {"XlaEinsum"}, + {"XlaSharding"}, + {"Xlog1py"}, + {"Xlogy"}, + {"ZerosLike"}, + {"Zeta"}, + {"VarHandleOp"}, + }}; + static const auto &m = *OpGradientInfoInit(a); - auto it = m->find(op_name); - - if (it == m->end()) return false; - - *output = &it->second; - return true; + auto it = m.find(op_name); + if (it != m.end()) { + return it->second; + } + return absl::nullopt; } diff --git a/tensorflow/python/eager/pywrap_gradient_exclusions.h b/tensorflow/python/eager/pywrap_gradient_exclusions.h index 7e8908f79a2..4fac635ccd5 100644 --- a/tensorflow/python/eager/pywrap_gradient_exclusions.h +++ b/tensorflow/python/eager/pywrap_gradient_exclusions.h @@ -15,15 +15,24 @@ limitations under the License. #ifndef TENSORFLOW_PYTHON_EAGER_PYWRAP_GRADIENT_EXCLUSIONS_H_ #define TENSORFLOW_PYTHON_EAGER_PYWRAP_GRADIENT_EXCLUSIONS_H_ +#include "absl/types/optional.h" #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" -bool OpGradientDoesntRequireInputIndices( - const tensorflow::string& op_name, - std::pair<bool, tensorflow::gtl::FlatSet<int>>** output); +// Lookup whether the Op with the given op_name has unused input indices. +// Returns absl::nullopt if all inputs are used, set of unused indices +// otherwise. Empty set indicates that all indices are unused. The latter is +// necessary because sometimes it may not be possible to enumerate all indices +// just using OpDef e.g. when there are `list(T)` or `N * T` type inputs. +absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices( + const tensorflow::string& op_name); -bool OpGradientDoesntRequireOutputIndices( - const tensorflow::string& op_name, - std::pair<bool, tensorflow::gtl::FlatSet<int>>** output); +// Lookup whether the Op with the given op_name has unused output indices. +// Returns absl::nullopt if all outputs are used, set of unused indices +// otherwise. Empty set indicates that all indices are unused. The latter is +// necessary because sometimes it may not be possible to enumerate all indices +// just using OpDef e.g. when there are `list(T)` or `N * T` type outputs. +absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedOutputIndices( + const tensorflow::string& op_name); #endif // TENSORFLOW_PYTHON_EAGER_PYWRAP_GRADIENT_EXCLUSIONS_H_ diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index 6daacae9d4f..39ea862ba5e 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -2944,15 +2944,15 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs, PyObject* op_outputs; bool op_outputs_tuple_created = false; - std::pair<bool, tensorflow::gtl::FlatSet<int>>* outputs_not_required; - if (OpGradientDoesntRequireOutputIndices(c_op_name, &outputs_not_required)) { - if (outputs_not_required->first) { + if (const auto unused_output_indices = + OpGradientUnusedOutputIndices(c_op_name)) { + if (unused_output_indices->empty()) { op_outputs = Py_None; } else { op_outputs_tuple_created = true; - op_outputs = CopySequenceSettingIndicesToNull( - results, outputs_not_required->second); + op_outputs = + CopySequenceSettingIndicesToNull(results, *unused_output_indices); } } else { op_outputs = results; @@ -2960,15 +2960,15 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs, PyObject* op_inputs; bool op_inputs_tuple_created = false; - std::pair<bool, tensorflow::gtl::FlatSet<int>>* inputs_not_required; - if (OpGradientDoesntRequireInputIndices(c_op_name, &inputs_not_required)) { - if (inputs_not_required->first) { + if (const auto unused_input_indices = + OpGradientUnusedInputIndices(c_op_name)) { + if (unused_input_indices->empty()) { op_inputs = Py_None; } else { op_inputs_tuple_created = true; op_inputs = - CopySequenceSettingIndicesToNull(inputs, inputs_not_required->second); + CopySequenceSettingIndicesToNull(inputs, *unused_input_indices); } } else { op_inputs = inputs;