diff --git a/tensorflow/python/tools/print_selective_registration_header_test.py b/tensorflow/python/tools/print_selective_registration_header_test.py index fe20df59246..36978b0860a 100644 --- a/tensorflow/python/tools/print_selective_registration_header_test.py +++ b/tensorflow/python/tools/print_selective_registration_header_test.py @@ -156,13 +156,6 @@ class PrintOpFilegroupTest(test.TestCase): expected = '''// This file was autogenerated by %s #ifndef OPS_TO_REGISTER #define OPS_TO_REGISTER -constexpr inline bool ShouldRegisterOp(const char op[]) { - return false - || (strcmp(op, "BiasAdd") == 0) - ; -} -#define SHOULD_REGISTER_OP(op) ShouldRegisterOp(op) - namespace { constexpr const char* skip(const char* x) { @@ -194,6 +187,13 @@ constexpr inline bool ShouldRegisterOp(const char op[]) { }; #define SHOULD_REGISTER_OP_KERNEL(clz) (find_in::f(clz, kNecessaryOpKernelClasses)) +constexpr inline bool ShouldRegisterOp(const char op[]) { + return false + || isequal(op, "BiasAdd") + ; +} +#define SHOULD_REGISTER_OP(op) ShouldRegisterOp(op) + #define SHOULD_REGISTER_OP_GRADIENT false #endif''' % self.script_name diff --git a/tensorflow/python/tools/selective_registration_header_lib.py b/tensorflow/python/tools/selective_registration_header_lib.py index 7be61ca379e..7f7470994dd 100644 --- a/tensorflow/python/tools/selective_registration_header_lib.py +++ b/tensorflow/python/tools/selective_registration_header_lib.py @@ -100,15 +100,6 @@ def get_header_from_ops_and_kernels(ops_and_kernels, append('#define SHOULD_REGISTER_OP_KERNEL(clz) true') append('#define SHOULD_REGISTER_OP_GRADIENT true') else: - append('constexpr inline bool ShouldRegisterOp(const char op[]) {') - append(' return false') - for op in sorted(ops): - append(' || (strcmp(op, "%s") == 0)' % op) - append(' ;') - append('}') - append('#define SHOULD_REGISTER_OP(op) ShouldRegisterOp(op)') - append('') - line = ''' namespace { constexpr const char* skip(const char* x) { @@ -147,6 +138,15 @@ def get_header_from_ops_and_kernels(ops_and_kernels, 'kNecessaryOpKernelClasses))') append('') + append('constexpr inline bool ShouldRegisterOp(const char op[]) {') + append(' return false') + for op in sorted(ops): + append(' || isequal(op, "%s")' % op) + append(' ;') + append('}') + append('#define SHOULD_REGISTER_OP(op) ShouldRegisterOp(op)') + append('') + append('#define SHOULD_REGISTER_OP_GRADIENT ' + ( 'true' if 'SymbolicGradient' in ops else 'false'))