Simplify gradient exclusions data to speed up compilation w/ clang on windows.
PiperOrigin-RevId: 294998085 Change-Id: Ie56b8f2cf4ed1e5fd8e2a641947b2d69f316e86a
This commit is contained in:
parent
c8f74b6215
commit
22f801369f
@ -54,6 +54,7 @@ cc_library(
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/hash",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:variant",
|
||||
],
|
||||
)
|
||||
|
@ -63,10 +63,33 @@ limitations under the License.
|
||||
_INCLUDES = """
|
||||
#include "tensorflow/python/eager/pywrap_gradient_exclusions.h"
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
|
||||
using tensorflow::string;
|
||||
|
||||
namespace {
|
||||
// Keep static data in a format that's easy to init statically.
|
||||
struct OpIndexInfo {
|
||||
const char *op_name;
|
||||
int num_indices;
|
||||
std::array<int, 4> unused_indices;
|
||||
};
|
||||
|
||||
// Helper function to initialize FlatMap<string,FlatSet> from OpIndexInfo.
|
||||
template <typename T>
|
||||
auto OpGradientInfoInit(const T &a) {
|
||||
auto *m = new tensorflow::gtl::FlatMap<string, tensorflow::gtl::FlatSet<int>>;
|
||||
for (const auto &item : a) {
|
||||
m->emplace(string(item.op_name),
|
||||
tensorflow::gtl::FlatSet<int>(
|
||||
item.unused_indices.begin(),
|
||||
item.unused_indices.begin() + item.num_indices));
|
||||
}
|
||||
return m;
|
||||
}
|
||||
} // namespace
|
||||
"""
|
||||
|
||||
_EXCLUDED_OPS = [
|
||||
@ -281,7 +304,6 @@ def get_entries(attr_name):
|
||||
"""
|
||||
assert attr_name in ["inputs", "outputs"]
|
||||
entries = {}
|
||||
spaces = " "
|
||||
for op_type in ops._gradient_registry.list(): # pylint: disable=protected-access
|
||||
if op_type in _EXCLUDED_OPS:
|
||||
continue
|
||||
@ -291,72 +313,57 @@ def get_entries(attr_name):
|
||||
if gradient_fn is None:
|
||||
# NotDifferentiable
|
||||
if num_values != -1:
|
||||
entries[op_type] = spaces + "{\"%s\", {true, {}}}," % op_type
|
||||
entries[op_type] = "{\"%s\"}," % op_type
|
||||
continue
|
||||
used_tensors = _live_tensors(gradient_fn, attr_name=attr_name)
|
||||
if used_tensors is _ALL:
|
||||
continue
|
||||
elif not used_tensors:
|
||||
entries[op_type] = spaces + "{\"%s\", {true, {}}}," % op_type
|
||||
entries[op_type] = "{\"%s\"}," % op_type
|
||||
else:
|
||||
all_tensors = set(range(num_values))
|
||||
unused_tensors = all_tensors - used_tensors
|
||||
if unused_tensors:
|
||||
entries[op_type] = spaces + "{\"%s\", {false, {%s}}}," % (
|
||||
op_type, ", ".join(str(i) for i in sorted(list(unused_tensors))))
|
||||
unused_tensor_list = sorted(list(unused_tensors))
|
||||
entries[op_type] = "{\"%s\", %d, {%s}}," % (
|
||||
op_type, len(unused_tensor_list), ", ".join(
|
||||
str(i) for i in unused_tensor_list))
|
||||
return entries
|
||||
|
||||
|
||||
def get_function(name, entries):
|
||||
"""Generates lookup function with given name and lookup table entries."""
|
||||
contents = """
|
||||
absl::optional<tensorflow::gtl::FlatSet<int>> {name}(
|
||||
const tensorflow::string &op_name) {{
|
||||
static std::array<OpIndexInfo, {count}> a = {{{{
|
||||
""".format(
|
||||
name=name, count=len(entries) + 1)
|
||||
contents += " "
|
||||
contents += "\n ".join(entries[op_type] for op_type in sorted(entries))
|
||||
contents += "\n {\"VarHandleOp\"},"
|
||||
contents += """
|
||||
}};
|
||||
static const auto &m = *OpGradientInfoInit(a);
|
||||
|
||||
auto it = m.find(op_name);
|
||||
if (it != m.end()) {
|
||||
return it->second;
|
||||
}
|
||||
return absl::nullopt;
|
||||
}
|
||||
"""
|
||||
return contents
|
||||
|
||||
|
||||
def get_contents():
|
||||
"""Returns contents for the generated file."""
|
||||
contents = ""
|
||||
contents += _GENERATED_FILE_HEADER + _INCLUDES
|
||||
contents += """
|
||||
bool OpGradientDoesntRequireInputIndices(
|
||||
const string& op_name,
|
||||
std::pair<bool, tensorflow::gtl::FlatSet<int>>** output) {
|
||||
static tensorflow::gtl::FlatMap<
|
||||
string, std::pair<bool, tensorflow::gtl::FlatSet<int>>>* m =
|
||||
new tensorflow::gtl::FlatMap<
|
||||
string, std::pair<bool, tensorflow::gtl::FlatSet<int>>>({
|
||||
"""
|
||||
|
||||
entries = get_entries("inputs")
|
||||
contents += "\n".join(entries[op_type] for op_type in sorted(entries))
|
||||
contents += "\n {\"VarHandleOp\", {true, {}}},\n"
|
||||
contents += """ });
|
||||
|
||||
auto it = m->find(op_name);
|
||||
|
||||
if (it == m->end()) return false;
|
||||
|
||||
*output = &it->second;
|
||||
return true;
|
||||
}
|
||||
"""
|
||||
contents += """
|
||||
bool OpGradientDoesntRequireOutputIndices(
|
||||
const string& op_name,
|
||||
std::pair<bool, tensorflow::gtl::FlatSet<int>>** output) {
|
||||
static tensorflow::gtl::FlatMap<
|
||||
string, std::pair<bool, tensorflow::gtl::FlatSet<int>>>* m =
|
||||
new tensorflow::gtl::FlatMap<
|
||||
string, std::pair<bool, tensorflow::gtl::FlatSet<int>>>({
|
||||
"""
|
||||
|
||||
entries = get_entries("outputs")
|
||||
contents += "\n".join(entries[op_type] for op_type in sorted(entries))
|
||||
contents += "\n {\"VarHandleOp\", {true, {}}},\n"
|
||||
contents += """ });
|
||||
|
||||
auto it = m->find(op_name);
|
||||
|
||||
if (it == m->end()) return false;
|
||||
|
||||
*output = &it->second;
|
||||
return true;
|
||||
}
|
||||
"""
|
||||
contents += get_function("OpGradientUnusedInputIndices",
|
||||
get_entries("inputs"))
|
||||
contents += get_function("OpGradientUnusedOutputIndices",
|
||||
get_entries("outputs"))
|
||||
return contents
|
||||
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -15,15 +15,24 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_PYTHON_EAGER_PYWRAP_GRADIENT_EXCLUSIONS_H_
|
||||
#define TENSORFLOW_PYTHON_EAGER_PYWRAP_GRADIENT_EXCLUSIONS_H_
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
|
||||
bool OpGradientDoesntRequireInputIndices(
|
||||
const tensorflow::string& op_name,
|
||||
std::pair<bool, tensorflow::gtl::FlatSet<int>>** output);
|
||||
// Lookup whether the Op with the given op_name has unused input indices.
|
||||
// Returns absl::nullopt if all inputs are used, set of unused indices
|
||||
// otherwise. Empty set indicates that all indices are unused. The latter is
|
||||
// necessary because sometimes it may not be possible to enumerate all indices
|
||||
// just using OpDef e.g. when there are `list(T)` or `N * T` type inputs.
|
||||
absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices(
|
||||
const tensorflow::string& op_name);
|
||||
|
||||
bool OpGradientDoesntRequireOutputIndices(
|
||||
const tensorflow::string& op_name,
|
||||
std::pair<bool, tensorflow::gtl::FlatSet<int>>** output);
|
||||
// Lookup whether the Op with the given op_name has unused output indices.
|
||||
// Returns absl::nullopt if all outputs are used, set of unused indices
|
||||
// otherwise. Empty set indicates that all indices are unused. The latter is
|
||||
// necessary because sometimes it may not be possible to enumerate all indices
|
||||
// just using OpDef e.g. when there are `list(T)` or `N * T` type outputs.
|
||||
absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedOutputIndices(
|
||||
const tensorflow::string& op_name);
|
||||
|
||||
#endif // TENSORFLOW_PYTHON_EAGER_PYWRAP_GRADIENT_EXCLUSIONS_H_
|
||||
|
@ -2944,15 +2944,15 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
|
||||
|
||||
PyObject* op_outputs;
|
||||
bool op_outputs_tuple_created = false;
|
||||
std::pair<bool, tensorflow::gtl::FlatSet<int>>* outputs_not_required;
|
||||
|
||||
if (OpGradientDoesntRequireOutputIndices(c_op_name, &outputs_not_required)) {
|
||||
if (outputs_not_required->first) {
|
||||
if (const auto unused_output_indices =
|
||||
OpGradientUnusedOutputIndices(c_op_name)) {
|
||||
if (unused_output_indices->empty()) {
|
||||
op_outputs = Py_None;
|
||||
} else {
|
||||
op_outputs_tuple_created = true;
|
||||
op_outputs = CopySequenceSettingIndicesToNull(
|
||||
results, outputs_not_required->second);
|
||||
op_outputs =
|
||||
CopySequenceSettingIndicesToNull(results, *unused_output_indices);
|
||||
}
|
||||
} else {
|
||||
op_outputs = results;
|
||||
@ -2960,15 +2960,15 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
|
||||
|
||||
PyObject* op_inputs;
|
||||
bool op_inputs_tuple_created = false;
|
||||
std::pair<bool, tensorflow::gtl::FlatSet<int>>* inputs_not_required;
|
||||
|
||||
if (OpGradientDoesntRequireInputIndices(c_op_name, &inputs_not_required)) {
|
||||
if (inputs_not_required->first) {
|
||||
if (const auto unused_input_indices =
|
||||
OpGradientUnusedInputIndices(c_op_name)) {
|
||||
if (unused_input_indices->empty()) {
|
||||
op_inputs = Py_None;
|
||||
} else {
|
||||
op_inputs_tuple_created = true;
|
||||
op_inputs =
|
||||
CopySequenceSettingIndicesToNull(inputs, inputs_not_required->second);
|
||||
CopySequenceSettingIndicesToNull(inputs, *unused_input_indices);
|
||||
}
|
||||
} else {
|
||||
op_inputs = inputs;
|
||||
|
Loading…
Reference in New Issue
Block a user