Add a method to list op names in an ApiDefMap.
PiperOrigin-RevId: 197234301
This commit is contained in:
parent
92cdc99c98
commit
81168f21b7
@ -134,6 +134,9 @@ class ApiDefMap(object):
|
||||
return self._op_per_name[op_name]
|
||||
raise ValueError("No entry found for " + op_name + ".")
|
||||
|
||||
def op_names(self):
|
||||
return self._op_per_name.keys()
|
||||
|
||||
|
||||
@tf_contextlib.contextmanager
|
||||
def tf_buffer(data=None):
|
||||
|
@ -25,6 +25,10 @@ from tensorflow.python.platform import googletest
|
||||
|
||||
class ApiDefMapTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testApiDefMapOpNames(self):
|
||||
api_def_map = c_api_util.ApiDefMap()
|
||||
self.assertIn("Add", api_def_map.op_names())
|
||||
|
||||
def testApiDefMapGet(self):
|
||||
api_def_map = c_api_util.ApiDefMap()
|
||||
op_def = api_def_map.get_op_def("Add")
|
||||
|
Loading…
x
Reference in New Issue
Block a user