Fix the exporting of tf.io.decode_proto and tf.io.encode_proto so that they correctly trigger dispatching. Note: Although this changes the api_def pbtxt files to change how exporting happens, the actual exported apis should not changed.

PiperOrigin-RevId: 320641446
Change-Id: I70760649794f08fffb629e25244802f73cf8b256
This commit is contained in:
Tomer Kaftan 2020-07-10 11:31:12 -07:00 committed by TensorFlower Gardener
parent c14d4f0e21
commit 4d97a127c6
6 changed files with 19 additions and 7 deletions

View File

@ -1,6 +1,5 @@
op {
graph_op_name: "DecodeProtoV2"
visibility: HIDDEN
in_arg {
name: "bytes"
description: <<END

View File

@ -1,6 +1,5 @@
op {
graph_op_name: "EncodeProto"
visibility: HIDDEN
in_arg {
name: "sizes"
description: <<END

View File

@ -0,0 +1,6 @@
op {
graph_op_name: "DecodeProtoV2"
endpoint {
name: "io.decode_proto"
}
}

View File

@ -0,0 +1,6 @@
op {
graph_op_name: "EncodeProto"
endpoint {
name: "io.encode_proto"
}
}

View File

@ -19,14 +19,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import
from tensorflow.python.framework import ops
from tensorflow.python.ops.gen_decode_proto_ops import decode_proto_v2 as decode_proto
from tensorflow.python.ops.gen_encode_proto_ops import encode_proto
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export
tf_export("io.decode_proto")(dispatch.add_dispatch_support(decode_proto))
tf_export("io.encode_proto")(dispatch.add_dispatch_support(encode_proto))
# pylint: enable=unused-import
ops.NotDifferentiable("DecodeProtoV2")
ops.NotDifferentiable("EncodeProto")

View File

@ -22,6 +22,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.proto_ops import decode_proto
from tensorflow.python.platform import googletest
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging
@ -188,6 +189,10 @@ class DispatchTest(test_util.TensorFlowTestCase):
str(trace),
"math.reduce_sum(math.add(name=None, x=math.abs(x), y=y), axis=3)")
proto_val = TensorTracer("proto")
trace = decode_proto(proto_val, "message_type", ["field"], ["float32"])
self.assertIn("io.decode_proto(bytes=proto,", str(trace))
finally:
# Clean up.
dispatch._GLOBAL_DISPATCHERS = original_global_dispatchers