Fix the exporting of tf.io.decode_proto and tf.io.encode_proto so that they correctly trigger dispatching. Note: Although this changes the api_def pbtxt files to change how exporting happens, the actual exported apis should not changed.
PiperOrigin-RevId: 320641446 Change-Id: I70760649794f08fffb629e25244802f73cf8b256
This commit is contained in:
parent
c14d4f0e21
commit
4d97a127c6
@ -1,6 +1,5 @@
|
||||
op {
|
||||
graph_op_name: "DecodeProtoV2"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "bytes"
|
||||
description: <<END
|
||||
|
@ -1,6 +1,5 @@
|
||||
op {
|
||||
graph_op_name: "EncodeProto"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "sizes"
|
||||
description: <<END
|
||||
|
@ -0,0 +1,6 @@
|
||||
op {
|
||||
graph_op_name: "DecodeProtoV2"
|
||||
endpoint {
|
||||
name: "io.decode_proto"
|
||||
}
|
||||
}
|
@ -0,0 +1,6 @@
|
||||
op {
|
||||
graph_op_name: "EncodeProto"
|
||||
endpoint {
|
||||
name: "io.encode_proto"
|
||||
}
|
||||
}
|
@ -19,14 +19,11 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=unused-import
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops.gen_decode_proto_ops import decode_proto_v2 as decode_proto
|
||||
from tensorflow.python.ops.gen_encode_proto_ops import encode_proto
|
||||
from tensorflow.python.util import dispatch
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
tf_export("io.decode_proto")(dispatch.add_dispatch_support(decode_proto))
|
||||
tf_export("io.encode_proto")(dispatch.add_dispatch_support(encode_proto))
|
||||
# pylint: enable=unused-import
|
||||
|
||||
ops.NotDifferentiable("DecodeProtoV2")
|
||||
ops.NotDifferentiable("EncodeProto")
|
||||
|
@ -22,6 +22,7 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import gen_math_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.proto_ops import decode_proto
|
||||
from tensorflow.python.platform import googletest
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.platform import tf_logging
|
||||
@ -188,6 +189,10 @@ class DispatchTest(test_util.TensorFlowTestCase):
|
||||
str(trace),
|
||||
"math.reduce_sum(math.add(name=None, x=math.abs(x), y=y), axis=3)")
|
||||
|
||||
proto_val = TensorTracer("proto")
|
||||
trace = decode_proto(proto_val, "message_type", ["field"], ["float32"])
|
||||
self.assertIn("io.decode_proto(bytes=proto,", str(trace))
|
||||
|
||||
finally:
|
||||
# Clean up.
|
||||
dispatch._GLOBAL_DISPATCHERS = original_global_dispatchers
|
||||
|
Loading…
Reference in New Issue
Block a user