Reduce many unnecessary Shape copy.
PiperOrigin-RevId: 254315128
This commit is contained in:
parent
81cc1b91cc
commit
9fe47db864
@ -14,6 +14,7 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
|
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
@ -81,8 +82,8 @@ HloInputOutputAliasProto HloInputOutputAliasConfig::ToProto() const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<HloInputOutputAliasConfig> HloInputOutputAliasConfig::CreateFromProto(
|
StatusOr<HloInputOutputAliasConfig> HloInputOutputAliasConfig::CreateFromProto(
|
||||||
const Shape& output_shape, const HloInputOutputAliasProto& proto) {
|
Shape output_shape, const HloInputOutputAliasProto& proto) {
|
||||||
HloInputOutputAliasConfig result(output_shape);
|
HloInputOutputAliasConfig result(std::move(output_shape));
|
||||||
for (const HloInputOutputAliasProto::AliasEntryProto& entry :
|
for (const HloInputOutputAliasProto::AliasEntryProto& entry :
|
||||||
proto.entries()) {
|
proto.entries()) {
|
||||||
ShapeIndex output_index(entry.output_shape_index().begin(),
|
ShapeIndex output_index(entry.output_shape_index().begin(),
|
||||||
|
@ -57,7 +57,7 @@ class HloInputOutputAliasConfig {
|
|||||||
HloInputOutputAliasConfig() = default;
|
HloInputOutputAliasConfig() = default;
|
||||||
|
|
||||||
explicit HloInputOutputAliasConfig(Shape output_shape)
|
explicit HloInputOutputAliasConfig(Shape output_shape)
|
||||||
: alias_(output_shape) {}
|
: alias_(std::move(output_shape)) {}
|
||||||
|
|
||||||
virtual ~HloInputOutputAliasConfig() = default;
|
virtual ~HloInputOutputAliasConfig() = default;
|
||||||
|
|
||||||
@ -86,7 +86,7 @@ class HloInputOutputAliasConfig {
|
|||||||
HloInputOutputAliasProto ToProto() const;
|
HloInputOutputAliasProto ToProto() const;
|
||||||
|
|
||||||
static StatusOr<HloInputOutputAliasConfig> CreateFromProto(
|
static StatusOr<HloInputOutputAliasConfig> CreateFromProto(
|
||||||
const Shape& output_shape, const HloInputOutputAliasProto& proto);
|
Shape output_shape, const HloInputOutputAliasProto& proto);
|
||||||
|
|
||||||
// Returns the output index that the given parameter and parameter index is
|
// Returns the output index that the given parameter and parameter index is
|
||||||
// aliased with. A nullopt is returned if there is no output that is aliased
|
// aliased with. A nullopt is returned if there is no output that is aliased
|
||||||
|
@ -36,9 +36,9 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
HloModule::HloModule(const string& name, const HloModuleConfig& config)
|
HloModule::HloModule(const string& name, HloModuleConfig config)
|
||||||
: name_(NameUniquer::GetSanitizedName(name)),
|
: name_(NameUniquer::GetSanitizedName(name)),
|
||||||
config_(config),
|
config_(std::move(config)),
|
||||||
unique_id_(next_unique_module_id_++) {}
|
unique_id_(next_unique_module_id_++) {}
|
||||||
|
|
||||||
Status HloModule::set_schedule(HloSchedule schedule) {
|
Status HloModule::set_schedule(HloSchedule schedule) {
|
||||||
|
@ -64,7 +64,7 @@ class HloModule {
|
|||||||
// only be used for HloModules used outside of the XLA service (eg
|
// only be used for HloModules used outside of the XLA service (eg
|
||||||
// tests). The versioned handle is used by the service in the compilation
|
// tests). The versioned handle is used by the service in the compilation
|
||||||
// cache. A default configuration is created for this module.
|
// cache. A default configuration is created for this module.
|
||||||
explicit HloModule(const string& name, const HloModuleConfig& config);
|
explicit HloModule(const string& name, HloModuleConfig config);
|
||||||
virtual ~HloModule() {}
|
virtual ~HloModule() {}
|
||||||
|
|
||||||
// Adds an entry computation to the module. A module can only have one entry
|
// Adds an entry computation to the module. A module can only have one entry
|
||||||
|
@ -33,6 +33,9 @@ HloModuleConfig::HloModuleConfig(const ProgramShape& program_shape,
|
|||||||
: entry_computation_layout_(
|
: entry_computation_layout_(
|
||||||
ComputationLayout(program_shape, ignore_layouts)) {}
|
ComputationLayout(program_shape, ignore_layouts)) {}
|
||||||
|
|
||||||
|
HloModuleConfig::HloModuleConfig(ComputationLayout entry_computation_layout)
|
||||||
|
: entry_computation_layout_(std::move(entry_computation_layout)) {}
|
||||||
|
|
||||||
void HloModuleConfig::SetDefaultComputationLayout(
|
void HloModuleConfig::SetDefaultComputationLayout(
|
||||||
const ProgramShape& program_shape) {
|
const ProgramShape& program_shape) {
|
||||||
entry_computation_layout_ = ComputationLayout(program_shape);
|
entry_computation_layout_ = ComputationLayout(program_shape);
|
||||||
|
@ -45,6 +45,8 @@ class HloModuleConfig {
|
|||||||
explicit HloModuleConfig(const ProgramShape& program_shape,
|
explicit HloModuleConfig(const ProgramShape& program_shape,
|
||||||
bool ignore_layouts = true);
|
bool ignore_layouts = true);
|
||||||
|
|
||||||
|
explicit HloModuleConfig(ComputationLayout entry_computation_layout);
|
||||||
|
|
||||||
// Checks if this config has an entry computation layout already.
|
// Checks if this config has an entry computation layout already.
|
||||||
bool has_entry_computation_layout() const {
|
bool has_entry_computation_layout() const {
|
||||||
return entry_computation_layout_.has_value();
|
return entry_computation_layout_.has_value();
|
||||||
|
@ -31,11 +31,10 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
ShapedBuffer::ShapedBuffer(const Shape& on_host_shape,
|
ShapedBuffer::ShapedBuffer(Shape on_host_shape, Shape on_device_shape,
|
||||||
const Shape& on_device_shape,
|
|
||||||
const se::Platform* platform, int device_ordinal)
|
const se::Platform* platform, int device_ordinal)
|
||||||
: on_host_shape_(on_host_shape),
|
: on_host_shape_(std::move(on_host_shape)),
|
||||||
on_device_shape_(on_device_shape),
|
on_device_shape_(std::move(on_device_shape)),
|
||||||
platform_(platform),
|
platform_(platform),
|
||||||
device_ordinal_(device_ordinal),
|
device_ordinal_(device_ordinal),
|
||||||
buffers_(&on_device_shape_) {}
|
buffers_(&on_device_shape_) {}
|
||||||
@ -117,12 +116,12 @@ std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer) {
|
|||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
ScopedShapedBuffer::ScopedShapedBuffer(const Shape& on_host_shape,
|
ScopedShapedBuffer::ScopedShapedBuffer(Shape on_host_shape,
|
||||||
const Shape& on_device_shape,
|
Shape on_device_shape,
|
||||||
se::DeviceMemoryAllocator* allocator,
|
se::DeviceMemoryAllocator* allocator,
|
||||||
int device_ordinal)
|
int device_ordinal)
|
||||||
: ShapedBuffer(on_host_shape, on_device_shape, allocator->platform(),
|
: ShapedBuffer(std::move(on_host_shape), std::move(on_device_shape),
|
||||||
device_ordinal),
|
allocator->platform(), device_ordinal),
|
||||||
allocator_(allocator) {}
|
allocator_(allocator) {}
|
||||||
|
|
||||||
ScopedShapedBuffer::ScopedShapedBuffer(ShapedBuffer shaped_buffer,
|
ScopedShapedBuffer::ScopedShapedBuffer(ShapedBuffer shaped_buffer,
|
||||||
|
@ -42,7 +42,7 @@ class ShapedBuffer {
|
|||||||
// both the on-host and on-device shape are required. The on-device shape
|
// both the on-host and on-device shape are required. The on-device shape
|
||||||
// determines the number of device allocations (DeviceMemoryBase) held by the
|
// determines the number of device allocations (DeviceMemoryBase) held by the
|
||||||
// ShapedBuffer.
|
// ShapedBuffer.
|
||||||
ShapedBuffer(const Shape& on_host_shape, const Shape& on_device_shape,
|
ShapedBuffer(Shape on_host_shape, Shape on_device_shape,
|
||||||
const se::Platform* platform, int device_ordinal);
|
const se::Platform* platform, int device_ordinal);
|
||||||
|
|
||||||
// Movable, but not copyable.
|
// Movable, but not copyable.
|
||||||
@ -136,8 +136,7 @@ std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer);
|
|||||||
class ScopedShapedBuffer : public ShapedBuffer {
|
class ScopedShapedBuffer : public ShapedBuffer {
|
||||||
public:
|
public:
|
||||||
// Creates a ScopedShapedBuffer with null DeviceMemoryBases at each index.
|
// Creates a ScopedShapedBuffer with null DeviceMemoryBases at each index.
|
||||||
explicit ScopedShapedBuffer(const Shape& on_host_shape,
|
explicit ScopedShapedBuffer(Shape on_host_shape, Shape on_device_shape,
|
||||||
const Shape& on_device_shape,
|
|
||||||
se::DeviceMemoryAllocator* allocator,
|
se::DeviceMemoryAllocator* allocator,
|
||||||
int device_ordinal);
|
int device_ordinal);
|
||||||
|
|
||||||
|
@ -315,18 +315,19 @@ StatusOr<ScopedShapedBuffer> TransferManager::AllocateScopedShapedBuffer(
|
|||||||
ShapeUtil::HumanStringWithLayout(on_host_shape));
|
ShapeUtil::HumanStringWithLayout(on_host_shape));
|
||||||
}
|
}
|
||||||
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(on_host_shape));
|
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(on_host_shape));
|
||||||
const Shape on_device_shape = HostShapeToDeviceShape(on_host_shape);
|
Shape on_device_shape = HostShapeToDeviceShape(on_host_shape);
|
||||||
TF_RET_CHECK(LayoutUtil::HasLayout(on_device_shape));
|
TF_RET_CHECK(LayoutUtil::HasLayout(on_device_shape));
|
||||||
|
|
||||||
ScopedShapedBuffer shaped_buffer(on_host_shape, on_device_shape, allocator,
|
ScopedShapedBuffer shaped_buffer(on_host_shape, std::move(on_device_shape),
|
||||||
device_ordinal);
|
allocator, device_ordinal);
|
||||||
|
|
||||||
// Allocate an appropriate sized buffer for each element in the shape
|
// Allocate an appropriate sized buffer for each element in the shape
|
||||||
// including the tuple pointer arrays.
|
// including the tuple pointer arrays.
|
||||||
for (auto& pair : shaped_buffer.buffers()) {
|
for (auto& pair : shaped_buffer.buffers()) {
|
||||||
const ShapeIndex& index = pair.first;
|
const ShapeIndex& index = pair.first;
|
||||||
se::DeviceMemoryBase& memory_base = pair.second;
|
se::DeviceMemoryBase& memory_base = pair.second;
|
||||||
const Shape& subshape = ShapeUtil::GetSubshape(on_device_shape, index);
|
const Shape& subshape =
|
||||||
|
ShapeUtil::GetSubshape(shaped_buffer.on_device_shape(), index);
|
||||||
TF_ASSIGN_OR_RETURN(auto memory,
|
TF_ASSIGN_OR_RETURN(auto memory,
|
||||||
allocator->Allocate(shaped_buffer.device_ordinal(),
|
allocator->Allocate(shaped_buffer.device_ordinal(),
|
||||||
GetByteSizeRequirement(subshape)));
|
GetByteSizeRequirement(subshape)));
|
||||||
|
@ -129,7 +129,7 @@ class ShapeTree {
|
|||||||
// this ShapeTree object, and also that the Shape is consistent with it.
|
// this ShapeTree object, and also that the Shape is consistent with it.
|
||||||
void replace_shape_ptr(const Shape* shape) {
|
void replace_shape_ptr(const Shape* shape) {
|
||||||
if (shape_storage_ != nullptr) {
|
if (shape_storage_ != nullptr) {
|
||||||
CHECK_EQ(*shape, *shape_storage_);
|
DCHECK_EQ(*shape, *shape_storage_);
|
||||||
shape_storage_ = nullptr;
|
shape_storage_ = nullptr;
|
||||||
}
|
}
|
||||||
shape_ = shape;
|
shape_ = shape;
|
||||||
|
Loading…
Reference in New Issue
Block a user