Avoid some over-inlined routines. Reduces code size of TensorFlow binaries

considerably.  Shrinks text size of example_trainer binary by ~1.5%.
Change: 115578002
This commit is contained in:
A. Unique TensorFlower 2016-02-25 10:37:06 -08:00 committed by TensorFlower Gardener
parent 63bd3efc5c
commit 9ccc4b6afe
13 changed files with 279 additions and 231 deletions

View File

@ -90,6 +90,8 @@ OpKernel::OpKernel(OpKernelConstruction* context)
&output_name_map_));
}
OpKernel::~OpKernel() {}
Status OpKernel::InputRange(const string& input_name, int* start,
int* stop) const {
const auto result = input_name_map_.find(input_name);
@ -172,6 +174,10 @@ Status OpKernelConstruction::allocate_persistent(
return s;
}
void OpKernelConstruction::SetStatus(const Status& status) {
status_->Update(status);
}
// OpKernelContext -----------------------------------------------------------
OpKernelContext::OpKernelContext(Params* params)
@ -194,6 +200,29 @@ OpKernelContext::~OpKernelContext() {
}
}
Allocator* OpKernelContext::get_allocator(AllocatorAttributes attr) {
Allocator* allocator =
params_->device->GetStepAllocator(attr, step_resource_manager());
if (params_->track_allocations) {
mutex_lock lock(mu_);
for (const auto& wrapped : wrapped_allocators_) {
if (wrapped.first == allocator) {
return wrapped.second;
}
}
TrackingAllocator* wrapped_allocator =
new TrackingAllocator(allocator, attr.track_sizes());
wrapped_allocators_.push_back(std::make_pair(allocator, wrapped_allocator));
return wrapped_allocator;
} else {
return allocator;
}
}
void OpKernelContext::SetStatus(const Status& status) {
status_.Update(status);
}
Status OpKernelContext::input(const string& name, const Tensor** tensor) {
int start, stop;
TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));

View File

@ -69,7 +69,7 @@ class OpKernel {
// OpKernel won't be instantiated by the scheduler, so you may perform
// expensive initialization in the descendant's constructor.
explicit OpKernel(OpKernelConstruction* context);
virtual ~OpKernel() {}
virtual ~OpKernel();
// An OpKernel's computation can be either synchronous or
// asynchronous.
@ -287,7 +287,7 @@ class OpKernelConstruction {
const DataTypeSlice expected_outputs);
// For recording configuration errors during construction.
void SetStatus(const Status& status) { status_->Update(status); }
void SetStatus(const Status& status);
const Status& status() const { return *status_; }
// Look up the attr with name attr_name and set *value to its value. If no
@ -874,7 +874,7 @@ class OpKernelContext {
// An OpKernel should call SetStatus() if Compute() encounters an
// error.
void SetStatus(const Status& status) { status_.Update(status); }
void SetStatus(const Status& status);
const Status& status() const { return status_; }
// Cancellation.
@ -907,25 +907,7 @@ class OpKernelContext {
}
private:
Allocator* get_allocator(AllocatorAttributes attr) {
Allocator* allocator =
params_->device->GetStepAllocator(attr, step_resource_manager());
if (params_->track_allocations) {
mutex_lock lock(mu_);
for (const auto& wrapped : wrapped_allocators_) {
if (wrapped.first == allocator) {
return wrapped.second;
}
}
TrackingAllocator* wrapped_allocator =
new TrackingAllocator(allocator, attr.track_sizes());
wrapped_allocators_.push_back(
std::make_pair(allocator, wrapped_allocator));
return wrapped_allocator;
} else {
return allocator;
}
}
Allocator* get_allocator(AllocatorAttributes attr);
// Internal method to add a tensor's buffer to the list of buffers
// referenced during the execution of the Op, so that GPUs may

View File

@ -0,0 +1,10 @@
#include "tensorflow/core/framework/tensor_reference.h"
namespace tensorflow {
TensorReference::TensorReference(const Tensor& tensor)
: buf_(tensor.buf_ ? tensor.buf_->root_buffer() : nullptr) {
if (buf_) buf_->Ref();
}
} // namespace tensorflow

View File

@ -31,10 +31,7 @@ namespace tensorflow {
class TensorReference {
public:
// Take the reference of the root buffer so the size will be more accurate
explicit TensorReference(const Tensor& tensor)
: buf_(tensor.buf_ ? tensor.buf_->root_buffer() : nullptr) {
if (buf_) buf_->Ref();
}
explicit TensorReference(const Tensor& tensor);
~TensorReference() {}

View File

@ -109,6 +109,18 @@ void TensorShape::SlowCopyFrom(const TensorShape& b) {
}
}
int64 TensorShape::dim_size(int d) const {
DCHECK_GE(d, 0);
DCHECK_LT(d, dims());
if (tag() == REP16) {
return as16()->dims_[d];
} else if (tag() == REP32) {
return as32()->dims_[d];
} else {
return (*as64()->dims_)[d];
}
}
void TensorShape::Clear() {
ClearAllButDataType();
set_data_type(DT_INVALID);

View File

@ -87,24 +87,15 @@ class TensorShape {
/// Return the number of dimensions in the tensor.
int dims() const {
return (tag() == REP_OUT_OF_LINE) ? (*as64()->dims_).size() : ndims_byte();
DCHECK(tag() != REP_OUT_OF_LINE || (*as64()->dims_).size() == ndims_byte());
return ndims_byte();
}
/// \brief Returns the number of elements in dimension `d`.
/// REQUIRES: `0 <= d < dims()`
// TODO(touts): Rename to `dimension()` to match
// `Eigen::Tensor::dimension()`?
int64 dim_size(int d) const {
DCHECK_GE(d, 0);
DCHECK_LT(d, dims());
if (tag() == REP16) {
return as16()->dims_[d];
} else if (tag() == REP32) {
return as32()->dims_[d];
} else {
return (*as64()->dims_)[d];
}
}
int64 dim_size(int d) const;
/// Returns sizes of all dimensions.
gtl::InlinedVector<int64, 4> dim_sizes() const;

View File

@ -0,0 +1,76 @@
#include "tensorflow/core/framework/unique_tensor_references.h"
namespace tensorflow {
UniqueTensorReferences::~UniqueTensorReferences() {
if (!frozen_) {
// The references were not retrieved so discard them to avoid
// leaking memory.
TensorReferenceVector refs;
FreezeAndReturnReferences(&refs);
for (auto& tensor : refs) {
tensor.Unref();
}
}
delete referenced_tensors_set_;
}
void UniqueTensorReferences::Add(const Tensor& tensor) {
DCHECK(!frozen_);
// Do nothing if the tensor has a null buffer.
if (tensor.IsInitialized()) {
if (referenced_tensors_set_ != nullptr) {
// There are enough tensors that we are using a hash set to
// de-duplicate.
const TensorReference tensor_ref(tensor);
if (!referenced_tensors_set_->insert(tensor_ref).second) {
// The tensor was a duplicate, so discard the reference.
tensor_ref.Unref();
}
} else {
for (size_t i = 0; i < referenced_tensors_vector_.size(); ++i) {
if (referenced_tensors_vector_[i].SharesBufferWith(tensor)) {
// tensor is a duplicate, so nothing to do.
return;
}
}
referenced_tensors_vector_.push_back(TensorReference(tensor));
if (kInVector == referenced_tensors_vector_.size()) {
// There are too many tensors to keep using the N^2 algorithm
// so start de-duplicating using a set.
// Transfer the refs from the vector to the set.
DCHECK(referenced_tensors_set_ == nullptr);
referenced_tensors_set_ = new ReferencedTensorsSet;
referenced_tensors_set_->reserve(kInVector);
referenced_tensors_set_->insert(referenced_tensors_vector_.begin(),
referenced_tensors_vector_.end());
DCHECK_EQ(kInVector, referenced_tensors_set_->size());
referenced_tensors_vector_.clear();
}
}
}
}
void UniqueTensorReferences::FreezeAndReturnReferences(
TensorReferenceVector* out_vector) {
// Prevent any further additions.
frozen_ = true;
if (referenced_tensors_set_ != nullptr) {
DCHECK(referenced_tensors_vector_.empty());
out_vector->reserve(referenced_tensors_set_->size());
for (const auto& ref : *referenced_tensors_set_) {
out_vector->push_back(ref);
}
referenced_tensors_set_->clear();
delete referenced_tensors_set_;
referenced_tensors_set_ = nullptr;
} else {
out_vector->reserve(referenced_tensors_vector_.size());
for (const auto& ref : referenced_tensors_vector_) {
out_vector->push_back(ref);
}
referenced_tensors_vector_.clear();
}
}
} // namespace tensorflow

View File

@ -35,78 +35,14 @@ class UniqueTensorReferences {
public:
UniqueTensorReferences() : frozen_(false), referenced_tensors_set_(nullptr) {}
~UniqueTensorReferences() {
if (!frozen_) {
// The references were not retrieved so discard them to avoid
// leaking memory.
TensorReferenceVector refs;
FreezeAndReturnReferences(&refs);
for (auto& tensor : refs) {
tensor.Unref();
}
}
delete referenced_tensors_set_;
}
~UniqueTensorReferences();
// Adds a reference to tensor if its buffer is not already referenced.
void Add(const Tensor& tensor) {
DCHECK(!frozen_);
// Do nothing if the tensor has a null buffer.
if (tensor.IsInitialized()) {
if (referenced_tensors_set_ != nullptr) {
// There are enough tensors that we are using a hash set to
// de-duplicate.
const TensorReference tensor_ref(tensor);
if (!referenced_tensors_set_->insert(tensor_ref).second) {
// The tensor was a duplicate, so discard the reference.
tensor_ref.Unref();
}
} else {
for (size_t i = 0; i < referenced_tensors_vector_.size(); ++i) {
if (referenced_tensors_vector_[i].SharesBufferWith(tensor)) {
// tensor is a duplicate, so nothing to do.
return;
}
}
referenced_tensors_vector_.push_back(TensorReference(tensor));
if (kInVector == referenced_tensors_vector_.size()) {
// There are too many tensors to keep using the N^2 algorithm
// so start de-duplicating using a set.
// Transfer the refs from the vector to the set.
DCHECK(referenced_tensors_set_ == nullptr);
referenced_tensors_set_ = new ReferencedTensorsSet;
referenced_tensors_set_->reserve(kInVector);
referenced_tensors_set_->insert(referenced_tensors_vector_.begin(),
referenced_tensors_vector_.end());
DCHECK_EQ(kInVector, referenced_tensors_set_->size());
referenced_tensors_vector_.clear();
}
}
}
}
void Add(const Tensor& tensor);
// No more references may be added after this is called. The unique
// references are returning in out_vector.
void FreezeAndReturnReferences(TensorReferenceVector* out_vector) {
// Prevent any further additions.
frozen_ = true;
if (referenced_tensors_set_ != nullptr) {
DCHECK(referenced_tensors_vector_.empty());
out_vector->reserve(referenced_tensors_set_->size());
for (const auto& ref : *referenced_tensors_set_) {
out_vector->push_back(ref);
}
referenced_tensors_set_->clear();
delete referenced_tensors_set_;
referenced_tensors_set_ = nullptr;
} else {
out_vector->reserve(referenced_tensors_vector_.size());
for (const auto& ref : referenced_tensors_vector_) {
out_vector->push_back(ref);
}
referenced_tensors_vector_.clear();
}
}
void FreezeAndReturnReferences(TensorReferenceVector* out_vector);
private:
// Up to kInVector elements are stored in reference_tensors_vector_

View File

@ -21,6 +21,13 @@ limitations under the License.
namespace tensorflow {
NodeBuilder::NodeOut::NodeOut(Node* n, int i) // NOLINT(runtime/explicit)
: node(n),
error(false),
name(node != nullptr ? node->name() : (error = true, "")),
index(i),
dt(SafeGetOutput(node, i, &error)) {}
NodeBuilder::NodeBuilder(const string& name, const string& op_name,
const OpRegistryInterface* op_registry)
: def_builder_(name, op_name, op_registry) {}

View File

@ -48,12 +48,7 @@ class NodeBuilder {
// ArraySlice.
struct NodeOut {
// For referencing an existing Node.
NodeOut(Node* n, int i = 0) // NOLINT(runtime/explicit)
: node(n),
error(false),
name(node != nullptr ? node->name() : (error = true, "")),
index(i),
dt(SafeGetOutput(node, i, &error)) {}
NodeOut(Node* n, int i = 0);
// For referencing Nodes not in the graph being built. It is
// useful when preparing a graph for ExtendSession or creating a

View File

@ -1074,6 +1074,7 @@ filegroup(
"io.cc",
"lrn_op.cc",
"maxpooling_op.cc",
"reduction_ops_common.cc",
"reduction_ops_max.cc",
"reduction_ops_mean.cc",
"reduction_ops_min.cc",

View File

@ -0,0 +1,127 @@
#include "tensorflow/core/kernels/reduction_ops_common.h"
namespace tensorflow {
TensorShape ReductionHelper::out_reshape() const {
TensorShape shape;
for (auto size : out_reshape_) shape.AddDim(size);
return shape;
}
// The final output shape must be allocated with this shape.
TensorShape ReductionHelper::out_shape() const {
TensorShape shape;
for (auto size : out_shape_) shape.AddDim(size);
return shape;
}
TensorShape ReductionHelper::shuffled_shape() {
const int dims = data_reshape_.size();
TensorShape shape;
for (int i = reduce_first_axis_; i < dims; i += 2) {
shape.AddDim(data_reshape_[i]);
}
for (int i = !reduce_first_axis_; i < dims; i += 2) {
shape.AddDim(data_reshape_[i]);
}
return shape;
}
gtl::InlinedVector<int32, 8> ReductionHelper::permutation() {
const int dims = data_reshape_.size();
const int unreduced_dims = (dims + !reduce_first_axis_) / 2;
gtl::InlinedVector<int32, 8> perm(dims);
for (int i = 0; i < unreduced_dims; i++) {
perm[i] = 2 * i + reduce_first_axis_;
}
for (int i = unreduced_dims; i < dims; i++) {
perm[i] = 2 * (i - unreduced_dims) + !reduce_first_axis_;
}
return perm;
}
Status ReductionHelper::Simplify(const Tensor& data, const Tensor& axis,
const bool keep_dims) {
// bitmap[i] indicates whether to reduce data along i-th axis.
gtl::InlinedVector<bool, 4> bitmap(data.dims(), false);
auto axis_vec = axis.flat<int32>();
for (int64 i = 0; i < axis.NumElements(); ++i) {
const int32 index = axis_vec(i);
if (index < 0 || index >= data.dims()) {
return errors::OutOfRange("Invalid reduction dimension (", index,
" for input with ", data.dims(),
" dimension(s)");
}
bitmap[index] = true;
}
// Output tensor's dim sizes.
out_shape_.clear();
for (int i = 0; i < data.dims(); ++i) {
if (!bitmap[i]) {
// If we are not reducing along dimension i.
out_shape_.push_back(data.dim_size(i));
} else if (keep_dims) {
// We are reducing along dimension i, but we want to keep the
// same number of dimensions, so we set the dimension of i to
// '1'.
out_shape_.push_back(1);
}
}
// Depending on bitmap[i] and bitmap[i-1], we can collapse axis of
// the input data before doing the reduction on the resulting
// tensor. The shape of the reduction is a reshape of the final
// output.
// We'll skip the leading 1s.
int dim_index = 0;
for (; dim_index < data.dims(); ++dim_index) {
if (data.dim_size(dim_index) != 1) break;
}
if (dim_index >= data.dims()) {
// Special case. The input is essentially a scalar.
reduce_first_axis_ = true;
} else {
// Starting from the (dim_index)-th dimension, dimensions
// alternates between runs that need to be reduced and runs that
// don't.
//
// NOTE: If a dimension has size 1, we group it as the current
// run so that we can minimize the number of runs.
//
// E.g., when we want to reduce a tensor of shape [2, 1, 3, 1,
// 5] by axes = [1, 4], we should treat the tensor as a [6, 5]
// and reduce by axes = [1] (i.e., the output is shape [6]).
reduce_first_axis_ = bitmap[dim_index];
data_reshape_.push_back(data.dim_size(dim_index));
++dim_index;
for (; dim_index < data.dims(); ++dim_index) {
const auto size = data.dim_size(dim_index);
if (size == 1) {
bitmap[dim_index] = bitmap[dim_index - 1];
}
if (bitmap[dim_index - 1] != bitmap[dim_index]) {
// Starts a new run of reduce or !reduce.
data_reshape_.push_back(size);
} else {
// Continue a run of reduce or !reduce.
data_reshape_.back() *= size;
}
}
// If reduce_first_axis_ is true (input's dimension 0, 2, 4, etc
// are reduced), data_reshape_[1, 3, 5, ...] is out_reshape_,
// otherwise, data_reshape_[0, 2, 4, ...] is.
for (size_t i = reduce_first_axis_ ? 1 : 0; i < data_reshape_.size();
i += 2) {
out_reshape_.push_back(data_reshape_[i]);
}
}
VLOG(1) << "data reshape: " << str_util::Join(data_reshape_, ",");
VLOG(1) << "out reshape: " << str_util::Join(out_reshape_, ",");
VLOG(1) << "out shape: " << str_util::Join(out_shape_, ",");
return Status::OK();
}
} // namespace tensorflow

View File

@ -68,95 +68,11 @@ struct Constants<CPUDevice> {
};
#endif
namespace {
class ReductionHelper {
public:
ReductionHelper() : reduce_first_axis_(false) {}
Status Simplify(const Tensor& data, const Tensor& axis,
const bool keep_dims) {
// bitmap[i] indicates whether to reduce data along i-th axis.
gtl::InlinedVector<bool, 4> bitmap(data.dims(), false);
auto axis_vec = axis.flat<int32>();
for (int64 i = 0; i < axis.NumElements(); ++i) {
const int32 index = axis_vec(i);
if (index < 0 || index >= data.dims()) {
return errors::OutOfRange("Invalid reduction dimension (", index,
" for input with ", data.dims(),
" dimension(s)");
}
bitmap[index] = true;
}
// Output tensor's dim sizes.
out_shape_.clear();
for (int i = 0; i < data.dims(); ++i) {
if (!bitmap[i]) {
// If we are not reducing along dimension i.
out_shape_.push_back(data.dim_size(i));
} else if (keep_dims) {
// We are reducing along dimension i, but we want to keep the
// same number of dimensions, so we set the dimension of i to
// '1'.
out_shape_.push_back(1);
}
}
// Depending on bitmap[i] and bitmap[i-1], we can collapse axis of
// the input data before doing the reduction on the resulting
// tensor. The shape of the reduction is a reshape of the final
// output.
// We'll skip the leading 1s.
int dim_index = 0;
for (; dim_index < data.dims(); ++dim_index) {
if (data.dim_size(dim_index) != 1) break;
}
if (dim_index >= data.dims()) {
// Special case. The input is essentially a scalar.
reduce_first_axis_ = true;
} else {
// Starting from the (dim_index)-th dimension, dimensions
// alternates between runs that need to be reduced and runs that
// don't.
//
// NOTE: If a dimension has size 1, we group it as the current
// run so that we can minimize the number of runs.
//
// E.g., when we want to reduce a tensor of shape [2, 1, 3, 1,
// 5] by axes = [1, 4], we should treat the tensor as a [6, 5]
// and reduce by axes = [1] (i.e., the output is shape [6]).
reduce_first_axis_ = bitmap[dim_index];
data_reshape_.push_back(data.dim_size(dim_index));
++dim_index;
for (; dim_index < data.dims(); ++dim_index) {
const auto size = data.dim_size(dim_index);
if (size == 1) {
bitmap[dim_index] = bitmap[dim_index - 1];
}
if (bitmap[dim_index - 1] != bitmap[dim_index]) {
// Starts a new run of reduce or !reduce.
data_reshape_.push_back(size);
} else {
// Continue a run of reduce or !reduce.
data_reshape_.back() *= size;
}
}
// If reduce_first_axis_ is true (input's dimension 0, 2, 4, etc
// are reduced), data_reshape_[1, 3, 5, ...] is out_reshape_,
// otherwise, data_reshape_[0, 2, 4, ...] is.
for (size_t i = reduce_first_axis_ ? 1 : 0; i < data_reshape_.size();
i += 2) {
out_reshape_.push_back(data_reshape_[i]);
}
}
VLOG(1) << "data reshape: " << str_util::Join(data_reshape_, ",");
VLOG(1) << "out reshape: " << str_util::Join(out_reshape_, ",");
VLOG(1) << "out shape: " << str_util::Join(out_shape_, ",");
return Status::OK();
}
Status Simplify(const Tensor& data, const Tensor& axis, const bool keep_dims);
// We need to do roughly:
// tmp_out = allocate(out_reshape())
@ -164,18 +80,10 @@ class ReductionHelper {
// out = tmp_out.reshape(out_shape)
// The reduction result must be allocated with this shape.
TensorShape out_reshape() const {
TensorShape shape;
for (auto size : out_reshape_) shape.AddDim(size);
return shape;
}
TensorShape out_reshape() const;
// The final output shape must be allocated with this shape.
TensorShape out_shape() const {
TensorShape shape;
for (auto size : out_shape_) shape.AddDim(size);
return shape;
}
TensorShape out_shape() const;
// The reduction is on a reshaped tensor of this rank.
int ndims() const { return data_reshape_.size(); }
@ -203,31 +111,10 @@ class ReductionHelper {
}
// Shape with all reduction dimensions at the end
TensorShape shuffled_shape() {
const int dims = data_reshape_.size();
TensorShape shape;
for (int i = reduce_first_axis_; i < dims; i += 2) {
shape.AddDim(data_reshape_[i]);
}
for (int i = !reduce_first_axis_; i < dims; i += 2) {
shape.AddDim(data_reshape_[i]);
}
return shape;
}
TensorShape shuffled_shape();
// Permutation of reduced dims needed to put reduction dimensions at the end
gtl::InlinedVector<int32, 8> permutation() {
const int dims = data_reshape_.size();
const int unreduced_dims = (dims + !reduce_first_axis_) / 2;
gtl::InlinedVector<int32, 8> perm(dims);
for (int i = 0; i < unreduced_dims; i++) {
perm[i] = 2 * i + reduce_first_axis_;
}
for (int i = unreduced_dims; i < dims; i++) {
perm[i] = 2 * (i - unreduced_dims) + !reduce_first_axis_;
}
return perm;
}
gtl::InlinedVector<int32, 8> permutation();
private:
bool reduce_first_axis_; // True if need to reduce the 0-th dimension.
@ -236,8 +123,6 @@ class ReductionHelper {
gtl::InlinedVector<int64, 4> out_reshape_; // Reshape output for reduction.
};
} // end namespace
// For operations where the output is a reduction function along some
// dimensions of the input.
template <typename Device, class T, typename Reducer>