187 lines
7.3 KiB
C++
187 lines
7.3 KiB
C++
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_RESOURCE_H_
|
|
#define TENSORFLOW_COMPILER_TF2XLA_XLA_RESOURCE_H_
|
|
|
|
#include <memory>
|
|
|
|
#include "absl/strings/string_view.h"
|
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
|
#include "tensorflow/core/framework/tensor_shape.h"
|
|
#include "tensorflow/core/framework/types.pb.h"
|
|
#include "tensorflow/core/lib/core/status.h"
|
|
|
|
namespace tensorflow {
|
|
|
|
// Represents a resource, such as a Variable or TensorArray.
|
|
class XlaResource {
|
|
public:
|
|
enum Kind {
|
|
kInvalid,
|
|
kVariable,
|
|
kTensorArray,
|
|
kStack,
|
|
};
|
|
static absl::string_view KindToString(Kind kind);
|
|
|
|
// Creates a new Stack resource.
|
|
static std::unique_ptr<XlaResource> CreateStack(string name, DataType type,
|
|
int64 max_size);
|
|
|
|
// Creates a new TensorArray resource.
|
|
static std::unique_ptr<XlaResource> CreateTensorArray(
|
|
string name, DataType type, TensorShape shape, xla::XlaOp initial_value,
|
|
int64 max_array_size);
|
|
|
|
XlaResource(Kind kind, int arg_num, string name, DataType type,
|
|
TensorShape shape, const xla::XlaOp& initial_value,
|
|
int64 max_array_size,
|
|
const std::set<string>& tensor_array_gradients,
|
|
bool tensor_array_multiple_writes_aggregate);
|
|
|
|
XlaResource(const XlaResource&) = delete;
|
|
XlaResource(XlaResource&&) = delete;
|
|
XlaResource& operator=(const XlaResource&) = delete;
|
|
XlaResource& operator=(XlaResource&&) = delete;
|
|
|
|
Kind kind() const { return kind_; }
|
|
|
|
// If this resource is visible externally to the computation, what was its
|
|
// argument number?
|
|
// < 0 means "not visible externally".
|
|
int arg_num() const { return arg_num_; }
|
|
|
|
// A descriptive name for the resource, used in error messages.
|
|
const string& name() const { return name_; }
|
|
|
|
// Current type and value of the resource. Uninitialized resources are
|
|
// represented by a default (zero) handle and type DT_INVALID.
|
|
// While the type of a resource is notionally fixed during execution, when
|
|
// a resource is first initialized we do not yet know its type, so we keep
|
|
// track of its type dynamically.
|
|
DataType type() const { return type_; }
|
|
|
|
// Shape of the resource. For an uninitialized resource, this is ignored.
|
|
// For a Variable, this is the shape of the value. For a TensorArray or Stack
|
|
// this is the shape of each entry in the TensorArray/Stack.
|
|
const TensorShape& shape() const { return shape_; }
|
|
|
|
const xla::XlaOp& value() const { return value_; }
|
|
|
|
// Value of the resource at computation entry. Used to detect which
|
|
// variables have new values that need to be written back.
|
|
const xla::XlaOp& initial_value() const { return initial_value_; }
|
|
|
|
// An xla shape that indicates how this resource variable is represented on
|
|
// device.
|
|
const absl::optional<xla::Shape>& representation_shape() const {
|
|
return representation_shape_;
|
|
}
|
|
|
|
// A variable is initialized if it has a value.
|
|
bool initialized() const { return value_.valid(); }
|
|
|
|
// Sets the type and shape of the resource. The type and shape of a resource
|
|
// must not change once the variable has been initialized.
|
|
Status SetTypeAndShape(DataType type, const TensorShape& shape);
|
|
|
|
// Sets the current value of the resource. Returns an error if the type is not
|
|
// set to a valid value.
|
|
Status SetValue(const xla::XlaOp& value);
|
|
|
|
// Sets the current value of the resource to an all-zero value.
|
|
Status SetZeroValue(xla::XlaBuilder* builder);
|
|
|
|
// Sets the representational shape of the resource on device.
|
|
void SetRepresentationShape(const xla::Shape& shape) {
|
|
representation_shape_ = absl::make_optional(shape);
|
|
}
|
|
|
|
// Looks up the gradient for `source`, or creates it if it does not already
|
|
// exist. The call target must be an initialized TensorArray resource. A
|
|
// TensorArray can have multiple named gradients; see the operator
|
|
// documentation for TensorArrayGradV3 for details.
|
|
Status GetOrCreateTensorArrayGradient(const string& source,
|
|
xla::XlaBuilder* builder,
|
|
XlaResource** gradient_out);
|
|
|
|
// Packs a resource into a single XLA value `pack`, suitable for use as
|
|
// an XlaCompiler::Argument. For non-TensorArrays or TensorArrays without
|
|
// gradients, sets `*pack` to `value`.
|
|
// For TensorArrays with gradients, packs the value and its gradient values in
|
|
// a tuple; the gradients values are packed in order by source name.
|
|
Status Pack(xla::XlaOp* pack, xla::XlaBuilder* builder) const;
|
|
|
|
// Updates the resource with values from `pack`. If `gradient_sources` is
|
|
// non-empty, treats `pack` as a tuple that represents a TensorArray and
|
|
// its gradients, and unpacks and updates the gradient resources.
|
|
// If `reset_initial_values` is true, sets the initial_values as well as the
|
|
// values.
|
|
// Opposite of Pack().
|
|
Status SetFromPack(const std::set<string>& gradient_sources,
|
|
const xla::XlaOp& pack, xla::XlaBuilder* builder);
|
|
|
|
// TensorArray and Stack specific fields
|
|
// TODO(phawkins): refactor this code to use subclasses, rather than putting
|
|
// kind-specific fields in XlaResource.
|
|
|
|
// 'max_array_size' stores the expected size of the TensorArray or Stack.
|
|
// We need to store this since sometimes TensorArrays must be initialized
|
|
// lazily since we do not know the element shape at construction time.
|
|
// Used by both TensorArrays and Stacks.
|
|
int64 max_array_size() const { return max_array_size_; }
|
|
void set_max_array_size(int64 size) { max_array_size_ = size; }
|
|
|
|
bool tensor_array_multiple_writes_aggregate() const {
|
|
return tensor_array_multiple_writes_aggregate_;
|
|
}
|
|
|
|
// 'tensor_array_gradient' is a map from TensorArrayGradV3 'source' attributes
|
|
// to an XlaResource containing the gradient TensorArrays. We store a pointer
|
|
// here since there should only be one gradient TensorArray per 'source'
|
|
// string, irrespective of the number of calls to TensorArrayGrad. The map
|
|
// is ordered since values are packed into tuples by Pack() sorted by name
|
|
// order.
|
|
const std::map<string, std::unique_ptr<XlaResource>>& tensor_array_gradients()
|
|
const {
|
|
return tensor_array_gradients_;
|
|
}
|
|
|
|
private:
|
|
const Kind kind_;
|
|
const int arg_num_;
|
|
const string name_;
|
|
|
|
DataType type_;
|
|
TensorShape shape_;
|
|
xla::XlaOp value_;
|
|
xla::XlaOp initial_value_;
|
|
|
|
// An xla shape that indicates how this resource variable is represented on
|
|
// device.
|
|
absl::optional<xla::Shape> representation_shape_;
|
|
|
|
int64 max_array_size_ = -1;
|
|
bool tensor_array_multiple_writes_aggregate_ = false;
|
|
|
|
std::map<string, std::unique_ptr<XlaResource>> tensor_array_gradients_;
|
|
};
|
|
|
|
} // namespace tensorflow
|
|
|
|
#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_RESOURCE_H_
|