Reduce many unnecessary Shape copy.

PiperOrigin-RevId: 254315128
This commit is contained in:
A. Unique TensorFlower 2019-06-20 18:45:02 -07:00 committed by TensorFlower Gardener
parent 81cc1b91cc
commit 9fe47db864
10 changed files with 28 additions and 23 deletions

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module.h"
namespace xla { namespace xla {
@ -81,8 +82,8 @@ HloInputOutputAliasProto HloInputOutputAliasConfig::ToProto() const {
} }
StatusOr<HloInputOutputAliasConfig> HloInputOutputAliasConfig::CreateFromProto( StatusOr<HloInputOutputAliasConfig> HloInputOutputAliasConfig::CreateFromProto(
const Shape& output_shape, const HloInputOutputAliasProto& proto) { Shape output_shape, const HloInputOutputAliasProto& proto) {
HloInputOutputAliasConfig result(output_shape); HloInputOutputAliasConfig result(std::move(output_shape));
for (const HloInputOutputAliasProto::AliasEntryProto& entry : for (const HloInputOutputAliasProto::AliasEntryProto& entry :
proto.entries()) { proto.entries()) {
ShapeIndex output_index(entry.output_shape_index().begin(), ShapeIndex output_index(entry.output_shape_index().begin(),

View File

@ -57,7 +57,7 @@ class HloInputOutputAliasConfig {
HloInputOutputAliasConfig() = default; HloInputOutputAliasConfig() = default;
explicit HloInputOutputAliasConfig(Shape output_shape) explicit HloInputOutputAliasConfig(Shape output_shape)
: alias_(output_shape) {} : alias_(std::move(output_shape)) {}
virtual ~HloInputOutputAliasConfig() = default; virtual ~HloInputOutputAliasConfig() = default;
@ -86,7 +86,7 @@ class HloInputOutputAliasConfig {
HloInputOutputAliasProto ToProto() const; HloInputOutputAliasProto ToProto() const;
static StatusOr<HloInputOutputAliasConfig> CreateFromProto( static StatusOr<HloInputOutputAliasConfig> CreateFromProto(
const Shape& output_shape, const HloInputOutputAliasProto& proto); Shape output_shape, const HloInputOutputAliasProto& proto);
// Returns the output index that the given parameter and parameter index is // Returns the output index that the given parameter and parameter index is
// aliased with. A nullopt is returned if there is no output that is aliased // aliased with. A nullopt is returned if there is no output that is aliased

View File

@ -36,9 +36,9 @@ limitations under the License.
namespace xla { namespace xla {
HloModule::HloModule(const string& name, const HloModuleConfig& config) HloModule::HloModule(const string& name, HloModuleConfig config)
: name_(NameUniquer::GetSanitizedName(name)), : name_(NameUniquer::GetSanitizedName(name)),
config_(config), config_(std::move(config)),
unique_id_(next_unique_module_id_++) {} unique_id_(next_unique_module_id_++) {}
Status HloModule::set_schedule(HloSchedule schedule) { Status HloModule::set_schedule(HloSchedule schedule) {

View File

@ -64,7 +64,7 @@ class HloModule {
// only be used for HloModules used outside of the XLA service (eg // only be used for HloModules used outside of the XLA service (eg
// tests). The versioned handle is used by the service in the compilation // tests). The versioned handle is used by the service in the compilation
// cache. A default configuration is created for this module. // cache. A default configuration is created for this module.
explicit HloModule(const string& name, const HloModuleConfig& config); explicit HloModule(const string& name, HloModuleConfig config);
virtual ~HloModule() {} virtual ~HloModule() {}
// Adds an entry computation to the module. A module can only have one entry // Adds an entry computation to the module. A module can only have one entry

View File

@ -33,6 +33,9 @@ HloModuleConfig::HloModuleConfig(const ProgramShape& program_shape,
: entry_computation_layout_( : entry_computation_layout_(
ComputationLayout(program_shape, ignore_layouts)) {} ComputationLayout(program_shape, ignore_layouts)) {}
HloModuleConfig::HloModuleConfig(ComputationLayout entry_computation_layout)
: entry_computation_layout_(std::move(entry_computation_layout)) {}
void HloModuleConfig::SetDefaultComputationLayout( void HloModuleConfig::SetDefaultComputationLayout(
const ProgramShape& program_shape) { const ProgramShape& program_shape) {
entry_computation_layout_ = ComputationLayout(program_shape); entry_computation_layout_ = ComputationLayout(program_shape);

View File

@ -45,6 +45,8 @@ class HloModuleConfig {
explicit HloModuleConfig(const ProgramShape& program_shape, explicit HloModuleConfig(const ProgramShape& program_shape,
bool ignore_layouts = true); bool ignore_layouts = true);
explicit HloModuleConfig(ComputationLayout entry_computation_layout);
// Checks if this config has an entry computation layout already. // Checks if this config has an entry computation layout already.
bool has_entry_computation_layout() const { bool has_entry_computation_layout() const {
return entry_computation_layout_.has_value(); return entry_computation_layout_.has_value();

View File

@ -31,11 +31,10 @@ limitations under the License.
namespace xla { namespace xla {
ShapedBuffer::ShapedBuffer(const Shape& on_host_shape, ShapedBuffer::ShapedBuffer(Shape on_host_shape, Shape on_device_shape,
const Shape& on_device_shape,
const se::Platform* platform, int device_ordinal) const se::Platform* platform, int device_ordinal)
: on_host_shape_(on_host_shape), : on_host_shape_(std::move(on_host_shape)),
on_device_shape_(on_device_shape), on_device_shape_(std::move(on_device_shape)),
platform_(platform), platform_(platform),
device_ordinal_(device_ordinal), device_ordinal_(device_ordinal),
buffers_(&on_device_shape_) {} buffers_(&on_device_shape_) {}
@ -117,12 +116,12 @@ std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer) {
return out; return out;
} }
ScopedShapedBuffer::ScopedShapedBuffer(const Shape& on_host_shape, ScopedShapedBuffer::ScopedShapedBuffer(Shape on_host_shape,
const Shape& on_device_shape, Shape on_device_shape,
se::DeviceMemoryAllocator* allocator, se::DeviceMemoryAllocator* allocator,
int device_ordinal) int device_ordinal)
: ShapedBuffer(on_host_shape, on_device_shape, allocator->platform(), : ShapedBuffer(std::move(on_host_shape), std::move(on_device_shape),
device_ordinal), allocator->platform(), device_ordinal),
allocator_(allocator) {} allocator_(allocator) {}
ScopedShapedBuffer::ScopedShapedBuffer(ShapedBuffer shaped_buffer, ScopedShapedBuffer::ScopedShapedBuffer(ShapedBuffer shaped_buffer,

View File

@ -42,7 +42,7 @@ class ShapedBuffer {
// both the on-host and on-device shape are required. The on-device shape // both the on-host and on-device shape are required. The on-device shape
// determines the number of device allocations (DeviceMemoryBase) held by the // determines the number of device allocations (DeviceMemoryBase) held by the
// ShapedBuffer. // ShapedBuffer.
ShapedBuffer(const Shape& on_host_shape, const Shape& on_device_shape, ShapedBuffer(Shape on_host_shape, Shape on_device_shape,
const se::Platform* platform, int device_ordinal); const se::Platform* platform, int device_ordinal);
// Movable, but not copyable. // Movable, but not copyable.
@ -136,8 +136,7 @@ std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer);
class ScopedShapedBuffer : public ShapedBuffer { class ScopedShapedBuffer : public ShapedBuffer {
public: public:
// Creates a ScopedShapedBuffer with null DeviceMemoryBases at each index. // Creates a ScopedShapedBuffer with null DeviceMemoryBases at each index.
explicit ScopedShapedBuffer(const Shape& on_host_shape, explicit ScopedShapedBuffer(Shape on_host_shape, Shape on_device_shape,
const Shape& on_device_shape,
se::DeviceMemoryAllocator* allocator, se::DeviceMemoryAllocator* allocator,
int device_ordinal); int device_ordinal);

View File

@ -315,18 +315,19 @@ StatusOr<ScopedShapedBuffer> TransferManager::AllocateScopedShapedBuffer(
ShapeUtil::HumanStringWithLayout(on_host_shape)); ShapeUtil::HumanStringWithLayout(on_host_shape));
} }
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(on_host_shape)); TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(on_host_shape));
const Shape on_device_shape = HostShapeToDeviceShape(on_host_shape); Shape on_device_shape = HostShapeToDeviceShape(on_host_shape);
TF_RET_CHECK(LayoutUtil::HasLayout(on_device_shape)); TF_RET_CHECK(LayoutUtil::HasLayout(on_device_shape));
ScopedShapedBuffer shaped_buffer(on_host_shape, on_device_shape, allocator, ScopedShapedBuffer shaped_buffer(on_host_shape, std::move(on_device_shape),
device_ordinal); allocator, device_ordinal);
// Allocate an appropriate sized buffer for each element in the shape // Allocate an appropriate sized buffer for each element in the shape
// including the tuple pointer arrays. // including the tuple pointer arrays.
for (auto& pair : shaped_buffer.buffers()) { for (auto& pair : shaped_buffer.buffers()) {
const ShapeIndex& index = pair.first; const ShapeIndex& index = pair.first;
se::DeviceMemoryBase& memory_base = pair.second; se::DeviceMemoryBase& memory_base = pair.second;
const Shape& subshape = ShapeUtil::GetSubshape(on_device_shape, index); const Shape& subshape =
ShapeUtil::GetSubshape(shaped_buffer.on_device_shape(), index);
TF_ASSIGN_OR_RETURN(auto memory, TF_ASSIGN_OR_RETURN(auto memory,
allocator->Allocate(shaped_buffer.device_ordinal(), allocator->Allocate(shaped_buffer.device_ordinal(),
GetByteSizeRequirement(subshape))); GetByteSizeRequirement(subshape)));

View File

@ -129,7 +129,7 @@ class ShapeTree {
// this ShapeTree object, and also that the Shape is consistent with it. // this ShapeTree object, and also that the Shape is consistent with it.
void replace_shape_ptr(const Shape* shape) { void replace_shape_ptr(const Shape* shape) {
if (shape_storage_ != nullptr) { if (shape_storage_ != nullptr) {
CHECK_EQ(*shape, *shape_storage_); DCHECK_EQ(*shape, *shape_storage_);
shape_storage_ = nullptr; shape_storage_ = nullptr;
} }
shape_ = shape; shape_ = shape;