Print an "autogenerated" message into selective registration headers.
Change: 151759343
This commit is contained in:
parent
6737941fb4
commit
a311a55c46
@ -19,6 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
from google.protobuf import text_format
|
||||
|
||||
@ -71,6 +72,9 @@ GRAPH_DEF_TXT_2 = """
|
||||
|
||||
class PrintOpFilegroupTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
_, self.script_name = os.path.split(sys.argv[0])
|
||||
|
||||
def WriteGraphFiles(self, graphs):
|
||||
fnames = []
|
||||
for i, graph in enumerate(graphs):
|
||||
@ -130,6 +134,7 @@ class PrintOpFilegroupTest(test.TestCase):
|
||||
ops_and_kernels, include_all_ops_and_kernels=True)
|
||||
self.assertListEqual(
|
||||
[
|
||||
'// This file was autogenerated by %s' % self.script_name,
|
||||
'#ifndef OPS_TO_REGISTER', #
|
||||
'#define OPS_TO_REGISTER', #
|
||||
'#define SHOULD_REGISTER_OP(op) true', #
|
||||
@ -148,7 +153,8 @@ class PrintOpFilegroupTest(test.TestCase):
|
||||
default_ops = ''
|
||||
graphs = [text_format.Parse(GRAPH_DEF_TXT_2, graph_pb2.GraphDef())]
|
||||
|
||||
expected = '''#ifndef OPS_TO_REGISTER
|
||||
expected = '''// This file was autogenerated by %s
|
||||
#ifndef OPS_TO_REGISTER
|
||||
#define OPS_TO_REGISTER
|
||||
constexpr inline bool ShouldRegisterOp(const char op[]) {
|
||||
return false
|
||||
@ -189,7 +195,7 @@ constexpr inline bool ShouldRegisterOp(const char op[]) {
|
||||
#define SHOULD_REGISTER_OP_KERNEL(clz) (find_in<sizeof(kNecessaryOpKernelClasses) / sizeof(*kNecessaryOpKernelClasses)>::f(clz, kNecessaryOpKernelClasses))
|
||||
|
||||
#define SHOULD_REGISTER_OP_GRADIENT false
|
||||
#endif'''
|
||||
#endif''' % self.script_name
|
||||
|
||||
header = selective_registration_header_lib.get_header(
|
||||
self.WriteGraphFiles(graphs), 'rawproto', default_ops)
|
||||
|
@ -22,6 +22,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
from google.protobuf import text_format
|
||||
@ -89,6 +90,8 @@ def get_header_from_ops_and_kernels(ops_and_kernels,
|
||||
def append(s):
|
||||
result_list.append(s)
|
||||
|
||||
_, script_name = os.path.split(sys.argv[0])
|
||||
append('// This file was autogenerated by %s' % script_name)
|
||||
append('#ifndef OPS_TO_REGISTER')
|
||||
append('#define OPS_TO_REGISTER')
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user