Go: Support setting shape valued attributes.
Fixes #6833 Change: 144752893
This commit is contained in:
parent
0662eabf9d
commit
f8d75baaf4
@ -395,7 +395,7 @@ func goType(tfType string) (string, error) {
|
||||
case "type":
|
||||
gotype = "tf.DataType"
|
||||
case "shape":
|
||||
gotype = "[]int64"
|
||||
gotype = "tf.Shape"
|
||||
case "tensor":
|
||||
gotype = "tf.Tensor"
|
||||
case "string":
|
||||
|
@ -259,13 +259,38 @@ func setAttr(cdesc *C.TF_OperationDescription, status *status, name string, valu
|
||||
if err := status.Err(); err != nil {
|
||||
return fmt.Errorf("bad value for attribute %q: %v", name, err)
|
||||
}
|
||||
case Shape:
|
||||
ndims, dims := cshape(value)
|
||||
var dimsp *C.int64_t
|
||||
if ndims > 0 {
|
||||
dimsp = &dims[0]
|
||||
}
|
||||
C.TF_SetAttrShape(cdesc, cAttrName, dimsp, ndims)
|
||||
case []Shape:
|
||||
ndims := make([]C.int, len(value))
|
||||
dims := make([][]C.int64_t, len(value))
|
||||
dimsp := make([]*C.int64_t, len(value))
|
||||
for i, s := range value {
|
||||
ndims[i], dims[i] = cshape(s)
|
||||
if ndims[i] > 0 {
|
||||
dimsp[i] = &dims[i][0]
|
||||
}
|
||||
}
|
||||
C.TF_SetAttrShapeList(cdesc, cAttrName, &dimsp[0], &ndims[0], C.int(len(value)))
|
||||
default:
|
||||
// Shapes can be done, but will require that it be
|
||||
// distinguishable from []int64. Which is fine, it
|
||||
// probably makes sense to define a Shape type anyway,
|
||||
// since that should handle partially known shapes as
|
||||
// well and hide the special meaning of -1?
|
||||
return fmt.Errorf("attribute %q has a type (%T) which is not valid for operation attributes", name, value)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func cshape(s Shape) (C.int, []C.int64_t) {
|
||||
ndims := C.int(s.NumDimensions())
|
||||
if ndims < 0 {
|
||||
return -1, nil
|
||||
}
|
||||
dims := make([]C.int64_t, ndims)
|
||||
for i, s := range s.dims {
|
||||
dims[i] = C.int64_t(s)
|
||||
}
|
||||
return ndims, dims
|
||||
}
|
||||
|
33
tensorflow/go/op/op_test.go
Normal file
33
tensorflow/go/op/op_test.go
Normal file
@ -0,0 +1,33 @@
|
||||
// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Tests for the generated code of some operations.
|
||||
|
||||
package op
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
tf "github.com/tensorflow/tensorflow/tensorflow/go"
|
||||
)
|
||||
|
||||
func TestPlaceholder(t *testing.T) {
|
||||
s := NewScope()
|
||||
Placeholder(s.SubScope("x"), tf.Float, PlaceholderShape(tf.MakeShape(-1, 10)))
|
||||
Placeholder(s.SubScope("y"), tf.Float, PlaceholderShape(tf.ScalarShape()))
|
||||
Placeholder(s.SubScope("z"), tf.Float, PlaceholderShape(tf.Shape{}))
|
||||
if _, err := s.Finalize(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
@ -81,6 +81,21 @@ func TestOperationOutputListSize(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestOperationShapeAttribute(t *testing.T) {
|
||||
g := NewGraph()
|
||||
_, err := g.AddOperation(OpSpec{
|
||||
Type: "Placeholder",
|
||||
Attrs: map[string]interface{}{
|
||||
"dtype": Float,
|
||||
"shape": MakeShape(-1, 3),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// If and when the API to get attributes is added, check that here.
|
||||
}
|
||||
|
||||
func TestOutputShape(t *testing.T) {
|
||||
graph := NewGraph()
|
||||
testdata := []struct {
|
||||
|
102
tensorflow/go/shape.go
Normal file
102
tensorflow/go/shape.go
Normal file
@ -0,0 +1,102 @@
|
||||
// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package tensorflow
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Shape represents the (possibly partially known) shape of a tensor that will
|
||||
// be produced by an operation.
|
||||
//
|
||||
// The zero-value of a Shape represents a shape with an unknown number of
|
||||
// dimensions.
|
||||
type Shape struct {
|
||||
dims []int64
|
||||
}
|
||||
|
||||
// ScalarShape returns a Shape representing a scalar.
|
||||
func ScalarShape() Shape {
|
||||
return Shape{dims: make([]int64, 0)}
|
||||
}
|
||||
|
||||
// MakeShape returns a Shape with the provided size of each dimension.
|
||||
//
|
||||
// A value of -1 implies that the size of the corresponding dimension is not
|
||||
// known.
|
||||
func MakeShape(shape ...int64) Shape {
|
||||
cpy := make([]int64, len(shape))
|
||||
copy(cpy, shape)
|
||||
return Shape{dims: cpy}
|
||||
}
|
||||
|
||||
// NumDimensions returns the number of dimensions represented by s, or -1 if
|
||||
// unknown.
|
||||
func (s Shape) NumDimensions() int {
|
||||
if s.dims == nil {
|
||||
return -1
|
||||
}
|
||||
return len(s.dims)
|
||||
}
|
||||
|
||||
// Size returns the size of the dim-th dimension of the shape, or -1 if it
|
||||
// is unknown.
|
||||
//
|
||||
// REQUIRES: 0 <= dim < s.NumDimensions()
|
||||
func (s Shape) Size(dim int) int64 {
|
||||
if dim < 0 || dim > s.NumDimensions() {
|
||||
return -1
|
||||
}
|
||||
return s.dims[dim]
|
||||
}
|
||||
|
||||
// IsFullySpecified returns true iff the size of all the dimensions of s are
|
||||
// known.
|
||||
func (s Shape) IsFullySpecified() bool {
|
||||
if s.dims == nil {
|
||||
return false
|
||||
}
|
||||
for _, size := range s.dims {
|
||||
if size <= 1 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// ToSlice returns the (possibly partially known) shape represented by s as a
|
||||
// slice, or an error if the number of dimensions is not known.
|
||||
func (s Shape) ToSlice() ([]int64, error) {
|
||||
if s.dims == nil {
|
||||
return nil, fmt.Errorf("cannot create a slice for a Shape with an unknown number of dimensions")
|
||||
}
|
||||
cpy := make([]int64, len(s.dims))
|
||||
copy(cpy, s.dims)
|
||||
return cpy, nil
|
||||
}
|
||||
|
||||
func (s Shape) String() string {
|
||||
if s.dims == nil {
|
||||
return "?"
|
||||
}
|
||||
ret := fmt.Sprint(s.dims)
|
||||
for _, size := range s.dims {
|
||||
if size < 0 {
|
||||
ret = strings.Replace(ret, fmt.Sprint(size), "?", 1)
|
||||
}
|
||||
}
|
||||
return strings.Replace(ret, " ", ", ", -1)
|
||||
}
|
83
tensorflow/go/shape_test.go
Normal file
83
tensorflow/go/shape_test.go
Normal file
@ -0,0 +1,83 @@
|
||||
// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package tensorflow
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestShape(t *testing.T) {
|
||||
tests := []struct {
|
||||
shape Shape
|
||||
slice []int64
|
||||
full bool
|
||||
str string
|
||||
}{
|
||||
{
|
||||
shape: ScalarShape(),
|
||||
slice: make([]int64, 0),
|
||||
full: true,
|
||||
str: "[]",
|
||||
},
|
||||
{
|
||||
shape: MakeShape(-1, 2, -1, 4),
|
||||
slice: []int64{-1, 2, -1, 4},
|
||||
full: false,
|
||||
str: "[?, 2, ?, 4]",
|
||||
},
|
||||
{
|
||||
shape: MakeShape(2, 3),
|
||||
slice: []int64{2, 3},
|
||||
full: true,
|
||||
str: "[2, 3]",
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(fmt.Sprintf("%#v", test.shape), func(t *testing.T) {
|
||||
if got, want := test.shape.NumDimensions(), len(test.slice); got != want {
|
||||
t.Errorf("Got %v, want %v", got, want)
|
||||
}
|
||||
if gotSlice, err := test.shape.ToSlice(); err != nil || !reflect.DeepEqual(gotSlice, test.slice) {
|
||||
t.Errorf("Got (%#v, %v), want (%#v, nil)", gotSlice, err, test.slice)
|
||||
}
|
||||
if got, want := test.shape.IsFullySpecified(), test.full; got != want {
|
||||
t.Errorf("Got %v, want %v", got, want)
|
||||
}
|
||||
if got, want := test.shape.String(), test.str; got != want {
|
||||
t.Errorf("Got %v, want %v", got, want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestZeroShape(t *testing.T) {
|
||||
var s Shape
|
||||
if s.NumDimensions() != -1 {
|
||||
t.Error(s.NumDimensions())
|
||||
}
|
||||
if _, err := s.ToSlice(); err == nil {
|
||||
t.Error("ToSlice() on a Shape of unknown number of dimensions should fail")
|
||||
}
|
||||
if s.IsFullySpecified() {
|
||||
t.Error("Shape of unknown number of dimensions should not be fully specified")
|
||||
}
|
||||
if got, want := s.String(), "?"; got != want {
|
||||
t.Errorf("Got %q, want %q", got, want)
|
||||
}
|
||||
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user