Change the header generated for selective registration to use the isequals

defined in the header, instead of strcmp. strcmp is not a constexpr on all
platforms.

PiperOrigin-RevId: 156914421
This commit is contained in:
A. Unique TensorFlower 2017-05-23 14:48:39 -07:00 committed by TensorFlower Gardener
parent 630005aa52
commit f333979e0d
2 changed files with 16 additions and 16 deletions

View File

@ -156,13 +156,6 @@ class PrintOpFilegroupTest(test.TestCase):
expected = '''// This file was autogenerated by %s
#ifndef OPS_TO_REGISTER
#define OPS_TO_REGISTER
constexpr inline bool ShouldRegisterOp(const char op[]) {
return false
|| (strcmp(op, "BiasAdd") == 0)
;
}
#define SHOULD_REGISTER_OP(op) ShouldRegisterOp(op)
namespace {
constexpr const char* skip(const char* x) {
@ -194,6 +187,13 @@ constexpr inline bool ShouldRegisterOp(const char op[]) {
};
#define SHOULD_REGISTER_OP_KERNEL(clz) (find_in<sizeof(kNecessaryOpKernelClasses) / sizeof(*kNecessaryOpKernelClasses)>::f(clz, kNecessaryOpKernelClasses))
constexpr inline bool ShouldRegisterOp(const char op[]) {
return false
|| isequal(op, "BiasAdd")
;
}
#define SHOULD_REGISTER_OP(op) ShouldRegisterOp(op)
#define SHOULD_REGISTER_OP_GRADIENT false
#endif''' % self.script_name

View File

@ -100,15 +100,6 @@ def get_header_from_ops_and_kernels(ops_and_kernels,
append('#define SHOULD_REGISTER_OP_KERNEL(clz) true')
append('#define SHOULD_REGISTER_OP_GRADIENT true')
else:
append('constexpr inline bool ShouldRegisterOp(const char op[]) {')
append(' return false')
for op in sorted(ops):
append(' || (strcmp(op, "%s") == 0)' % op)
append(' ;')
append('}')
append('#define SHOULD_REGISTER_OP(op) ShouldRegisterOp(op)')
append('')
line = '''
namespace {
constexpr const char* skip(const char* x) {
@ -147,6 +138,15 @@ def get_header_from_ops_and_kernels(ops_and_kernels,
'kNecessaryOpKernelClasses))')
append('')
append('constexpr inline bool ShouldRegisterOp(const char op[]) {')
append(' return false')
for op in sorted(ops):
append(' || isequal(op, "%s")' % op)
append(' ;')
append('}')
append('#define SHOULD_REGISTER_OP(op) ShouldRegisterOp(op)')
append('')
append('#define SHOULD_REGISTER_OP_GRADIENT ' + (
'true' if 'SymbolicGradient' in ops else 'false'))