Add tf.summary.experimental.write_raw_pb() to replace import_event()
This removes import_event() from the TF 2.0 tf.summary API and replaces it with an experimental write_raw_pb() function, which takes one or more tf.compat.v1.Summary serialized protocol buffers (as a string tensor) and a step value and writes them to the default summary writer. The new op provides a way to use legacy code that constructs tf.compat.v1.Summary objects by hand with the TF 2.0 summary writing API, and unlike import_event(), the op behaves like write() in terms of respecting the recording condition. It also doesn't require constructing a tf.compat.v1.Event protobuf to wrap the Summary. While adjusting the 2.0 APIs, I also moved summary_scope() into tf.summary.experimental since it's meant primarily as an implementation detail of 2.0 summary APIs and is likely to change in the future. PiperOrigin-RevId: 243177697
This commit is contained in:
parent
4fc3d561ed
commit
1ea3483d48
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "WriteRawProtoSummary"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/framework/summary.pb.h"
|
||||
#include "tensorflow/core/lib/db/sqlite.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/summary/schema.h"
|
||||
@ -147,6 +148,43 @@ class WriteSummaryOp : public OpKernel {
|
||||
REGISTER_KERNEL_BUILDER(Name("WriteSummary").Device(DEVICE_CPU),
|
||||
WriteSummaryOp);
|
||||
|
||||
class WriteRawProtoSummaryOp : public OpKernel {
|
||||
public:
|
||||
explicit WriteRawProtoSummaryOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void Compute(OpKernelContext* ctx) override {
|
||||
SummaryWriterInterface* s;
|
||||
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &s));
|
||||
core::ScopedUnref unref(s);
|
||||
const Tensor* tmp;
|
||||
OP_REQUIRES_OK(ctx, ctx->input("step", &tmp));
|
||||
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(tmp->shape()),
|
||||
errors::InvalidArgument("step must be scalar, got shape ",
|
||||
tmp->shape().DebugString()));
|
||||
const int64 step = tmp->scalar<int64>()();
|
||||
const Tensor* t;
|
||||
OP_REQUIRES_OK(ctx, ctx->input("tensor", &t));
|
||||
std::unique_ptr<Event> event{new Event};
|
||||
event->set_step(step);
|
||||
event->set_wall_time(static_cast<double>(ctx->env()->NowMicros()) / 1.0e6);
|
||||
// Each Summary proto contains just one repeated field "value" of Value
|
||||
// messages with the actual data, so repeated Merge() is equivalent to
|
||||
// concatenating all the Value entries together into a single Event.
|
||||
const auto summary_pbs = t->flat<string>();
|
||||
for (int i = 0; i < summary_pbs.size(); ++i) {
|
||||
if (!event->mutable_summary()->MergeFromString(summary_pbs(i))) {
|
||||
ctx->CtxFailureWithWarning(errors::DataLoss(
|
||||
"Bad tf.compat.v1.Summary binary proto tensor string at index ",
|
||||
i));
|
||||
return;
|
||||
}
|
||||
}
|
||||
OP_REQUIRES_OK(ctx, s->WriteEvent(std::move(event)));
|
||||
}
|
||||
};
|
||||
REGISTER_KERNEL_BUILDER(Name("WriteRawProtoSummary").Device(DEVICE_CPU),
|
||||
WriteRawProtoSummaryOp);
|
||||
|
||||
class ImportEventOp : public OpKernel {
|
||||
public:
|
||||
explicit ImportEventOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
@ -57,6 +57,12 @@ REGISTER_OP("WriteSummary")
|
||||
.Attr("T: type")
|
||||
.SetShapeFn(shape_inference::NoOutputs);
|
||||
|
||||
REGISTER_OP("WriteRawProtoSummary")
|
||||
.Input("writer: resource")
|
||||
.Input("step: int64")
|
||||
.Input("tensor: string")
|
||||
.SetShapeFn(shape_inference::NoOutputs);
|
||||
|
||||
REGISTER_OP("ImportEvent")
|
||||
.Input("writer: resource")
|
||||
.Input("event: string")
|
||||
|
@ -407,6 +407,69 @@ class SummaryOpsCoreTest(test_util.TensorFlowTestCase):
|
||||
self.assertEqual(2, events[2].step)
|
||||
self.assertEqual(4, events[3].step)
|
||||
|
||||
def testWriteRawPb(self):
|
||||
logdir = self.get_temp_dir()
|
||||
pb = summary_pb2.Summary()
|
||||
pb.value.add().simple_value = 42.0
|
||||
with context.eager_mode():
|
||||
with summary_ops.create_file_writer_v2(logdir).as_default():
|
||||
output = summary_ops.write_raw_pb(pb.SerializeToString(), step=12)
|
||||
self.assertTrue(output.numpy())
|
||||
events = events_from_logdir(logdir)
|
||||
self.assertEqual(2, len(events))
|
||||
self.assertEqual(12, events[1].step)
|
||||
self.assertProtoEquals(pb, events[1].summary)
|
||||
|
||||
def testWriteRawPb_fromFunction(self):
|
||||
logdir = self.get_temp_dir()
|
||||
pb = summary_pb2.Summary()
|
||||
pb.value.add().simple_value = 42.0
|
||||
with context.eager_mode():
|
||||
writer = summary_ops.create_file_writer_v2(logdir)
|
||||
@def_function.function
|
||||
def f():
|
||||
with writer.as_default():
|
||||
return summary_ops.write_raw_pb(pb.SerializeToString(), step=12)
|
||||
output = f()
|
||||
self.assertTrue(output.numpy())
|
||||
events = events_from_logdir(logdir)
|
||||
self.assertEqual(2, len(events))
|
||||
self.assertEqual(12, events[1].step)
|
||||
self.assertProtoEquals(pb, events[1].summary)
|
||||
|
||||
def testWriteRawPb_multipleValues(self):
|
||||
logdir = self.get_temp_dir()
|
||||
pb1 = summary_pb2.Summary()
|
||||
pb1.value.add().simple_value = 1.0
|
||||
pb1.value.add().simple_value = 2.0
|
||||
pb2 = summary_pb2.Summary()
|
||||
pb2.value.add().simple_value = 3.0
|
||||
pb3 = summary_pb2.Summary()
|
||||
pb3.value.add().simple_value = 4.0
|
||||
pb3.value.add().simple_value = 5.0
|
||||
pb3.value.add().simple_value = 6.0
|
||||
pbs = [pb.SerializeToString() for pb in (pb1, pb2, pb3)]
|
||||
with context.eager_mode():
|
||||
with summary_ops.create_file_writer_v2(logdir).as_default():
|
||||
output = summary_ops.write_raw_pb(pbs, step=12)
|
||||
self.assertTrue(output.numpy())
|
||||
events = events_from_logdir(logdir)
|
||||
self.assertEqual(2, len(events))
|
||||
self.assertEqual(12, events[1].step)
|
||||
expected_pb = summary_pb2.Summary()
|
||||
for i in range(6):
|
||||
expected_pb.value.add().simple_value = i + 1.0
|
||||
self.assertProtoEquals(expected_pb, events[1].summary)
|
||||
|
||||
def testWriteRawPb_invalidValue(self):
|
||||
logdir = self.get_temp_dir()
|
||||
with context.eager_mode():
|
||||
with summary_ops.create_file_writer_v2(logdir).as_default():
|
||||
with self.assertRaisesRegex(
|
||||
errors.DataLossError,
|
||||
'Bad tf.compat.v1.Summary binary proto tensor string'):
|
||||
summary_ops.write_raw_pb('notaproto', step=12)
|
||||
|
||||
@test_util.also_run_as_tf_function
|
||||
def testGetSetStep(self):
|
||||
try:
|
||||
|
@ -542,10 +542,10 @@ def summary_writer_initializer_op():
|
||||
_INVALID_SCOPE_CHARACTERS = re.compile(r"[^-_/.A-Za-z0-9]")
|
||||
|
||||
|
||||
@tf_export("summary.summary_scope", v1=[])
|
||||
@tf_export("summary.experimental.summary_scope", v1=[])
|
||||
@tf_contextlib.contextmanager
|
||||
def summary_scope(name, default_name="summary", values=None):
|
||||
"""A context manager for use when defining a custom summary op.
|
||||
"""Experimental context manager for use when defining a custom summary op.
|
||||
|
||||
This behaves similarly to `tf.name_scope`, except that it returns a generated
|
||||
summary tag in addition to the scope name. The tag is structurally similar to
|
||||
@ -642,6 +642,54 @@ def write(tag, tensor, step=None, metadata=None, name=None):
|
||||
_should_record_summaries_v2(), record, _nothing, name="summary_cond")
|
||||
|
||||
|
||||
@tf_export("summary.experimental.write_raw_pb", v1=[])
|
||||
def write_raw_pb(tensor, step=None, name=None):
|
||||
"""Writes a summary using raw `tf.compat.v1.Summary` protocol buffers.
|
||||
|
||||
Experimental: this exists to support the usage of V1-style manual summary
|
||||
writing (via the construction of a `tf.compat.v1.Summary` protocol buffer)
|
||||
with the V2 summary writing API.
|
||||
|
||||
Args:
|
||||
tensor: the string Tensor holding one or more serialized `Summary` protobufs
|
||||
step: Explicit `int64`-castable monotonic step value for this summary. If
|
||||
omitted, this defaults to `tf.summary.experimental.get_step()`, which must
|
||||
not be None.
|
||||
name: Optional string name for this op.
|
||||
|
||||
Returns:
|
||||
True on success, or false if no summary was written because no default
|
||||
summary writer was available.
|
||||
|
||||
Raises:
|
||||
ValueError: if a default writer exists, but no step was provided and
|
||||
`tf.summary.experimental.get_step()` is None.
|
||||
"""
|
||||
with ops.name_scope(name, "write_raw_pb") as scope:
|
||||
if context.context().summary_writer is None:
|
||||
return constant_op.constant(False)
|
||||
if step is None:
|
||||
step = get_step()
|
||||
if step is None:
|
||||
raise ValueError("No step set via 'step' argument or "
|
||||
"tf.summary.experimental.set_step()")
|
||||
|
||||
def record():
|
||||
"""Record the actual summary and return True."""
|
||||
# Note the identity to move the tensor to the CPU.
|
||||
with ops.device("cpu:0"):
|
||||
raw_summary_op = gen_summary_ops.write_raw_proto_summary(
|
||||
context.context().summary_writer._resource, # pylint: disable=protected-access
|
||||
step,
|
||||
array_ops.identity(tensor),
|
||||
name=scope)
|
||||
with ops.control_dependencies([raw_summary_op]):
|
||||
return constant_op.constant(True)
|
||||
|
||||
return smart_cond.smart_cond(
|
||||
_should_record_summaries_v2(), record, _nothing, name="summary_cond")
|
||||
|
||||
|
||||
def summary_writer_function(name, tensor, function, family=None):
|
||||
"""Helper function to write summaries.
|
||||
|
||||
@ -826,7 +874,6 @@ def graph(param, step=None, name=None):
|
||||
_graph = graph # for functions with a graph parameter
|
||||
|
||||
|
||||
@tf_export("summary.import_event", v1=[])
|
||||
def import_event(tensor, name=None):
|
||||
"""Writes a `tf.Event` binary proto.
|
||||
|
||||
|
@ -4400,6 +4400,10 @@ tf_module {
|
||||
name: "WriteImageSummary"
|
||||
argspec: "args=[\'writer\', \'step\', \'tag\', \'tensor\', \'bad_color\', \'max_images\', \'name\'], varargs=None, keywords=None, defaults=[\'3\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "WriteRawProtoSummary"
|
||||
argspec: "args=[\'writer\', \'step\', \'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "WriteScalarSummary"
|
||||
argspec: "args=[\'writer\', \'step\', \'tag\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -4400,6 +4400,10 @@ tf_module {
|
||||
name: "WriteImageSummary"
|
||||
argspec: "args=[\'writer\', \'step\', \'tag\', \'tensor\', \'bad_color\', \'max_images\', \'name\'], varargs=None, keywords=None, defaults=[\'3\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "WriteRawProtoSummary"
|
||||
argspec: "args=[\'writer\', \'step\', \'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "WriteScalarSummary"
|
||||
argspec: "args=[\'writer\', \'step\', \'tag\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -8,4 +8,12 @@ tf_module {
|
||||
name: "set_step"
|
||||
argspec: "args=[\'step\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "summary_scope"
|
||||
argspec: "args=[\'name\', \'default_name\', \'values\'], varargs=None, keywords=None, defaults=[\'summary\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "write_raw_pb"
|
||||
argspec: "args=[\'tensor\', \'step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
}
|
||||
|
@ -32,10 +32,6 @@ tf_module {
|
||||
name: "image"
|
||||
argspec: "args=[\'name\', \'data\', \'step\', \'max_outputs\', \'description\'], varargs=None, keywords=None, defaults=[\'None\', \'3\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "import_event"
|
||||
argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "record_if"
|
||||
argspec: "args=[\'condition\'], varargs=None, keywords=None, defaults=None"
|
||||
@ -44,10 +40,6 @@ tf_module {
|
||||
name: "scalar"
|
||||
argspec: "args=[\'name\', \'data\', \'step\', \'description\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "summary_scope"
|
||||
argspec: "args=[\'name\', \'default_name\', \'values\'], varargs=None, keywords=None, defaults=[\'summary\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "text"
|
||||
argspec: "args=[\'name\', \'data\', \'step\', \'description\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user