Add a method to list op names in an ApiDefMap.

PiperOrigin-RevId: 197234301
This commit is contained in:
A. Unique TensorFlower 2018-05-18 19:25:41 -07:00 committed by TensorFlower Gardener
parent 92cdc99c98
commit 81168f21b7
2 changed files with 7 additions and 0 deletions

View File

@ -134,6 +134,9 @@ class ApiDefMap(object):
return self._op_per_name[op_name]
raise ValueError("No entry found for " + op_name + ".")
def op_names(self):
return self._op_per_name.keys()
@tf_contextlib.contextmanager
def tf_buffer(data=None):

View File

@ -25,6 +25,10 @@ from tensorflow.python.platform import googletest
class ApiDefMapTest(test_util.TensorFlowTestCase):
def testApiDefMapOpNames(self):
api_def_map = c_api_util.ApiDefMap()
self.assertIn("Add", api_def_map.op_names())
def testApiDefMapGet(self):
api_def_map = c_api_util.ApiDefMap()
op_def = api_def_map.get_op_def("Add")