Avoid some over-inlined routines. Reduces code size of TensorFlow binaries
considerably. Shrinks text size of example_trainer binary by ~1.5%. Change: 115578002
This commit is contained in:
parent
63bd3efc5c
commit
9ccc4b6afe
@ -90,6 +90,8 @@ OpKernel::OpKernel(OpKernelConstruction* context)
|
|||||||
&output_name_map_));
|
&output_name_map_));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
OpKernel::~OpKernel() {}
|
||||||
|
|
||||||
Status OpKernel::InputRange(const string& input_name, int* start,
|
Status OpKernel::InputRange(const string& input_name, int* start,
|
||||||
int* stop) const {
|
int* stop) const {
|
||||||
const auto result = input_name_map_.find(input_name);
|
const auto result = input_name_map_.find(input_name);
|
||||||
@ -172,6 +174,10 @@ Status OpKernelConstruction::allocate_persistent(
|
|||||||
return s;
|
return s;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void OpKernelConstruction::SetStatus(const Status& status) {
|
||||||
|
status_->Update(status);
|
||||||
|
}
|
||||||
|
|
||||||
// OpKernelContext -----------------------------------------------------------
|
// OpKernelContext -----------------------------------------------------------
|
||||||
|
|
||||||
OpKernelContext::OpKernelContext(Params* params)
|
OpKernelContext::OpKernelContext(Params* params)
|
||||||
@ -194,6 +200,29 @@ OpKernelContext::~OpKernelContext() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Allocator* OpKernelContext::get_allocator(AllocatorAttributes attr) {
|
||||||
|
Allocator* allocator =
|
||||||
|
params_->device->GetStepAllocator(attr, step_resource_manager());
|
||||||
|
if (params_->track_allocations) {
|
||||||
|
mutex_lock lock(mu_);
|
||||||
|
for (const auto& wrapped : wrapped_allocators_) {
|
||||||
|
if (wrapped.first == allocator) {
|
||||||
|
return wrapped.second;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TrackingAllocator* wrapped_allocator =
|
||||||
|
new TrackingAllocator(allocator, attr.track_sizes());
|
||||||
|
wrapped_allocators_.push_back(std::make_pair(allocator, wrapped_allocator));
|
||||||
|
return wrapped_allocator;
|
||||||
|
} else {
|
||||||
|
return allocator;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void OpKernelContext::SetStatus(const Status& status) {
|
||||||
|
status_.Update(status);
|
||||||
|
}
|
||||||
|
|
||||||
Status OpKernelContext::input(const string& name, const Tensor** tensor) {
|
Status OpKernelContext::input(const string& name, const Tensor** tensor) {
|
||||||
int start, stop;
|
int start, stop;
|
||||||
TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
|
TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
|
||||||
|
@ -69,7 +69,7 @@ class OpKernel {
|
|||||||
// OpKernel won't be instantiated by the scheduler, so you may perform
|
// OpKernel won't be instantiated by the scheduler, so you may perform
|
||||||
// expensive initialization in the descendant's constructor.
|
// expensive initialization in the descendant's constructor.
|
||||||
explicit OpKernel(OpKernelConstruction* context);
|
explicit OpKernel(OpKernelConstruction* context);
|
||||||
virtual ~OpKernel() {}
|
virtual ~OpKernel();
|
||||||
|
|
||||||
// An OpKernel's computation can be either synchronous or
|
// An OpKernel's computation can be either synchronous or
|
||||||
// asynchronous.
|
// asynchronous.
|
||||||
@ -287,7 +287,7 @@ class OpKernelConstruction {
|
|||||||
const DataTypeSlice expected_outputs);
|
const DataTypeSlice expected_outputs);
|
||||||
|
|
||||||
// For recording configuration errors during construction.
|
// For recording configuration errors during construction.
|
||||||
void SetStatus(const Status& status) { status_->Update(status); }
|
void SetStatus(const Status& status);
|
||||||
const Status& status() const { return *status_; }
|
const Status& status() const { return *status_; }
|
||||||
|
|
||||||
// Look up the attr with name attr_name and set *value to its value. If no
|
// Look up the attr with name attr_name and set *value to its value. If no
|
||||||
@ -874,7 +874,7 @@ class OpKernelContext {
|
|||||||
|
|
||||||
// An OpKernel should call SetStatus() if Compute() encounters an
|
// An OpKernel should call SetStatus() if Compute() encounters an
|
||||||
// error.
|
// error.
|
||||||
void SetStatus(const Status& status) { status_.Update(status); }
|
void SetStatus(const Status& status);
|
||||||
const Status& status() const { return status_; }
|
const Status& status() const { return status_; }
|
||||||
|
|
||||||
// Cancellation.
|
// Cancellation.
|
||||||
@ -907,25 +907,7 @@ class OpKernelContext {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Allocator* get_allocator(AllocatorAttributes attr) {
|
Allocator* get_allocator(AllocatorAttributes attr);
|
||||||
Allocator* allocator =
|
|
||||||
params_->device->GetStepAllocator(attr, step_resource_manager());
|
|
||||||
if (params_->track_allocations) {
|
|
||||||
mutex_lock lock(mu_);
|
|
||||||
for (const auto& wrapped : wrapped_allocators_) {
|
|
||||||
if (wrapped.first == allocator) {
|
|
||||||
return wrapped.second;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
TrackingAllocator* wrapped_allocator =
|
|
||||||
new TrackingAllocator(allocator, attr.track_sizes());
|
|
||||||
wrapped_allocators_.push_back(
|
|
||||||
std::make_pair(allocator, wrapped_allocator));
|
|
||||||
return wrapped_allocator;
|
|
||||||
} else {
|
|
||||||
return allocator;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Internal method to add a tensor's buffer to the list of buffers
|
// Internal method to add a tensor's buffer to the list of buffers
|
||||||
// referenced during the execution of the Op, so that GPUs may
|
// referenced during the execution of the Op, so that GPUs may
|
||||||
|
10
tensorflow/core/framework/tensor_reference.cc
Normal file
10
tensorflow/core/framework/tensor_reference.cc
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
#include "tensorflow/core/framework/tensor_reference.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
TensorReference::TensorReference(const Tensor& tensor)
|
||||||
|
: buf_(tensor.buf_ ? tensor.buf_->root_buffer() : nullptr) {
|
||||||
|
if (buf_) buf_->Ref();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
@ -31,10 +31,7 @@ namespace tensorflow {
|
|||||||
class TensorReference {
|
class TensorReference {
|
||||||
public:
|
public:
|
||||||
// Take the reference of the root buffer so the size will be more accurate
|
// Take the reference of the root buffer so the size will be more accurate
|
||||||
explicit TensorReference(const Tensor& tensor)
|
explicit TensorReference(const Tensor& tensor);
|
||||||
: buf_(tensor.buf_ ? tensor.buf_->root_buffer() : nullptr) {
|
|
||||||
if (buf_) buf_->Ref();
|
|
||||||
}
|
|
||||||
|
|
||||||
~TensorReference() {}
|
~TensorReference() {}
|
||||||
|
|
||||||
|
@ -109,6 +109,18 @@ void TensorShape::SlowCopyFrom(const TensorShape& b) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int64 TensorShape::dim_size(int d) const {
|
||||||
|
DCHECK_GE(d, 0);
|
||||||
|
DCHECK_LT(d, dims());
|
||||||
|
if (tag() == REP16) {
|
||||||
|
return as16()->dims_[d];
|
||||||
|
} else if (tag() == REP32) {
|
||||||
|
return as32()->dims_[d];
|
||||||
|
} else {
|
||||||
|
return (*as64()->dims_)[d];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void TensorShape::Clear() {
|
void TensorShape::Clear() {
|
||||||
ClearAllButDataType();
|
ClearAllButDataType();
|
||||||
set_data_type(DT_INVALID);
|
set_data_type(DT_INVALID);
|
||||||
|
@ -87,24 +87,15 @@ class TensorShape {
|
|||||||
|
|
||||||
/// Return the number of dimensions in the tensor.
|
/// Return the number of dimensions in the tensor.
|
||||||
int dims() const {
|
int dims() const {
|
||||||
return (tag() == REP_OUT_OF_LINE) ? (*as64()->dims_).size() : ndims_byte();
|
DCHECK(tag() != REP_OUT_OF_LINE || (*as64()->dims_).size() == ndims_byte());
|
||||||
|
return ndims_byte();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// \brief Returns the number of elements in dimension `d`.
|
/// \brief Returns the number of elements in dimension `d`.
|
||||||
/// REQUIRES: `0 <= d < dims()`
|
/// REQUIRES: `0 <= d < dims()`
|
||||||
// TODO(touts): Rename to `dimension()` to match
|
// TODO(touts): Rename to `dimension()` to match
|
||||||
// `Eigen::Tensor::dimension()`?
|
// `Eigen::Tensor::dimension()`?
|
||||||
int64 dim_size(int d) const {
|
int64 dim_size(int d) const;
|
||||||
DCHECK_GE(d, 0);
|
|
||||||
DCHECK_LT(d, dims());
|
|
||||||
if (tag() == REP16) {
|
|
||||||
return as16()->dims_[d];
|
|
||||||
} else if (tag() == REP32) {
|
|
||||||
return as32()->dims_[d];
|
|
||||||
} else {
|
|
||||||
return (*as64()->dims_)[d];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns sizes of all dimensions.
|
/// Returns sizes of all dimensions.
|
||||||
gtl::InlinedVector<int64, 4> dim_sizes() const;
|
gtl::InlinedVector<int64, 4> dim_sizes() const;
|
||||||
|
76
tensorflow/core/framework/unique_tensor_references.cc
Normal file
76
tensorflow/core/framework/unique_tensor_references.cc
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
#include "tensorflow/core/framework/unique_tensor_references.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
UniqueTensorReferences::~UniqueTensorReferences() {
|
||||||
|
if (!frozen_) {
|
||||||
|
// The references were not retrieved so discard them to avoid
|
||||||
|
// leaking memory.
|
||||||
|
TensorReferenceVector refs;
|
||||||
|
FreezeAndReturnReferences(&refs);
|
||||||
|
for (auto& tensor : refs) {
|
||||||
|
tensor.Unref();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
delete referenced_tensors_set_;
|
||||||
|
}
|
||||||
|
|
||||||
|
void UniqueTensorReferences::Add(const Tensor& tensor) {
|
||||||
|
DCHECK(!frozen_);
|
||||||
|
// Do nothing if the tensor has a null buffer.
|
||||||
|
if (tensor.IsInitialized()) {
|
||||||
|
if (referenced_tensors_set_ != nullptr) {
|
||||||
|
// There are enough tensors that we are using a hash set to
|
||||||
|
// de-duplicate.
|
||||||
|
const TensorReference tensor_ref(tensor);
|
||||||
|
if (!referenced_tensors_set_->insert(tensor_ref).second) {
|
||||||
|
// The tensor was a duplicate, so discard the reference.
|
||||||
|
tensor_ref.Unref();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for (size_t i = 0; i < referenced_tensors_vector_.size(); ++i) {
|
||||||
|
if (referenced_tensors_vector_[i].SharesBufferWith(tensor)) {
|
||||||
|
// tensor is a duplicate, so nothing to do.
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
referenced_tensors_vector_.push_back(TensorReference(tensor));
|
||||||
|
if (kInVector == referenced_tensors_vector_.size()) {
|
||||||
|
// There are too many tensors to keep using the N^2 algorithm
|
||||||
|
// so start de-duplicating using a set.
|
||||||
|
// Transfer the refs from the vector to the set.
|
||||||
|
DCHECK(referenced_tensors_set_ == nullptr);
|
||||||
|
referenced_tensors_set_ = new ReferencedTensorsSet;
|
||||||
|
referenced_tensors_set_->reserve(kInVector);
|
||||||
|
referenced_tensors_set_->insert(referenced_tensors_vector_.begin(),
|
||||||
|
referenced_tensors_vector_.end());
|
||||||
|
DCHECK_EQ(kInVector, referenced_tensors_set_->size());
|
||||||
|
referenced_tensors_vector_.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void UniqueTensorReferences::FreezeAndReturnReferences(
|
||||||
|
TensorReferenceVector* out_vector) {
|
||||||
|
// Prevent any further additions.
|
||||||
|
frozen_ = true;
|
||||||
|
if (referenced_tensors_set_ != nullptr) {
|
||||||
|
DCHECK(referenced_tensors_vector_.empty());
|
||||||
|
out_vector->reserve(referenced_tensors_set_->size());
|
||||||
|
for (const auto& ref : *referenced_tensors_set_) {
|
||||||
|
out_vector->push_back(ref);
|
||||||
|
}
|
||||||
|
referenced_tensors_set_->clear();
|
||||||
|
delete referenced_tensors_set_;
|
||||||
|
referenced_tensors_set_ = nullptr;
|
||||||
|
} else {
|
||||||
|
out_vector->reserve(referenced_tensors_vector_.size());
|
||||||
|
for (const auto& ref : referenced_tensors_vector_) {
|
||||||
|
out_vector->push_back(ref);
|
||||||
|
}
|
||||||
|
referenced_tensors_vector_.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
@ -35,78 +35,14 @@ class UniqueTensorReferences {
|
|||||||
public:
|
public:
|
||||||
UniqueTensorReferences() : frozen_(false), referenced_tensors_set_(nullptr) {}
|
UniqueTensorReferences() : frozen_(false), referenced_tensors_set_(nullptr) {}
|
||||||
|
|
||||||
~UniqueTensorReferences() {
|
~UniqueTensorReferences();
|
||||||
if (!frozen_) {
|
|
||||||
// The references were not retrieved so discard them to avoid
|
|
||||||
// leaking memory.
|
|
||||||
TensorReferenceVector refs;
|
|
||||||
FreezeAndReturnReferences(&refs);
|
|
||||||
for (auto& tensor : refs) {
|
|
||||||
tensor.Unref();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
delete referenced_tensors_set_;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Adds a reference to tensor if its buffer is not already referenced.
|
// Adds a reference to tensor if its buffer is not already referenced.
|
||||||
void Add(const Tensor& tensor) {
|
void Add(const Tensor& tensor);
|
||||||
DCHECK(!frozen_);
|
|
||||||
// Do nothing if the tensor has a null buffer.
|
|
||||||
if (tensor.IsInitialized()) {
|
|
||||||
if (referenced_tensors_set_ != nullptr) {
|
|
||||||
// There are enough tensors that we are using a hash set to
|
|
||||||
// de-duplicate.
|
|
||||||
const TensorReference tensor_ref(tensor);
|
|
||||||
if (!referenced_tensors_set_->insert(tensor_ref).second) {
|
|
||||||
// The tensor was a duplicate, so discard the reference.
|
|
||||||
tensor_ref.Unref();
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (size_t i = 0; i < referenced_tensors_vector_.size(); ++i) {
|
|
||||||
if (referenced_tensors_vector_[i].SharesBufferWith(tensor)) {
|
|
||||||
// tensor is a duplicate, so nothing to do.
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
referenced_tensors_vector_.push_back(TensorReference(tensor));
|
|
||||||
if (kInVector == referenced_tensors_vector_.size()) {
|
|
||||||
// There are too many tensors to keep using the N^2 algorithm
|
|
||||||
// so start de-duplicating using a set.
|
|
||||||
// Transfer the refs from the vector to the set.
|
|
||||||
DCHECK(referenced_tensors_set_ == nullptr);
|
|
||||||
referenced_tensors_set_ = new ReferencedTensorsSet;
|
|
||||||
referenced_tensors_set_->reserve(kInVector);
|
|
||||||
referenced_tensors_set_->insert(referenced_tensors_vector_.begin(),
|
|
||||||
referenced_tensors_vector_.end());
|
|
||||||
DCHECK_EQ(kInVector, referenced_tensors_set_->size());
|
|
||||||
referenced_tensors_vector_.clear();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// No more references may be added after this is called. The unique
|
// No more references may be added after this is called. The unique
|
||||||
// references are returning in out_vector.
|
// references are returning in out_vector.
|
||||||
void FreezeAndReturnReferences(TensorReferenceVector* out_vector) {
|
void FreezeAndReturnReferences(TensorReferenceVector* out_vector);
|
||||||
// Prevent any further additions.
|
|
||||||
frozen_ = true;
|
|
||||||
if (referenced_tensors_set_ != nullptr) {
|
|
||||||
DCHECK(referenced_tensors_vector_.empty());
|
|
||||||
out_vector->reserve(referenced_tensors_set_->size());
|
|
||||||
for (const auto& ref : *referenced_tensors_set_) {
|
|
||||||
out_vector->push_back(ref);
|
|
||||||
}
|
|
||||||
referenced_tensors_set_->clear();
|
|
||||||
delete referenced_tensors_set_;
|
|
||||||
referenced_tensors_set_ = nullptr;
|
|
||||||
} else {
|
|
||||||
out_vector->reserve(referenced_tensors_vector_.size());
|
|
||||||
for (const auto& ref : referenced_tensors_vector_) {
|
|
||||||
out_vector->push_back(ref);
|
|
||||||
}
|
|
||||||
referenced_tensors_vector_.clear();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Up to kInVector elements are stored in reference_tensors_vector_
|
// Up to kInVector elements are stored in reference_tensors_vector_
|
||||||
|
@ -21,6 +21,13 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
NodeBuilder::NodeOut::NodeOut(Node* n, int i) // NOLINT(runtime/explicit)
|
||||||
|
: node(n),
|
||||||
|
error(false),
|
||||||
|
name(node != nullptr ? node->name() : (error = true, "")),
|
||||||
|
index(i),
|
||||||
|
dt(SafeGetOutput(node, i, &error)) {}
|
||||||
|
|
||||||
NodeBuilder::NodeBuilder(const string& name, const string& op_name,
|
NodeBuilder::NodeBuilder(const string& name, const string& op_name,
|
||||||
const OpRegistryInterface* op_registry)
|
const OpRegistryInterface* op_registry)
|
||||||
: def_builder_(name, op_name, op_registry) {}
|
: def_builder_(name, op_name, op_registry) {}
|
||||||
|
@ -48,12 +48,7 @@ class NodeBuilder {
|
|||||||
// ArraySlice.
|
// ArraySlice.
|
||||||
struct NodeOut {
|
struct NodeOut {
|
||||||
// For referencing an existing Node.
|
// For referencing an existing Node.
|
||||||
NodeOut(Node* n, int i = 0) // NOLINT(runtime/explicit)
|
NodeOut(Node* n, int i = 0);
|
||||||
: node(n),
|
|
||||||
error(false),
|
|
||||||
name(node != nullptr ? node->name() : (error = true, "")),
|
|
||||||
index(i),
|
|
||||||
dt(SafeGetOutput(node, i, &error)) {}
|
|
||||||
|
|
||||||
// For referencing Nodes not in the graph being built. It is
|
// For referencing Nodes not in the graph being built. It is
|
||||||
// useful when preparing a graph for ExtendSession or creating a
|
// useful when preparing a graph for ExtendSession or creating a
|
||||||
|
@ -1074,6 +1074,7 @@ filegroup(
|
|||||||
"io.cc",
|
"io.cc",
|
||||||
"lrn_op.cc",
|
"lrn_op.cc",
|
||||||
"maxpooling_op.cc",
|
"maxpooling_op.cc",
|
||||||
|
"reduction_ops_common.cc",
|
||||||
"reduction_ops_max.cc",
|
"reduction_ops_max.cc",
|
||||||
"reduction_ops_mean.cc",
|
"reduction_ops_mean.cc",
|
||||||
"reduction_ops_min.cc",
|
"reduction_ops_min.cc",
|
||||||
|
127
tensorflow/core/kernels/reduction_ops_common.cc
Normal file
127
tensorflow/core/kernels/reduction_ops_common.cc
Normal file
@ -0,0 +1,127 @@
|
|||||||
|
#include "tensorflow/core/kernels/reduction_ops_common.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
TensorShape ReductionHelper::out_reshape() const {
|
||||||
|
TensorShape shape;
|
||||||
|
for (auto size : out_reshape_) shape.AddDim(size);
|
||||||
|
return shape;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The final output shape must be allocated with this shape.
|
||||||
|
TensorShape ReductionHelper::out_shape() const {
|
||||||
|
TensorShape shape;
|
||||||
|
for (auto size : out_shape_) shape.AddDim(size);
|
||||||
|
return shape;
|
||||||
|
}
|
||||||
|
|
||||||
|
TensorShape ReductionHelper::shuffled_shape() {
|
||||||
|
const int dims = data_reshape_.size();
|
||||||
|
TensorShape shape;
|
||||||
|
for (int i = reduce_first_axis_; i < dims; i += 2) {
|
||||||
|
shape.AddDim(data_reshape_[i]);
|
||||||
|
}
|
||||||
|
for (int i = !reduce_first_axis_; i < dims; i += 2) {
|
||||||
|
shape.AddDim(data_reshape_[i]);
|
||||||
|
}
|
||||||
|
return shape;
|
||||||
|
}
|
||||||
|
|
||||||
|
gtl::InlinedVector<int32, 8> ReductionHelper::permutation() {
|
||||||
|
const int dims = data_reshape_.size();
|
||||||
|
const int unreduced_dims = (dims + !reduce_first_axis_) / 2;
|
||||||
|
gtl::InlinedVector<int32, 8> perm(dims);
|
||||||
|
for (int i = 0; i < unreduced_dims; i++) {
|
||||||
|
perm[i] = 2 * i + reduce_first_axis_;
|
||||||
|
}
|
||||||
|
for (int i = unreduced_dims; i < dims; i++) {
|
||||||
|
perm[i] = 2 * (i - unreduced_dims) + !reduce_first_axis_;
|
||||||
|
}
|
||||||
|
return perm;
|
||||||
|
}
|
||||||
|
|
||||||
|
Status ReductionHelper::Simplify(const Tensor& data, const Tensor& axis,
|
||||||
|
const bool keep_dims) {
|
||||||
|
// bitmap[i] indicates whether to reduce data along i-th axis.
|
||||||
|
gtl::InlinedVector<bool, 4> bitmap(data.dims(), false);
|
||||||
|
auto axis_vec = axis.flat<int32>();
|
||||||
|
for (int64 i = 0; i < axis.NumElements(); ++i) {
|
||||||
|
const int32 index = axis_vec(i);
|
||||||
|
if (index < 0 || index >= data.dims()) {
|
||||||
|
return errors::OutOfRange("Invalid reduction dimension (", index,
|
||||||
|
" for input with ", data.dims(),
|
||||||
|
" dimension(s)");
|
||||||
|
}
|
||||||
|
bitmap[index] = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Output tensor's dim sizes.
|
||||||
|
out_shape_.clear();
|
||||||
|
for (int i = 0; i < data.dims(); ++i) {
|
||||||
|
if (!bitmap[i]) {
|
||||||
|
// If we are not reducing along dimension i.
|
||||||
|
out_shape_.push_back(data.dim_size(i));
|
||||||
|
} else if (keep_dims) {
|
||||||
|
// We are reducing along dimension i, but we want to keep the
|
||||||
|
// same number of dimensions, so we set the dimension of i to
|
||||||
|
// '1'.
|
||||||
|
out_shape_.push_back(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Depending on bitmap[i] and bitmap[i-1], we can collapse axis of
|
||||||
|
// the input data before doing the reduction on the resulting
|
||||||
|
// tensor. The shape of the reduction is a reshape of the final
|
||||||
|
// output.
|
||||||
|
|
||||||
|
// We'll skip the leading 1s.
|
||||||
|
int dim_index = 0;
|
||||||
|
for (; dim_index < data.dims(); ++dim_index) {
|
||||||
|
if (data.dim_size(dim_index) != 1) break;
|
||||||
|
}
|
||||||
|
if (dim_index >= data.dims()) {
|
||||||
|
// Special case. The input is essentially a scalar.
|
||||||
|
reduce_first_axis_ = true;
|
||||||
|
} else {
|
||||||
|
// Starting from the (dim_index)-th dimension, dimensions
|
||||||
|
// alternates between runs that need to be reduced and runs that
|
||||||
|
// don't.
|
||||||
|
//
|
||||||
|
// NOTE: If a dimension has size 1, we group it as the current
|
||||||
|
// run so that we can minimize the number of runs.
|
||||||
|
//
|
||||||
|
// E.g., when we want to reduce a tensor of shape [2, 1, 3, 1,
|
||||||
|
// 5] by axes = [1, 4], we should treat the tensor as a [6, 5]
|
||||||
|
// and reduce by axes = [1] (i.e., the output is shape [6]).
|
||||||
|
reduce_first_axis_ = bitmap[dim_index];
|
||||||
|
data_reshape_.push_back(data.dim_size(dim_index));
|
||||||
|
++dim_index;
|
||||||
|
for (; dim_index < data.dims(); ++dim_index) {
|
||||||
|
const auto size = data.dim_size(dim_index);
|
||||||
|
if (size == 1) {
|
||||||
|
bitmap[dim_index] = bitmap[dim_index - 1];
|
||||||
|
}
|
||||||
|
if (bitmap[dim_index - 1] != bitmap[dim_index]) {
|
||||||
|
// Starts a new run of reduce or !reduce.
|
||||||
|
data_reshape_.push_back(size);
|
||||||
|
} else {
|
||||||
|
// Continue a run of reduce or !reduce.
|
||||||
|
data_reshape_.back() *= size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// If reduce_first_axis_ is true (input's dimension 0, 2, 4, etc
|
||||||
|
// are reduced), data_reshape_[1, 3, 5, ...] is out_reshape_,
|
||||||
|
// otherwise, data_reshape_[0, 2, 4, ...] is.
|
||||||
|
for (size_t i = reduce_first_axis_ ? 1 : 0; i < data_reshape_.size();
|
||||||
|
i += 2) {
|
||||||
|
out_reshape_.push_back(data_reshape_[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
VLOG(1) << "data reshape: " << str_util::Join(data_reshape_, ",");
|
||||||
|
VLOG(1) << "out reshape: " << str_util::Join(out_reshape_, ",");
|
||||||
|
VLOG(1) << "out shape: " << str_util::Join(out_shape_, ",");
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
@ -68,95 +68,11 @@ struct Constants<CPUDevice> {
|
|||||||
};
|
};
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
class ReductionHelper {
|
class ReductionHelper {
|
||||||
public:
|
public:
|
||||||
ReductionHelper() : reduce_first_axis_(false) {}
|
ReductionHelper() : reduce_first_axis_(false) {}
|
||||||
|
|
||||||
Status Simplify(const Tensor& data, const Tensor& axis,
|
Status Simplify(const Tensor& data, const Tensor& axis, const bool keep_dims);
|
||||||
const bool keep_dims) {
|
|
||||||
// bitmap[i] indicates whether to reduce data along i-th axis.
|
|
||||||
gtl::InlinedVector<bool, 4> bitmap(data.dims(), false);
|
|
||||||
auto axis_vec = axis.flat<int32>();
|
|
||||||
for (int64 i = 0; i < axis.NumElements(); ++i) {
|
|
||||||
const int32 index = axis_vec(i);
|
|
||||||
if (index < 0 || index >= data.dims()) {
|
|
||||||
return errors::OutOfRange("Invalid reduction dimension (", index,
|
|
||||||
" for input with ", data.dims(),
|
|
||||||
" dimension(s)");
|
|
||||||
}
|
|
||||||
bitmap[index] = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Output tensor's dim sizes.
|
|
||||||
out_shape_.clear();
|
|
||||||
for (int i = 0; i < data.dims(); ++i) {
|
|
||||||
if (!bitmap[i]) {
|
|
||||||
// If we are not reducing along dimension i.
|
|
||||||
out_shape_.push_back(data.dim_size(i));
|
|
||||||
} else if (keep_dims) {
|
|
||||||
// We are reducing along dimension i, but we want to keep the
|
|
||||||
// same number of dimensions, so we set the dimension of i to
|
|
||||||
// '1'.
|
|
||||||
out_shape_.push_back(1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Depending on bitmap[i] and bitmap[i-1], we can collapse axis of
|
|
||||||
// the input data before doing the reduction on the resulting
|
|
||||||
// tensor. The shape of the reduction is a reshape of the final
|
|
||||||
// output.
|
|
||||||
|
|
||||||
// We'll skip the leading 1s.
|
|
||||||
int dim_index = 0;
|
|
||||||
for (; dim_index < data.dims(); ++dim_index) {
|
|
||||||
if (data.dim_size(dim_index) != 1) break;
|
|
||||||
}
|
|
||||||
if (dim_index >= data.dims()) {
|
|
||||||
// Special case. The input is essentially a scalar.
|
|
||||||
reduce_first_axis_ = true;
|
|
||||||
} else {
|
|
||||||
// Starting from the (dim_index)-th dimension, dimensions
|
|
||||||
// alternates between runs that need to be reduced and runs that
|
|
||||||
// don't.
|
|
||||||
//
|
|
||||||
// NOTE: If a dimension has size 1, we group it as the current
|
|
||||||
// run so that we can minimize the number of runs.
|
|
||||||
//
|
|
||||||
// E.g., when we want to reduce a tensor of shape [2, 1, 3, 1,
|
|
||||||
// 5] by axes = [1, 4], we should treat the tensor as a [6, 5]
|
|
||||||
// and reduce by axes = [1] (i.e., the output is shape [6]).
|
|
||||||
reduce_first_axis_ = bitmap[dim_index];
|
|
||||||
data_reshape_.push_back(data.dim_size(dim_index));
|
|
||||||
++dim_index;
|
|
||||||
for (; dim_index < data.dims(); ++dim_index) {
|
|
||||||
const auto size = data.dim_size(dim_index);
|
|
||||||
if (size == 1) {
|
|
||||||
bitmap[dim_index] = bitmap[dim_index - 1];
|
|
||||||
}
|
|
||||||
if (bitmap[dim_index - 1] != bitmap[dim_index]) {
|
|
||||||
// Starts a new run of reduce or !reduce.
|
|
||||||
data_reshape_.push_back(size);
|
|
||||||
} else {
|
|
||||||
// Continue a run of reduce or !reduce.
|
|
||||||
data_reshape_.back() *= size;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// If reduce_first_axis_ is true (input's dimension 0, 2, 4, etc
|
|
||||||
// are reduced), data_reshape_[1, 3, 5, ...] is out_reshape_,
|
|
||||||
// otherwise, data_reshape_[0, 2, 4, ...] is.
|
|
||||||
for (size_t i = reduce_first_axis_ ? 1 : 0; i < data_reshape_.size();
|
|
||||||
i += 2) {
|
|
||||||
out_reshape_.push_back(data_reshape_[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
VLOG(1) << "data reshape: " << str_util::Join(data_reshape_, ",");
|
|
||||||
VLOG(1) << "out reshape: " << str_util::Join(out_reshape_, ",");
|
|
||||||
VLOG(1) << "out shape: " << str_util::Join(out_shape_, ",");
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
// We need to do roughly:
|
// We need to do roughly:
|
||||||
// tmp_out = allocate(out_reshape())
|
// tmp_out = allocate(out_reshape())
|
||||||
@ -164,18 +80,10 @@ class ReductionHelper {
|
|||||||
// out = tmp_out.reshape(out_shape)
|
// out = tmp_out.reshape(out_shape)
|
||||||
|
|
||||||
// The reduction result must be allocated with this shape.
|
// The reduction result must be allocated with this shape.
|
||||||
TensorShape out_reshape() const {
|
TensorShape out_reshape() const;
|
||||||
TensorShape shape;
|
|
||||||
for (auto size : out_reshape_) shape.AddDim(size);
|
|
||||||
return shape;
|
|
||||||
}
|
|
||||||
|
|
||||||
// The final output shape must be allocated with this shape.
|
// The final output shape must be allocated with this shape.
|
||||||
TensorShape out_shape() const {
|
TensorShape out_shape() const;
|
||||||
TensorShape shape;
|
|
||||||
for (auto size : out_shape_) shape.AddDim(size);
|
|
||||||
return shape;
|
|
||||||
}
|
|
||||||
|
|
||||||
// The reduction is on a reshaped tensor of this rank.
|
// The reduction is on a reshaped tensor of this rank.
|
||||||
int ndims() const { return data_reshape_.size(); }
|
int ndims() const { return data_reshape_.size(); }
|
||||||
@ -203,31 +111,10 @@ class ReductionHelper {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Shape with all reduction dimensions at the end
|
// Shape with all reduction dimensions at the end
|
||||||
TensorShape shuffled_shape() {
|
TensorShape shuffled_shape();
|
||||||
const int dims = data_reshape_.size();
|
|
||||||
TensorShape shape;
|
|
||||||
for (int i = reduce_first_axis_; i < dims; i += 2) {
|
|
||||||
shape.AddDim(data_reshape_[i]);
|
|
||||||
}
|
|
||||||
for (int i = !reduce_first_axis_; i < dims; i += 2) {
|
|
||||||
shape.AddDim(data_reshape_[i]);
|
|
||||||
}
|
|
||||||
return shape;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Permutation of reduced dims needed to put reduction dimensions at the end
|
// Permutation of reduced dims needed to put reduction dimensions at the end
|
||||||
gtl::InlinedVector<int32, 8> permutation() {
|
gtl::InlinedVector<int32, 8> permutation();
|
||||||
const int dims = data_reshape_.size();
|
|
||||||
const int unreduced_dims = (dims + !reduce_first_axis_) / 2;
|
|
||||||
gtl::InlinedVector<int32, 8> perm(dims);
|
|
||||||
for (int i = 0; i < unreduced_dims; i++) {
|
|
||||||
perm[i] = 2 * i + reduce_first_axis_;
|
|
||||||
}
|
|
||||||
for (int i = unreduced_dims; i < dims; i++) {
|
|
||||||
perm[i] = 2 * (i - unreduced_dims) + !reduce_first_axis_;
|
|
||||||
}
|
|
||||||
return perm;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
bool reduce_first_axis_; // True if need to reduce the 0-th dimension.
|
bool reduce_first_axis_; // True if need to reduce the 0-th dimension.
|
||||||
@ -236,8 +123,6 @@ class ReductionHelper {
|
|||||||
gtl::InlinedVector<int64, 4> out_reshape_; // Reshape output for reduction.
|
gtl::InlinedVector<int64, 4> out_reshape_; // Reshape output for reduction.
|
||||||
};
|
};
|
||||||
|
|
||||||
} // end namespace
|
|
||||||
|
|
||||||
// For operations where the output is a reduction function along some
|
// For operations where the output is a reduction function along some
|
||||||
// dimensions of the input.
|
// dimensions of the input.
|
||||||
template <typename Device, class T, typename Reducer>
|
template <typename Device, class T, typename Reducer>
|
||||||
|
Loading…
Reference in New Issue
Block a user