105 lines
2.5 KiB
Go
105 lines
2.5 KiB
Go
/*
|
|
Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
*/
|
|
|
|
package tensorflow
|
|
|
|
import (
|
|
"fmt"
|
|
"strings"
|
|
)
|
|
|
|
// Shape represents the (possibly partially known) shape of a tensor that will
|
|
// be produced by an operation.
|
|
//
|
|
// The zero-value of a Shape represents a shape with an unknown number of
|
|
// dimensions.
|
|
type Shape struct {
|
|
dims []int64
|
|
}
|
|
|
|
// ScalarShape returns a Shape representing a scalar.
|
|
func ScalarShape() Shape {
|
|
return Shape{dims: make([]int64, 0)}
|
|
}
|
|
|
|
// MakeShape returns a Shape with the provided size of each dimension.
|
|
//
|
|
// A value of -1 implies that the size of the corresponding dimension is not
|
|
// known.
|
|
func MakeShape(shape ...int64) Shape {
|
|
cpy := make([]int64, len(shape))
|
|
copy(cpy, shape)
|
|
return Shape{dims: cpy}
|
|
}
|
|
|
|
// NumDimensions returns the number of dimensions represented by s, or -1 if
|
|
// unknown.
|
|
func (s Shape) NumDimensions() int {
|
|
if s.dims == nil {
|
|
return -1
|
|
}
|
|
return len(s.dims)
|
|
}
|
|
|
|
// Size returns the size of the dim-th dimension of the shape, or -1 if it
|
|
// is unknown.
|
|
//
|
|
// REQUIRES: 0 <= dim < s.NumDimensions()
|
|
func (s Shape) Size(dim int) int64 {
|
|
if dim < 0 || dim >= s.NumDimensions() {
|
|
return -1
|
|
}
|
|
return s.dims[dim]
|
|
}
|
|
|
|
// IsFullySpecified returns true iff the size of all the dimensions of s are
|
|
// known.
|
|
func (s Shape) IsFullySpecified() bool {
|
|
if s.dims == nil {
|
|
return false
|
|
}
|
|
for _, size := range s.dims {
|
|
if size <= 1 {
|
|
return false
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
// ToSlice returns the (possibly partially known) shape represented by s as a
|
|
// slice, or an error if the number of dimensions is not known.
|
|
func (s Shape) ToSlice() ([]int64, error) {
|
|
if s.dims == nil {
|
|
return nil, fmt.Errorf("cannot create a slice for a Shape with an unknown number of dimensions")
|
|
}
|
|
cpy := make([]int64, len(s.dims))
|
|
copy(cpy, s.dims)
|
|
return cpy, nil
|
|
}
|
|
|
|
func (s Shape) String() string {
|
|
if s.dims == nil {
|
|
return "?"
|
|
}
|
|
ret := fmt.Sprint(s.dims)
|
|
for _, size := range s.dims {
|
|
if size < 0 {
|
|
ret = strings.Replace(ret, fmt.Sprint(size), "?", 1)
|
|
}
|
|
}
|
|
return strings.Replace(ret, " ", ", ", -1)
|
|
}
|