Add make_ndarray, tensor_proto, and MetaGraphDef to tf api.

Since TensorProtos are part of the TensorFlow API, it makes sense
to also include the methods that generate and parse them.

Similarly, we write out MetaGraphDef protos in the summary writer,
so we should provide the proto as well.

This is part of an ongoing effort to have TensorBoard only consume
TensorFlow methods through the public api.

PiperOrigin-RevId: 157657564
This commit is contained in:
Dandelion Man? 2017-05-31 17:22:42 -07:00 committed by TensorFlower Gardener
parent 458f94c128
commit 0462416f64
10 changed files with 414 additions and 19 deletions

View File

@ -54,6 +54,7 @@ from tensorflow.core.framework.node_def_pb2 import *
from tensorflow.core.framework.summary_pb2 import * from tensorflow.core.framework.summary_pb2 import *
from tensorflow.core.framework.attr_value_pb2 import * from tensorflow.core.framework.attr_value_pb2 import *
from tensorflow.core.protobuf.meta_graph_pb2 import TensorInfo from tensorflow.core.protobuf.meta_graph_pb2 import TensorInfo
from tensorflow.core.protobuf.meta_graph_pb2 import MetaGraphDef
from tensorflow.core.protobuf.config_pb2 import * from tensorflow.core.protobuf.config_pb2 import *
from tensorflow.core.protobuf.tensorflow_server_pb2 import * from tensorflow.core.protobuf.tensorflow_server_pb2 import *
from tensorflow.core.protobuf.rewriter_config_pb2 import * from tensorflow.core.protobuf.rewriter_config_pb2 import *
@ -144,6 +145,7 @@ _allowed_symbols = [
'GraphOptions', 'GraphOptions',
'HistogramProto', 'HistogramProto',
'LogMessage', 'LogMessage',
'MetaGraphDef',
'NameAttrList', 'NameAttrList',
'NodeDef', 'NodeDef',
'OptimizerOptions', 'OptimizerOptions',

View File

@ -41,6 +41,8 @@
@@import_graph_def @@import_graph_def
@@load_file_system_library @@load_file_system_library
@@load_op_library @@load_op_library
@@make_tensor_proto
@@make_ndarray
## Graph collections ## Graph collections
@ -98,6 +100,10 @@ from tensorflow.python.framework.sparse_tensor import convert_to_tensor_or_spars
from tensorflow.python.framework.subscribe import subscribe from tensorflow.python.framework.subscribe import subscribe
from tensorflow.python.framework.importer import import_graph_def from tensorflow.python.framework.importer import import_graph_def
# Utilities for working with Tensors
from tensorflow.python.framework.tensor_util import make_tensor_proto
from tensorflow.python.framework.tensor_util import MakeNdarray as make_ndarray
# Needed when you defined a new Op in C++. # Needed when you defined a new Op in C++.
from tensorflow.python.framework.ops import RegisterGradient from tensorflow.python.framework.ops import RegisterGradient
from tensorflow.python.framework.ops import NotDifferentiable from tensorflow.python.framework.ops import NotDifferentiable

View File

@ -25,9 +25,6 @@ import threading
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_util
from tensorflow.tensorboard.backend.event_processing import directory_watcher from tensorflow.tensorboard.backend.event_processing import directory_watcher
from tensorflow.tensorboard.backend.event_processing import event_file_loader from tensorflow.tensorboard.backend.event_processing import event_file_loader
from tensorflow.tensorboard.backend.event_processing import plugin_asset_util from tensorflow.tensorboard.backend.event_processing import plugin_asset_util
@ -329,7 +326,7 @@ class EventAccumulator(object):
if self._graph is None or self._graph_from_metagraph: if self._graph is None or self._graph_from_metagraph:
# We may have a graph_def in the metagraph. If so, and no # We may have a graph_def in the metagraph. If so, and no
# graph_def is directly available, use this one instead. # graph_def is directly available, use this one instead.
meta_graph = meta_graph_pb2.MetaGraphDef() meta_graph = tf.MetaGraphDef()
meta_graph.ParseFromString(self._meta_graph) meta_graph.ParseFromString(self._meta_graph)
if meta_graph.graph_def: if meta_graph.graph_def:
if self._graph is not None: if self._graph is not None:
@ -371,7 +368,7 @@ class EventAccumulator(object):
value: A tf.Summary.Value with a Tensor field. value: A tf.Summary.Value with a Tensor field.
event: The tf.Event containing that value. event: The tf.Event containing that value.
""" """
elements = tensor_util.MakeNdarray(value.tensor) elements = tf.make_ndarray(value.tensor)
# The node_name property of the value object is actually a watch key: a # The node_name property of the value object is actually a watch key: a
# combination of node name, output slot, and a suffix. We capture the # combination of node name, output slot, and a suffix. We capture the
@ -475,7 +472,7 @@ class EventAccumulator(object):
""" """
if self._meta_graph is None: if self._meta_graph is None:
raise ValueError('There is no metagraph in this EventAccumulator') raise ValueError('There is no metagraph in this EventAccumulator')
meta_graph = meta_graph_pb2.MetaGraphDef() meta_graph = tf.MetaGraphDef()
meta_graph.ParseFromString(self._meta_graph) meta_graph.ParseFromString(self._meta_graph)
return meta_graph return meta_graph
@ -698,7 +695,7 @@ class EventAccumulator(object):
device_name=device_name, device_name=device_name,
node_name=node_name, node_name=node_name,
output_slot=output_slot, output_slot=output_slot,
dtype=repr(dtypes.as_dtype(elements[12])), dtype=repr(tf.as_dtype(elements[12])),
shape=list(elements[14:]), shape=list(elements[14:]),
value=list(elements))) value=list(elements)))

View File

@ -24,8 +24,6 @@ import six
from six.moves import xrange # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf import tensorflow as tf
from tensorflow.core.framework import types_pb2
from tensorflow.python.framework import tensor_util
from tensorflow.python.platform import googletest from tensorflow.python.platform import googletest
from tensorflow.python.summary.writer.writer import SummaryToEventTransformer from tensorflow.python.summary.writer.writer import SummaryToEventTransformer
from tensorflow.tensorboard.backend.event_processing import event_accumulator as ea from tensorflow.tensorboard.backend.event_processing import event_accumulator as ea
@ -64,7 +62,7 @@ class _EventGenerator(object):
tag=ea.HEALTH_PILL_EVENT_TAG_PREFIX + device_name, tag=ea.HEALTH_PILL_EVENT_TAG_PREFIX + device_name,
node_name='%s:%d:DebugNumericSummary' % (op_name, output_slot)) node_name='%s:%d:DebugNumericSummary' % (op_name, output_slot))
value.tensor.tensor_shape.dim.add(size=len(elements)) value.tensor.tensor_shape.dim.add(size=len(elements))
value.tensor.dtype = types_pb2.DT_DOUBLE value.tensor.dtype = 2 # DT_DOUBLE
value.tensor.tensor_content = np.array(elements, dtype=np.float64).tobytes() value.tensor.tensor_content = np.array(elements, dtype=np.float64).tobytes()
self.AddEvent(event) self.AddEvent(event)
@ -282,11 +280,11 @@ class MockingEventAccumulatorTest(EventAccumulatorTest):
gen = _EventGenerator(self) gen = _EventGenerator(self)
acc = ea.EventAccumulator(gen) acc = ea.EventAccumulator(gen)
health_pill_elements_1 = list(range(1, 13)) + [ health_pill_elements_1 = list(range(1, 13)) + [
float(types_pb2.DT_FLOAT), 2.0, 1.0, 2.0, 0.0, 0.0, 0.0, 0.0] float(1), 2.0, 1.0, 2.0, 0.0, 0.0, 0.0, 0.0]
gen.AddHealthPill(13371337, 41, '/job:localhost/replica:0/task:0/cpu:0', gen.AddHealthPill(13371337, 41, '/job:localhost/replica:0/task:0/cpu:0',
'Add', 0, health_pill_elements_1) 'Add', 0, health_pill_elements_1)
health_pill_elements_2 = list(range(42, 54)) + [ health_pill_elements_2 = list(range(42, 54)) + [
float(types_pb2.DT_DOUBLE), 2.0, 1.0, 2.0, 0.0, 0.0, 0.0, 0.0] float(2), 2.0, 1.0, 2.0, 0.0, 0.0, 0.0, 0.0]
gen.AddHealthPill(13381338, 42, '/job:localhost/replica:0/task:0/gpu:0', gen.AddHealthPill(13381338, 42, '/job:localhost/replica:0/task:0/gpu:0',
'Add', 1, health_pill_elements_2) 'Add', 1, health_pill_elements_2)
acc.Reload() acc.Reload()
@ -319,11 +317,11 @@ class MockingEventAccumulatorTest(EventAccumulatorTest):
gen = _EventGenerator(self) gen = _EventGenerator(self)
acc = ea.EventAccumulator(gen) acc = ea.EventAccumulator(gen)
health_pill_elements_1 = list(range(1, 13)) + [ health_pill_elements_1 = list(range(1, 13)) + [
float(types_pb2.DT_FLOAT), 2.0, 1.0, 2.0, 0.0, 0.0, 0.0, 0.0] float(1), 2.0, 1.0, 2.0, 0.0, 0.0, 0.0, 0.0]
gen.AddHealthPill(13371337, 41, '/job:localhost/replica:0/task:0/cpu:0', gen.AddHealthPill(13371337, 41, '/job:localhost/replica:0/task:0/cpu:0',
'Add', 0, health_pill_elements_1) 'Add', 0, health_pill_elements_1)
health_pill_elements_2 = list(range(42, 54)) + [ health_pill_elements_2 = list(range(42, 54)) + [
float(types_pb2.DT_DOUBLE), 2.0, 1.0, 2.0, 0.0, 0.0, 0.0, 0.0] float(2), 2.0, 1.0, 2.0, 0.0, 0.0, 0.0, 0.0]
gen.AddHealthPill(13381338, 42, '/job:localhost/replica:0/task:0/cpu:0', gen.AddHealthPill(13381338, 42, '/job:localhost/replica:0/task:0/cpu:0',
'MatMul', 1, health_pill_elements_2) 'MatMul', 1, health_pill_elements_2)
acc.Reload() acc.Reload()
@ -850,11 +848,11 @@ class MockingEventAccumulatorTest(EventAccumulatorTest):
}) })
scalar_proto = accumulator.Tensors('scalar')[0].tensor_proto scalar_proto = accumulator.Tensors('scalar')[0].tensor_proto
scalar = tensor_util.MakeNdarray(scalar_proto) scalar = tf.make_ndarray(scalar_proto)
vector_proto = accumulator.Tensors('vector')[0].tensor_proto vector_proto = accumulator.Tensors('vector')[0].tensor_proto
vector = tensor_util.MakeNdarray(vector_proto) vector = tf.make_ndarray(vector_proto)
string_proto = accumulator.Tensors('string')[0].tensor_proto string_proto = accumulator.Tensors('string')[0].tensor_proto
string = tensor_util.MakeNdarray(string_proto) string = tf.make_ndarray(string_proto)
self.assertTrue(np.array_equal(scalar, 1.0)) self.assertTrue(np.array_equal(scalar, 1.0))
self.assertTrue(np.array_equal(vector, [1.0, 2.0, 3.0])) self.assertTrue(np.array_equal(vector, [1.0, 2.0, 3.0]))

View File

@ -32,9 +32,9 @@ import bleach
import markdown import markdown
import six import six
# pylint: enable=g-bad-import-order # pylint: enable=g-bad-import-order
import tensorflow as tf
from werkzeug import wrappers from werkzeug import wrappers
from tensorflow.python.framework import tensor_util
from tensorflow.python.summary import text_summary from tensorflow.python.summary import text_summary
from tensorflow.tensorboard.backend import http_util from tensorflow.tensorboard.backend import http_util
from tensorflow.tensorboard.plugins import base_plugin from tensorflow.tensorboard.plugins import base_plugin
@ -240,7 +240,7 @@ def text_array_to_html(text_arr):
def process_string_tensor_event(event): def process_string_tensor_event(event):
"""Convert a TensorEvent into a JSON-compatible response.""" """Convert a TensorEvent into a JSON-compatible response."""
string_arr = tensor_util.MakeNdarray(event.tensor_proto) string_arr = tf.make_ndarray(event.tensor_proto)
html = text_array_to_html(string_arr) html = text_array_to_html(string_arr)
return { return {
'wall_time': event.wall_time, 'wall_time': event.wall_time,

View File

@ -0,0 +1,84 @@
path: "tensorflow.MetaGraphDef.CollectionDefEntry"
tf_class {
is_instance: "<class \'tensorflow.core.protobuf.meta_graph_pb2.CollectionDefEntry\'>"
is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
member {
name: "DESCRIPTOR"
mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
}
member {
name: "Extensions"
mtype: "<type \'getset_descriptor\'>"
}
member {
name: "KEY_FIELD_NUMBER"
mtype: "<type \'int\'>"
}
member {
name: "VALUE_FIELD_NUMBER"
mtype: "<type \'int\'>"
}
member_method {
name: "ByteSize"
}
member_method {
name: "Clear"
}
member_method {
name: "ClearExtension"
}
member_method {
name: "ClearField"
}
member_method {
name: "CopyFrom"
}
member_method {
name: "DiscardUnknownFields"
}
member_method {
name: "FindInitializationErrors"
}
member_method {
name: "FromString"
}
member_method {
name: "HasExtension"
}
member_method {
name: "HasField"
}
member_method {
name: "IsInitialized"
}
member_method {
name: "ListFields"
}
member_method {
name: "MergeFrom"
}
member_method {
name: "MergeFromString"
}
member_method {
name: "ParseFromString"
}
member_method {
name: "RegisterExtension"
}
member_method {
name: "SerializePartialToString"
}
member_method {
name: "SerializeToString"
}
member_method {
name: "SetInParent"
}
member_method {
name: "WhichOneof"
}
member_method {
name: "__init__"
}
}

View File

@ -0,0 +1,100 @@
path: "tensorflow.MetaGraphDef.MetaInfoDef"
tf_class {
is_instance: "<class \'tensorflow.core.protobuf.meta_graph_pb2.MetaInfoDef\'>"
is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
member {
name: "ANY_INFO_FIELD_NUMBER"
mtype: "<type \'int\'>"
}
member {
name: "DESCRIPTOR"
mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
}
member {
name: "Extensions"
mtype: "<type \'getset_descriptor\'>"
}
member {
name: "META_GRAPH_VERSION_FIELD_NUMBER"
mtype: "<type \'int\'>"
}
member {
name: "STRIPPED_OP_LIST_FIELD_NUMBER"
mtype: "<type \'int\'>"
}
member {
name: "TAGS_FIELD_NUMBER"
mtype: "<type \'int\'>"
}
member {
name: "TENSORFLOW_GIT_VERSION_FIELD_NUMBER"
mtype: "<type \'int\'>"
}
member {
name: "TENSORFLOW_VERSION_FIELD_NUMBER"
mtype: "<type \'int\'>"
}
member_method {
name: "ByteSize"
}
member_method {
name: "Clear"
}
member_method {
name: "ClearExtension"
}
member_method {
name: "ClearField"
}
member_method {
name: "CopyFrom"
}
member_method {
name: "DiscardUnknownFields"
}
member_method {
name: "FindInitializationErrors"
}
member_method {
name: "FromString"
}
member_method {
name: "HasExtension"
}
member_method {
name: "HasField"
}
member_method {
name: "IsInitialized"
}
member_method {
name: "ListFields"
}
member_method {
name: "MergeFrom"
}
member_method {
name: "MergeFromString"
}
member_method {
name: "ParseFromString"
}
member_method {
name: "RegisterExtension"
}
member_method {
name: "SerializePartialToString"
}
member_method {
name: "SerializeToString"
}
member_method {
name: "SetInParent"
}
member_method {
name: "WhichOneof"
}
member_method {
name: "__init__"
}
}

View File

@ -0,0 +1,84 @@
path: "tensorflow.MetaGraphDef.SignatureDefEntry"
tf_class {
is_instance: "<class \'tensorflow.core.protobuf.meta_graph_pb2.SignatureDefEntry\'>"
is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
member {
name: "DESCRIPTOR"
mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
}
member {
name: "Extensions"
mtype: "<type \'getset_descriptor\'>"
}
member {
name: "KEY_FIELD_NUMBER"
mtype: "<type \'int\'>"
}
member {
name: "VALUE_FIELD_NUMBER"
mtype: "<type \'int\'>"
}
member_method {
name: "ByteSize"
}
member_method {
name: "Clear"
}
member_method {
name: "ClearExtension"
}
member_method {
name: "ClearField"
}
member_method {
name: "CopyFrom"
}
member_method {
name: "DiscardUnknownFields"
}
member_method {
name: "FindInitializationErrors"
}
member_method {
name: "FromString"
}
member_method {
name: "HasExtension"
}
member_method {
name: "HasField"
}
member_method {
name: "IsInitialized"
}
member_method {
name: "ListFields"
}
member_method {
name: "MergeFrom"
}
member_method {
name: "MergeFromString"
}
member_method {
name: "ParseFromString"
}
member_method {
name: "RegisterExtension"
}
member_method {
name: "SerializePartialToString"
}
member_method {
name: "SerializeToString"
}
member_method {
name: "SetInParent"
}
member_method {
name: "WhichOneof"
}
member_method {
name: "__init__"
}
}

View File

@ -0,0 +1,112 @@
path: "tensorflow.MetaGraphDef"
tf_class {
is_instance: "<class \'tensorflow.core.protobuf.meta_graph_pb2.MetaGraphDef\'>"
is_instance: "<type \'google.protobuf.pyext._message.CMessage\'>"
member {
name: "ASSET_FILE_DEF_FIELD_NUMBER"
mtype: "<type \'int\'>"
}
member {
name: "COLLECTION_DEF_FIELD_NUMBER"
mtype: "<type \'int\'>"
}
member {
name: "CollectionDefEntry"
mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
}
member {
name: "DESCRIPTOR"
mtype: "<type \'google.protobuf.pyext._message.MessageDescriptor\'>"
}
member {
name: "Extensions"
mtype: "<type \'getset_descriptor\'>"
}
member {
name: "GRAPH_DEF_FIELD_NUMBER"
mtype: "<type \'int\'>"
}
member {
name: "META_INFO_DEF_FIELD_NUMBER"
mtype: "<type \'int\'>"
}
member {
name: "MetaInfoDef"
mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
}
member {
name: "SAVER_DEF_FIELD_NUMBER"
mtype: "<type \'int\'>"
}
member {
name: "SIGNATURE_DEF_FIELD_NUMBER"
mtype: "<type \'int\'>"
}
member {
name: "SignatureDefEntry"
mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
}
member_method {
name: "ByteSize"
}
member_method {
name: "Clear"
}
member_method {
name: "ClearExtension"
}
member_method {
name: "ClearField"
}
member_method {
name: "CopyFrom"
}
member_method {
name: "DiscardUnknownFields"
}
member_method {
name: "FindInitializationErrors"
}
member_method {
name: "FromString"
}
member_method {
name: "HasExtension"
}
member_method {
name: "HasField"
}
member_method {
name: "IsInitialized"
}
member_method {
name: "ListFields"
}
member_method {
name: "MergeFrom"
}
member_method {
name: "MergeFromString"
}
member_method {
name: "ParseFromString"
}
member_method {
name: "RegisterExtension"
}
member_method {
name: "SerializePartialToString"
}
member_method {
name: "SerializeToString"
}
member_method {
name: "SetInParent"
}
member_method {
name: "WhichOneof"
}
member_method {
name: "__init__"
}
}

View File

@ -116,6 +116,10 @@ tf_module {
name: "LogMessage" name: "LogMessage"
mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>" mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
} }
member {
name: "MetaGraphDef"
mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
}
member { member {
name: "NameAttrList" name: "NameAttrList"
mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>" mtype: "<class \'google.protobuf.pyext.cpp_message.GeneratedProtocolMessageType\'>"
@ -1196,10 +1200,18 @@ tf_module {
name: "logical_xor" name: "logical_xor"
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'LogicalXor\'], " argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'LogicalXor\'], "
} }
member_method {
name: "make_ndarray"
argspec: "args=[\'tensor\'], varargs=None, keywords=None, defaults=None"
}
member_method { member_method {
name: "make_template" name: "make_template"
argspec: "args=[\'name_\', \'func_\', \'create_scope_now_\', \'unique_name_\', \'custom_getter_\'], varargs=None, keywords=kwargs, defaults=[\'False\', \'None\', \'None\'], " argspec: "args=[\'name_\', \'func_\', \'create_scope_now_\', \'unique_name_\', \'custom_getter_\'], varargs=None, keywords=kwargs, defaults=[\'False\', \'None\', \'None\'], "
} }
member_method {
name: "make_tensor_proto"
argspec: "args=[\'values\', \'dtype\', \'shape\', \'verify_shape\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\'], "
}
member_method { member_method {
name: "map_fn" name: "map_fn"
argspec: "args=[\'fn\', \'elems\', \'dtype\', \'parallel_iterations\', \'back_prop\', \'swap_memory\', \'infer_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'True\', \'False\', \'True\', \'None\'], " argspec: "args=[\'fn\', \'elems\', \'dtype\', \'parallel_iterations\', \'back_prop\', \'swap_memory\', \'infer_shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'True\', \'False\', \'True\', \'None\'], "