Print an "autogenerated" message into selective registration headers.
Change: 151759343
This commit is contained in:
parent
6737941fb4
commit
a311a55c46
@ -19,6 +19,7 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
from google.protobuf import text_format
|
from google.protobuf import text_format
|
||||||
|
|
||||||
@ -71,6 +72,9 @@ GRAPH_DEF_TXT_2 = """
|
|||||||
|
|
||||||
class PrintOpFilegroupTest(test.TestCase):
|
class PrintOpFilegroupTest(test.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
_, self.script_name = os.path.split(sys.argv[0])
|
||||||
|
|
||||||
def WriteGraphFiles(self, graphs):
|
def WriteGraphFiles(self, graphs):
|
||||||
fnames = []
|
fnames = []
|
||||||
for i, graph in enumerate(graphs):
|
for i, graph in enumerate(graphs):
|
||||||
@ -130,6 +134,7 @@ class PrintOpFilegroupTest(test.TestCase):
|
|||||||
ops_and_kernels, include_all_ops_and_kernels=True)
|
ops_and_kernels, include_all_ops_and_kernels=True)
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
[
|
[
|
||||||
|
'// This file was autogenerated by %s' % self.script_name,
|
||||||
'#ifndef OPS_TO_REGISTER', #
|
'#ifndef OPS_TO_REGISTER', #
|
||||||
'#define OPS_TO_REGISTER', #
|
'#define OPS_TO_REGISTER', #
|
||||||
'#define SHOULD_REGISTER_OP(op) true', #
|
'#define SHOULD_REGISTER_OP(op) true', #
|
||||||
@ -148,7 +153,8 @@ class PrintOpFilegroupTest(test.TestCase):
|
|||||||
default_ops = ''
|
default_ops = ''
|
||||||
graphs = [text_format.Parse(GRAPH_DEF_TXT_2, graph_pb2.GraphDef())]
|
graphs = [text_format.Parse(GRAPH_DEF_TXT_2, graph_pb2.GraphDef())]
|
||||||
|
|
||||||
expected = '''#ifndef OPS_TO_REGISTER
|
expected = '''// This file was autogenerated by %s
|
||||||
|
#ifndef OPS_TO_REGISTER
|
||||||
#define OPS_TO_REGISTER
|
#define OPS_TO_REGISTER
|
||||||
constexpr inline bool ShouldRegisterOp(const char op[]) {
|
constexpr inline bool ShouldRegisterOp(const char op[]) {
|
||||||
return false
|
return false
|
||||||
@ -189,7 +195,7 @@ constexpr inline bool ShouldRegisterOp(const char op[]) {
|
|||||||
#define SHOULD_REGISTER_OP_KERNEL(clz) (find_in<sizeof(kNecessaryOpKernelClasses) / sizeof(*kNecessaryOpKernelClasses)>::f(clz, kNecessaryOpKernelClasses))
|
#define SHOULD_REGISTER_OP_KERNEL(clz) (find_in<sizeof(kNecessaryOpKernelClasses) / sizeof(*kNecessaryOpKernelClasses)>::f(clz, kNecessaryOpKernelClasses))
|
||||||
|
|
||||||
#define SHOULD_REGISTER_OP_GRADIENT false
|
#define SHOULD_REGISTER_OP_GRADIENT false
|
||||||
#endif'''
|
#endif''' % self.script_name
|
||||||
|
|
||||||
header = selective_registration_header_lib.get_header(
|
header = selective_registration_header_lib.get_header(
|
||||||
self.WriteGraphFiles(graphs), 'rawproto', default_ops)
|
self.WriteGraphFiles(graphs), 'rawproto', default_ops)
|
||||||
|
@ -22,6 +22,7 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from google.protobuf import text_format
|
from google.protobuf import text_format
|
||||||
@ -89,6 +90,8 @@ def get_header_from_ops_and_kernels(ops_and_kernels,
|
|||||||
def append(s):
|
def append(s):
|
||||||
result_list.append(s)
|
result_list.append(s)
|
||||||
|
|
||||||
|
_, script_name = os.path.split(sys.argv[0])
|
||||||
|
append('// This file was autogenerated by %s' % script_name)
|
||||||
append('#ifndef OPS_TO_REGISTER')
|
append('#ifndef OPS_TO_REGISTER')
|
||||||
append('#define OPS_TO_REGISTER')
|
append('#define OPS_TO_REGISTER')
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user