In the current PluginAsset api, the PluginAsset provides a serialize_to_directory method in which it directly writes contents to disk. This means that we as framework maintainers don't have the flexibility to change the serialization strategy in different contexts, e.g. providing good ways to serialize in contexts where we are writing to a db rather than to disk. Also, it presents a trivial landmine where users may use standard python file APIs rather than gfile and thus provide implementations that work externally but break within g3. After the change, the PluginAsset instead provides an 'assets' method, which provides asset names and asset contents. How this gets written out is now an implementation detail handled by the tf.summary.FileWriter. We haven't yet exposed the PluginAsset class as part of the TensorFlow API (it's hidden) so this isn't an API break. Change: 149985564
82 lines
2.8 KiB
Python
82 lines
2.8 KiB
Python
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.framework import test_util
|
|
from tensorflow.python.platform import googletest
|
|
from tensorflow.python.summary import plugin_asset
|
|
|
|
|
|
class _UnnamedPluginAsset(plugin_asset.PluginAsset):
|
|
"""An example asset with a dummy serialize method provided, but no name."""
|
|
|
|
def assets(self):
|
|
return {}
|
|
|
|
|
|
class _ExamplePluginAsset(_UnnamedPluginAsset):
|
|
"""Simple example asset."""
|
|
plugin_name = "_ExamplePluginAsset"
|
|
|
|
|
|
class _OtherExampleAsset(_UnnamedPluginAsset):
|
|
"""Simple example asset."""
|
|
plugin_name = "_OtherExampleAsset"
|
|
|
|
|
|
class _ExamplePluginThatWillCauseCollision(_UnnamedPluginAsset):
|
|
plugin_name = "_ExamplePluginAsset"
|
|
|
|
|
|
class PluginAssetTest(test_util.TensorFlowTestCase):
|
|
|
|
def testGetPluginAsset(self):
|
|
epa = plugin_asset.get_plugin_asset(_ExamplePluginAsset)
|
|
self.assertIsInstance(epa, _ExamplePluginAsset)
|
|
epa2 = plugin_asset.get_plugin_asset(_ExamplePluginAsset)
|
|
self.assertIs(epa, epa2)
|
|
opa = plugin_asset.get_plugin_asset(_OtherExampleAsset)
|
|
self.assertIsNot(epa, opa)
|
|
|
|
def testUnnamedPluginFails(self):
|
|
with self.assertRaises(ValueError):
|
|
plugin_asset.get_plugin_asset(_UnnamedPluginAsset)
|
|
|
|
def testPluginCollisionDetected(self):
|
|
plugin_asset.get_plugin_asset(_ExamplePluginAsset)
|
|
with self.assertRaises(ValueError):
|
|
plugin_asset.get_plugin_asset(_ExamplePluginThatWillCauseCollision)
|
|
|
|
def testGetAllPluginAssets(self):
|
|
epa = plugin_asset.get_plugin_asset(_ExamplePluginAsset)
|
|
opa = plugin_asset.get_plugin_asset(_OtherExampleAsset)
|
|
self.assertItemsEqual(plugin_asset.get_all_plugin_assets(), [epa, opa])
|
|
|
|
def testRespectsGraphArgument(self):
|
|
g1 = ops.Graph()
|
|
g2 = ops.Graph()
|
|
e1 = plugin_asset.get_plugin_asset(_ExamplePluginAsset, g1)
|
|
e2 = plugin_asset.get_plugin_asset(_ExamplePluginAsset, g2)
|
|
|
|
self.assertEqual(e1, plugin_asset.get_all_plugin_assets(g1)[0])
|
|
self.assertEqual(e2, plugin_asset.get_all_plugin_assets(g2)[0])
|
|
|
|
if __name__ == "__main__":
|
|
googletest.main()
|