Use CompactTextString instead of String for generating ops.

PiperOrigin-RevId: 311655146
Change-Id: I57e5c595522b47dd9badbf0720569ffef69fed66
This commit is contained in:
Jonathan Hseu 2020-05-14 19:43:19 -07:00 committed by TensorFlower Gardener
parent 3b225a9776
commit a98948acf8
2 changed files with 259 additions and 5 deletions

View File

@ -567,11 +567,10 @@ func isListAttr(attrdef *odpb.OpDef_AttrDef) bool {
// This is useful when 's' corresponds to a "oneof" protocol buffer message.
// For example, consider the protocol buffer message:
// oneof value { bool b = 1; int64 i = 2; }
// String() on a Go corresponding object (using proto.CompactTextString) will
// print "b:true", or "i:7" etc. This function strips out the leading "b:" or
// "i:".
func stripLeadingColon(s fmt.Stringer) string {
x := s.String()
// proto.CompactTextString) will print "b:true", or "i:7" etc. This function
// strips out the leading "b:" or "i:".
func stripLeadingColon(m proto.Message) string {
x := proto.CompactTextString(m)
y := strings.SplitN(x, ":", 2)
if len(y) < 2 {
return x

View File

@ -533,6 +533,261 @@ func TestOp(scope *Scope, bb tf.Output, aa tf.Output, optional ...TestOpAttr) (c
op := scope.AddOperation(opspec)
return op.Output(0)
}
`,
},
{
tag: "SampleDistortedBoundingBox",
opdef: `
name: "SampleDistortedBoundingBox"
input_arg {
name: "image_size"
type_attr: "T"
}
input_arg {
name: "bounding_boxes"
type: DT_FLOAT
}
output_arg {
name: "begin"
type_attr: "T"
}
output_arg {
name: "size"
type_attr: "T"
}
output_arg {
name: "bboxes"
type: DT_FLOAT
}
attr {
name: "T"
type: "type"
allowed_values {
list {
type: DT_UINT8
type: DT_INT8
type: DT_INT16
type: DT_INT32
type: DT_INT64
}
}
}
attr {
name: "seed"
type: "int"
default_value {
i: 0
}
}
attr {
name: "seed2"
type: "int"
default_value {
i: 0
}
}
attr {
name: "min_object_covered"
type: "float"
default_value {
f: 0.1
}
}
attr {
name: "aspect_ratio_range"
type: "list(float)"
default_value {
list {
f: 0.75
f: 1.33
}
}
}
attr {
name: "area_range"
type: "list(float)"
default_value {
list {
f: 0.05
f: 1
}
}
}
attr {
name: "max_attempts"
type: "int"
default_value {
i: 100
}
}
attr {
name: "use_image_if_no_bounding_boxes"
type: "bool"
default_value {
b: false
}
}
is_stateful: true
`,
apidef: `
op {
graph_op_name: "SampleDistortedBoundingBox"
in_arg {
name: "image_size"
description: "Blah blah"
}
in_arg {
name: "bounding_boxes"
description: "Blah blah"
}
out_arg {
name: "begin"
description: "Blah blah"
}
out_arg {
name: "size"
description: "Blah blah"
}
out_arg {
name: "bboxes"
description: "Blah blah"
}
attr {
name: "seed"
description: "Blah blah"
}
attr {
name: "seed2"
description: "Blah blah"
}
attr {
name: "min_object_covered"
description: "Blah blah"
}
attr {
name: "aspect_ratio_range"
description: "Blah blah"
}
attr {
name: "area_range"
description: "Blah blah"
}
attr {
name: "max_attempts"
description: "Blah blah"
}
attr {
name: "use_image_if_no_bounding_boxes"
description: "Blah blah"
}
summary: "Generate a single randomly distorted bounding box for an image."
description: "Blah blah"
}
`,
wanted: `
// SampleDistortedBoundingBoxAttr is an optional argument to SampleDistortedBoundingBox.
type SampleDistortedBoundingBoxAttr func(optionalAttr)
// SampleDistortedBoundingBoxSeed sets the optional seed attribute to value.
//
// value: Blah blah
// If not specified, defaults to 0
func SampleDistortedBoundingBoxSeed(value int64) SampleDistortedBoundingBoxAttr {
return func(m optionalAttr) {
m["seed"] = value
}
}
// SampleDistortedBoundingBoxSeed2 sets the optional seed2 attribute to value.
//
// value: Blah blah
// If not specified, defaults to 0
func SampleDistortedBoundingBoxSeed2(value int64) SampleDistortedBoundingBoxAttr {
return func(m optionalAttr) {
m["seed2"] = value
}
}
// SampleDistortedBoundingBoxMinObjectCovered sets the optional min_object_covered attribute to value.
//
// value: Blah blah
// If not specified, defaults to 0.1
func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBoundingBoxAttr {
return func(m optionalAttr) {
m["min_object_covered"] = value
}
}
// SampleDistortedBoundingBoxAspectRatioRange sets the optional aspect_ratio_range attribute to value.
//
// value: Blah blah
// If not specified, defaults to <f:0.75 f:1.33 >
func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr {
return func(m optionalAttr) {
m["aspect_ratio_range"] = value
}
}
// SampleDistortedBoundingBoxAreaRange sets the optional area_range attribute to value.
//
// value: Blah blah
// If not specified, defaults to <f:0.05 f:1 >
func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr {
return func(m optionalAttr) {
m["area_range"] = value
}
}
// SampleDistortedBoundingBoxMaxAttempts sets the optional max_attempts attribute to value.
//
// value: Blah blah
// If not specified, defaults to 100
func SampleDistortedBoundingBoxMaxAttempts(value int64) SampleDistortedBoundingBoxAttr {
return func(m optionalAttr) {
m["max_attempts"] = value
}
}
// SampleDistortedBoundingBoxUseImageIfNoBoundingBoxes sets the optional use_image_if_no_bounding_boxes attribute to value.
//
// value: Blah blah
// If not specified, defaults to false
func SampleDistortedBoundingBoxUseImageIfNoBoundingBoxes(value bool) SampleDistortedBoundingBoxAttr {
return func(m optionalAttr) {
m["use_image_if_no_bounding_boxes"] = value
}
}
// Generate a single randomly distorted bounding box for an image.
//
// Blah blah
//
// Arguments:
// image_size: Blah blah
// bounding_boxes: Blah blah
//
// Returns:
// begin: Blah blah
// size: Blah blah
// bboxes: Blah blah
func SampleDistortedBoundingBox(scope *Scope, image_size tf.Output, bounding_boxes tf.Output, optional ...SampleDistortedBoundingBoxAttr) (begin tf.Output, size tf.Output, bboxes tf.Output) {
if scope.Err() != nil {
return
}
attrs := map[string]interface{}{}
for _, a := range optional {
a(attrs)
}
opspec := tf.OpSpec{
Type: "SampleDistortedBoundingBox",
Input: []tf.Input{
image_size, bounding_boxes,
},
Attrs: attrs,
}
op := scope.AddOperation(opspec)
return op.Output(0), op.Output(1), op.Output(2)
}
`,
},
}