Change the header generated for selective registration to use the isequals
defined in the header, instead of strcmp. strcmp is not a constexpr on all platforms. PiperOrigin-RevId: 156914421
This commit is contained in:
parent
630005aa52
commit
f333979e0d
@ -156,13 +156,6 @@ class PrintOpFilegroupTest(test.TestCase):
|
||||
expected = '''// This file was autogenerated by %s
|
||||
#ifndef OPS_TO_REGISTER
|
||||
#define OPS_TO_REGISTER
|
||||
constexpr inline bool ShouldRegisterOp(const char op[]) {
|
||||
return false
|
||||
|| (strcmp(op, "BiasAdd") == 0)
|
||||
;
|
||||
}
|
||||
#define SHOULD_REGISTER_OP(op) ShouldRegisterOp(op)
|
||||
|
||||
|
||||
namespace {
|
||||
constexpr const char* skip(const char* x) {
|
||||
@ -194,6 +187,13 @@ constexpr inline bool ShouldRegisterOp(const char op[]) {
|
||||
};
|
||||
#define SHOULD_REGISTER_OP_KERNEL(clz) (find_in<sizeof(kNecessaryOpKernelClasses) / sizeof(*kNecessaryOpKernelClasses)>::f(clz, kNecessaryOpKernelClasses))
|
||||
|
||||
constexpr inline bool ShouldRegisterOp(const char op[]) {
|
||||
return false
|
||||
|| isequal(op, "BiasAdd")
|
||||
;
|
||||
}
|
||||
#define SHOULD_REGISTER_OP(op) ShouldRegisterOp(op)
|
||||
|
||||
#define SHOULD_REGISTER_OP_GRADIENT false
|
||||
#endif''' % self.script_name
|
||||
|
||||
|
@ -100,15 +100,6 @@ def get_header_from_ops_and_kernels(ops_and_kernels,
|
||||
append('#define SHOULD_REGISTER_OP_KERNEL(clz) true')
|
||||
append('#define SHOULD_REGISTER_OP_GRADIENT true')
|
||||
else:
|
||||
append('constexpr inline bool ShouldRegisterOp(const char op[]) {')
|
||||
append(' return false')
|
||||
for op in sorted(ops):
|
||||
append(' || (strcmp(op, "%s") == 0)' % op)
|
||||
append(' ;')
|
||||
append('}')
|
||||
append('#define SHOULD_REGISTER_OP(op) ShouldRegisterOp(op)')
|
||||
append('')
|
||||
|
||||
line = '''
|
||||
namespace {
|
||||
constexpr const char* skip(const char* x) {
|
||||
@ -147,6 +138,15 @@ def get_header_from_ops_and_kernels(ops_and_kernels,
|
||||
'kNecessaryOpKernelClasses))')
|
||||
append('')
|
||||
|
||||
append('constexpr inline bool ShouldRegisterOp(const char op[]) {')
|
||||
append(' return false')
|
||||
for op in sorted(ops):
|
||||
append(' || isequal(op, "%s")' % op)
|
||||
append(' ;')
|
||||
append('}')
|
||||
append('#define SHOULD_REGISTER_OP(op) ShouldRegisterOp(op)')
|
||||
append('')
|
||||
|
||||
append('#define SHOULD_REGISTER_OP_GRADIENT ' + (
|
||||
'true' if 'SymbolicGradient' in ops else 'false'))
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user