Merge commit for internal changes

This commit is contained in:
Yifei Feng 2017-03-29 17:30:41 -07:00
commit 155332c116
214 changed files with 5638 additions and 1956 deletions

View File

@ -297,6 +297,7 @@ filegroup(
"//tensorflow/tensorboard/backend:all_files",
"//tensorflow/tensorboard/backend/event_processing:all_files",
"//tensorflow/tensorboard/components:all_files",
"//tensorflow/tensorboard/components/tf_text_dashboard:all_files",
"//tensorflow/tensorboard/components/vz_data_summary:all_files",
"//tensorflow/tensorboard/components/vz_line_chart:all_files",
"//tensorflow/tensorboard/components/vz_line_chart/demo:all_files",

View File

@ -135,6 +135,9 @@ class TF_ManagedBuffer : public TensorBuffer {
proto->set_requested_bytes(rb);
proto->set_allocator_name(tensorflow::cpu_allocator()->Name());
}
// Prevents input forwarding from mutating this buffer.
bool OwnsMemory() const override { return false; }
};
void* allocate_tensor(const char* operation, size_t len) {

View File

@ -314,6 +314,7 @@ tf_gen_op_wrappers_cc(
name = "cc_ops",
op_lib_names = [
"array_ops",
"audio_ops",
"candidate_sampling_ops",
"control_flow_ops",
"data_flow_ops",

View File

@ -260,6 +260,8 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
xla::ExecutableRunOptions run_options;
run_options.set_stream(stream);
run_options.set_allocator(&xla_allocator);
run_options.set_inter_op_thread_pool(
ctx->device()->tensorflow_cpu_worker_threads()->workers);
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
Env* env = Env::Default();
auto start_time = env->NowMicros();

View File

@ -19,13 +19,6 @@ limitations under the License.
namespace tensorflow {
void XlaDeviceAssignOp::Copy(OpKernelContext* context, Tensor* lhs,
const Tensor& rhs) {
std::shared_ptr<xla::GlobalData> gd =
XlaTransferManager::GetTensorGlobalData(rhs);
XlaTransferManager::SetTensorGlobalData(std::move(gd), lhs);
}
XlaDeviceDummyOp::XlaDeviceDummyOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void XlaDeviceDummyOp::Compute(OpKernelContext* ctx) {

View File

@ -20,7 +20,6 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/kernels/assign_op.h"
#include "tensorflow/core/kernels/constant_op.h"
#include "tensorflow/core/kernels/control_flow_ops.h"
#include "tensorflow/core/kernels/identity_op.h"
@ -30,14 +29,6 @@ limitations under the License.
namespace tensorflow {
// Implementation of Assign for XLA devices.
class XlaDeviceAssignOp : public AssignOp {
public:
using AssignOp::AssignOp;
void Copy(OpKernelContext* context, Tensor* lhs, const Tensor& rhs) override;
};
// Dummy OpKernel, used for kernels assigned to an XLA device that should be
// compiled. Should never be called at runtime since such ops should be
// rewritten to a _XlaLaunch op. If it is called, it means the placer placed an
@ -72,28 +63,6 @@ class XlaDeviceDummyOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("PlaceholderV2").Device(DEVICE), \
PlaceholderOp); \
\
REGISTER_KERNEL_BUILDER( \
Name("Variable").Device(DEVICE).TypeConstraint("dtype", TYPES), \
VariableOp); \
REGISTER_KERNEL_BUILDER( \
Name("VariableV2").Device(DEVICE).TypeConstraint("dtype", TYPES), \
VariableOp); \
REGISTER_KERNEL_BUILDER( \
Name("TemporaryVariable").Device(DEVICE).TypeConstraint("dtype", TYPES), \
TemporaryVariableOp); \
REGISTER_KERNEL_BUILDER(Name("DestroyTemporaryVariable") \
.Device(DEVICE) \
.TypeConstraint("T", TYPES), \
DestroyTemporaryVariableOp); \
REGISTER_KERNEL_BUILDER(Name("IsVariableInitialized") \
.Device(DEVICE) \
.TypeConstraint("dtype", TYPES) \
.HostMemory("is_initialized"), \
IsVariableInitializedOp); \
REGISTER_KERNEL_BUILDER( \
Name("Assign").Device(DEVICE).TypeConstraint("T", TYPES), \
XlaDeviceAssignOp); \
\
REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE), \
ControlTriggerOp); \
REGISTER_KERNEL_BUILDER(Name("Enter").Device(DEVICE), EnterOp); \

View File

@ -614,9 +614,12 @@ REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
Name("TruncateDiv").TypeConstraint("T", kGpuIntTypes));
REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
Name("TruncateMod").TypeConstraint("T", kGpuNumericTypes));
REGISTER_XLA_KERNEL(
DEVICE_GPU_XLA_JIT,
Name("TruncatedNormal").TypeConstraint("dtype", kGpuFloatTypes));
// TODO(b/34969189) The implementation of TruncatedNormal triggers a bug on GPU.
// REGISTER_XLA_KERNEL(
// DEVICE_GPU_XLA_JIT,
// Name("TruncatedNormal").TypeConstraint("dtype", kGpuFloatTypes));
REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
Name("Unpack").TypeConstraint("T", kGpuAllTypes));
REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("VarIsInitializedOp"));

View File

@ -301,8 +301,7 @@ StatusOr<std::vector<std::unique_ptr<GlobalData>>> Client::ExecuteParallel(
}
std::vector<std::unique_ptr<GlobalData>> outputs;
for (tensorflow::gtl::ArraySlice<ComputationInstance>::size_type i = 0;
i < computations.size(); ++i) {
for (size_t i = 0; i < computations.size(); ++i) {
outputs.push_back(
MakeUnique<GlobalData>(stub_, response.responses(i).output()));
if (computations[i].execution_profile != nullptr) {

View File

@ -33,7 +33,7 @@ Computation::Computation(Computation&& computation)
}
void Computation::Reset() {
// TODO(leary) deallocate any owned computation.
// TODO(b/34469253) deallocate any owned computation.
ResetWithoutFreeing();
}

View File

@ -106,9 +106,7 @@ bool ComputationBuilder::MakeWindow(
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
tensorflow::gtl::ArraySlice<int64> lhs_dilation,
tensorflow::gtl::ArraySlice<int64> rhs_dilation, Window* window) {
const auto verify_size = [&](const tensorflow::gtl::ArraySlice<
int64>::size_type x,
const char* x_name) {
const auto verify_size = [&](const size_t x, const char* x_name) {
if (x == 0 || x == window_dimensions.size()) {
return true;
} else {

View File

@ -541,8 +541,8 @@ class ComputationBuilder {
// (float32 is specified as there is an implicit float32 -1.0f constant
// exponent).
//
// TODO(leary) axe F32 suffix, can be determined by reflecting on the shape of
// the operand.
// TODO(b/34468990) axe F32 suffix, can be determined by reflecting on the
// shape of the operand.
ComputationDataHandle ReciprocalF32(const ComputationDataHandle& operand);
// Enqueues a negate instruction onto the computation.
@ -839,7 +839,7 @@ template <typename NativeT>
ComputationDataHandle ComputationBuilder::ConstantR4FromArray4DWithLayout(
const Array4D<NativeT>& values, const Layout& layout) {
return ConstantOp([&values, &layout](Literal* literal) {
LiteralUtil::PopulateR4FromArray4D(values, layout, literal);
LiteralUtil::PopulateR4FromArray4DWithLayout(values, layout, literal);
});
}

View File

@ -309,6 +309,14 @@ int LocalClient::default_device_ordinal() const {
return local_service_->backend().default_device_ordinal();
}
const Backend& LocalClient::backend() const {
return local_service_->backend();
}
Backend* LocalClient::mutable_backend() {
return local_service_->mutable_backend();
}
StatusOr<std::unique_ptr<LocalExecutable>> LocalClient::Compile(
const Computation& computation,
const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,

View File

@ -224,6 +224,10 @@ class LocalClient : public Client {
// capability).
bool device_ordinal_supported(int device_ordinal) const;
// Returns the backend used to execute computations.
const Backend& backend() const;
Backend* mutable_backend();
private:
LocalService* local_service_;
};

View File

@ -35,8 +35,7 @@ std::vector<std::pair<int64, int64>> MakePadding(
return low_high_padding;
case Padding::kSame:
for (tensorflow::gtl::ArraySlice<int64>::size_type i = 0;
i < input_dimensions.size(); ++i) {
for (size_t i = 0; i < input_dimensions.size(); ++i) {
int64 input_dimension = input_dimensions[i];
int64 window_dimension = window_dimensions[i];
int64 window_stride = window_strides[i];

View File

@ -32,8 +32,7 @@ namespace xla {
// Padding and nested layouts not supported yet.
DCHECK_EQ(0, shape.layout().padded_dimensions_size());
for (tensorflow::gtl::ArraySlice<int64>::size_type i = 0;
i < multi_index.size(); ++i) {
for (size_t i = 0; i < multi_index.size(); ++i) {
DCHECK_GE(multi_index[i], 0);
DCHECK_LT(multi_index[i], shape.dimensions(i))
<< "indexing beyond extent in dimension " << i << ":"

View File

@ -1133,6 +1133,22 @@ cc_library(
],
)
cc_library(
name = "hlo_verifier",
srcs = ["hlo_verifier.cc"],
hdrs = ["hlo_verifier.h"],
deps = [
":hlo",
":hlo_pass",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
],
)
cc_library(
name = "hlo_rematerialization",
srcs = ["hlo_rematerialization.cc"],

View File

@ -896,7 +896,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) {
HloInstruction* zero = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
PaddingConfig no_padding;
for (auto i = 0; i < 2; ++i) {
for (int i = 0; i < 2; ++i) {
auto dimension = no_padding.add_dimensions();
dimension->set_edge_padding_low(0);
dimension->set_edge_padding_high(0);
@ -926,7 +926,7 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) {
PaddingConfig padding;
int64 low_padding[2] = {-1, -2};
int64 high_padding[2] = {2, -3};
for (auto i = 0; i < 2; ++i) {
for (int i = 0; i < 2; ++i) {
auto dimension = padding.add_dimensions();
dimension->set_edge_padding_low(low_padding[i]);
dimension->set_edge_padding_high(high_padding[i]);

View File

@ -138,8 +138,7 @@ tensorflow::Status AllocationTracker::DeallocateShape(
TF_RET_CHECK(ShapeUtil::TupleElementCount(shape) == elements.size())
<< "tuple has unexpected number of elements: " << elements.size()
<< " != " << ShapeUtil::TupleElementCount(shape);
for (std::vector<se::DeviceMemoryBase>::size_type i = 0;
i < elements.size(); ++i) {
for (size_t i = 0; i < elements.size(); ++i) {
VLOG(2) << "recursing onto the tuple elements";
TF_RETURN_IF_ERROR(DeallocateShape(backend, device_ordinal, &elements[i],
shape.tuple_shapes(i),

View File

@ -212,6 +212,13 @@ StatusOr<BufferAllocation::Slice> BufferAssignment::GetUniqueTopLevelSlice(
return GetUniqueSlice(instruction, /*index=*/{});
}
bool BufferAssignment::SharesSliceAtIndex(
const HloInstruction* hlo_a, const ShapeIndex& shape_index_a,
const HloInstruction* hlo_b, const ShapeIndex& shape_index_b) const {
return GetUniqueSlice(hlo_a, shape_index_a).ConsumeValueOrDie() ==
GetUniqueSlice(hlo_b, shape_index_b).ConsumeValueOrDie();
}
StatusOr<BufferAllocation::Slice>
BufferAssignment::GetUniqueTopLevelOutputSlice() const {
return GetUniqueTopLevelSlice(

View File

@ -294,6 +294,15 @@ class BufferAssignment {
return GetPointsToSet(instruction).element(index);
}
// Returns true if 'hlo_a{shape_index_a}' and 'hlo_b{shape_index_b}'
// share the same BufferAllocation::Slice.
// Returns false otherwise.
// REQUIRES: BufferAssignment assigned allocations to both instructions.
bool SharesSliceAtIndex(const HloInstruction* hlo_a,
const ShapeIndex& shape_index_a,
const HloInstruction* hlo_b,
const ShapeIndex& shape_index_b) const;
// Returns the underlying points-to analysis used for this assignment.
const TuplePointsToAnalysis& points_to_analysis() const {
return liveness_->points_to_analysis();

View File

@ -121,6 +121,7 @@ bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a,
// *) Is element-wise.
// *) Is a loop fusion instruction (with DynamicUpdateSlice fused root) where
// the singleton use of 'a' at 'a.index' is the fused root at operand 0.
// *) Use of 'operand' is DynamicUpdateSlice at operand index 0.
for (const BufferAlias& alias : points_to_analysis_->GetBufferAliases(a)) {
if (b.instruction()->IsUserOf(alias.instruction()) &&
!CanShareOperandBufferWithUser(alias.instruction(), alias.index(),

View File

@ -612,6 +612,93 @@ TEST_F(FusedDynamicUpdateSliceLivenessTest, WithInterference) {
EXPECT_TRUE(Run(/*update_uses_tuple_element1=*/true));
}
class DynamicUpdateSliceLivenessTest : public BufferLivenessTest {
protected:
// Builds and runs a computation (see test case computation graphs below).
// Runs BufferLiveness on this computation.
// Returns whether buffer interference is detected between tuple-shaped
// parameter and root instructions at tuple element 1.
bool Run(const bool tuple_element1_has_two_uses) {
auto builder = HloComputation::Builder(TestName());
// Create param0 Tuple.
Shape data_shape = ShapeUtil::MakeShape(F32, {8});
Shape update_shape = ShapeUtil::MakeShape(F32, {3});
auto tuple_param0 = builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "param0"));
auto gte0 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 0));
auto gte1 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 1));
auto update = builder.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
if (tuple_element1_has_two_uses) {
// Add 'gte0' and 'gte1' to create another user of 'gte1'.
gte0 = builder.AddInstruction(HloInstruction::CreateBinary(
data_shape, HloOpcode::kAdd, gte0, gte1));
}
// Create a DynamicUpdateSlice instruction of tuple element 1 with 'update'.
auto starts = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({2})));
auto dynamic_update_slice =
builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
data_shape, gte1, update, starts));
// Create output tuple.
auto tuple_root = builder.AddInstruction(
HloInstruction::CreateTuple({gte0, dynamic_update_slice}));
// Build module and get reference to entry computation.
auto module = MakeUnique<HloModule>(TestName());
module->AddEntryComputation(builder.Build());
// Run BufferLiveness on 'module'.
auto liveness =
BufferLiveness::Run(module.get(),
MakeUnique<DependencyHloOrdering>(module.get()))
.ConsumeValueOrDie();
// Return whether or not buffers interfernce is detected between
// 'tuple_param0' and 'tuple_root' at shape index '{1}'.
return TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1});
}
};
// Tests that live ranges of buffers Param0[1] and Tuple[1] do not overlap in
// the following computation (because DynamicUpdateSlice (at operand 0) is the
// unique user):
//
// Parameter0
// | |
// GTE(0) GTE(1) Const Const
// | \ | /
// | DynamicUpdateSlice
// \ /
// Tuple
//
TEST_F(DynamicUpdateSliceLivenessTest, NoInterference) {
EXPECT_FALSE(Run(/*tuple_element1_has_two_uses=*/false));
}
// Tests that live ranges of buffers Param0[1] and Tuple[1] do overlap because
// GTE(1) has two users:
// 1) DynamicUpdateSlice at operand 0.
// 2) Add at operand 1.
//
// Parameter0
// | |
// GTE(0) GTE(1)
// | / |
// | / |
// Add | Const Const
// | | | |
// | DynamicUpdateSlice
// \ /
// Tuple
//
TEST_F(DynamicUpdateSliceLivenessTest, WithInterference) {
EXPECT_TRUE(Run(/*tuple_element1_has_two_uses=*/true));
}
} // namespace
} // namespace xla

View File

@ -110,8 +110,7 @@ class Compiler {
// The compiler may optionally specialize to the individual device
// (not just type of device) indicated by the executor.
//
// TODO(leary) will need to update this API when a single computation can run
// across multiple devices simultaneously.
// Use the overload below to compile computations that run in parallel.
virtual StatusOr<std::unique_ptr<Executable>> Compile(
std::unique_ptr<HloModule> module,
std::unique_ptr<HloModuleConfig> module_config, HloDumper dump_hlo,

View File

@ -69,6 +69,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/compiler/xla/service:hlo_pass_pipeline",
"//tensorflow/compiler/xla/service:hlo_subcomputation_unification",
"//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/compiler/xla/service:inliner",
"//tensorflow/compiler/xla/service:reshape_mover",
"//tensorflow/compiler/xla/service:transpose_folding",

View File

@ -68,6 +68,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include "tensorflow/compiler/xla/service/inliner.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/service/reshape_mover.h"
@ -214,6 +215,7 @@ Status CpuCompiler::RunHloPasses(HloModule* hlo_module,
HloDumper dump_hlo) {
// Optimization pipeline.
HloPassPipeline pipeline("CPU", dump_hlo);
pipeline.AddInvariantChecker<HloVerifier>();
// TODO(b/35786417): Re-enable inliner pass after fixing the bug and deciding
// where we will take this pass in future.
@ -573,8 +575,7 @@ CpuCompiler::CompileAheadOfTime(
}
std::vector<std::unique_ptr<AotCompilationResult>> results;
for (std::vector<std::unique_ptr<HloModule>>::size_type i = 0;
i < hlo_modules.size(); ++i) {
for (size_t i = 0; i < hlo_modules.size(); ++i) {
HloModule* hlo_module = hlo_modules[i].get();
HloModuleConfig* module_config = module_configs[i].get();

View File

@ -24,8 +24,9 @@ namespace cpu {
class CpuInstructionFusion : public InstructionFusion {
public:
CpuInstructionFusion() {}
~CpuInstructionFusion() override {}
CpuInstructionFusion()
: InstructionFusion(CpuInstructionFusion::IsExpensive) {}
~CpuInstructionFusion() override = default;
protected:
bool ShouldFuse(HloInstruction* consumer, int64 operand_index) override;

View File

@ -1111,7 +1111,7 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce, HloInstruction* arg,
llvm_ir::IrArray::Index input_index = reduced_dims_index;
llvm_ir::IrArray::Index::const_iterator it = index.begin();
for (auto i = 0; i < input_index.size(); ++i) {
for (size_t i = 0; i < input_index.size(); ++i) {
if (input_index[i] == nullptr) {
input_index[i] = *it++;
}
@ -1180,7 +1180,7 @@ Status IrEmitter::HandlePad(HloInstruction* pad) {
// output_index := edge_padding_low + operand_index * (interior_padding + 1)
const PaddingConfig& padding_config = pad->padding_config();
llvm_ir::IrArray::Index output_index;
for (auto i = 0; i < operand_index.size(); ++i) {
for (size_t i = 0; i < operand_index.size(); ++i) {
llvm::Value* offset = ir_builder_.CreateMul(
operand_index[i],
ir_builder_.getInt64(padding_config.dimensions(i).interior_padding() +
@ -1294,12 +1294,12 @@ Status IrEmitter::HandleCustomCall(
llvm_ir::EmitAllocaAtFunctionEntryWithCount(
i8_ptr_type, ir_builder_.getInt32(operands.size()),
"cc_operands_alloca", &ir_builder_);
for (auto i = 0; i < operands.size(); ++i) {
for (size_t i = 0; i < operands.size(); ++i) {
const HloInstruction* operand = operands[i];
llvm::Value* operand_as_i8ptr =
ir_builder_.CreatePointerCast(GetEmittedValueFor(operand), i8_ptr_type);
llvm::Value* slot_in_operands_alloca = ir_builder_.CreateInBoundsGEP(
operands_alloca, {ir_builder_.getInt32(i)});
operands_alloca, {ir_builder_.getInt64(i)});
ir_builder_.CreateStore(operand_as_i8ptr, slot_in_operands_alloca);
}
auto* custom_call_ir_function =
@ -1659,13 +1659,13 @@ void IrEmitter::EmitArrayFunctionCallInto(
ir_builder_.getInt32(parameter_addresses.size()),
tensorflow::strings::StrCat(name, "_parameter_addresses"),
&ir_builder_);
for (auto i = 0; i < parameter_addresses.size(); ++i) {
for (size_t i = 0; i < parameter_addresses.size(); ++i) {
llvm::Value* parameter_as_i8ptr = ir_builder_.CreateBitCast(
parameter_addresses[i], ir_builder_.getInt8PtrTy(),
llvm_ir::AsStringRef(tensorflow::strings::StrCat(name, "_parameter_", i,
"_address_as_i8ptr")));
llvm::Value* slot_in_param_adresses = ir_builder_.CreateInBoundsGEP(
parameter_addresses_buffer, {ir_builder_.getInt32(i)});
parameter_addresses_buffer, {ir_builder_.getInt64(i)});
ir_builder_.CreateStore(parameter_as_i8ptr, slot_in_param_adresses);
}

View File

@ -97,77 +97,81 @@ static void MarkLiveAddressesInOutput(
}
}
StatusOr<perftools::gputools::DeviceMemoryBase>
ParallelCpuExecutable::ExecuteOnStream(
const ServiceExecutableRunOptions* run_options,
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments,
HloExecutionProfile* hlo_execution_profile) {
se::Stream* stream = run_options->stream();
DeviceMemoryAllocator* memory_allocator = run_options->allocator();
VLOG(3) << "ExecuteOnStream arg size: " << arguments.size();
if (!arguments.empty()) {
VLOG(3) << "ExecuteOnStream arg[0]: " << arguments.at(0).opaque();
}
// Allocate the temporary buffers required for the computation.
se::StreamExecutor* stream_executor = stream->parent();
int device_ordinal = stream_executor->device_ordinal();
int64 buffer_count = assignment_->Allocations().size();
VLOG(3) << "temp buffer count: " << buffer_count;
std::vector<se::DeviceMemoryBase> device_allocations;
for (BufferAllocation::Index i = 0; i < buffer_count; ++i) {
Status ParallelCpuExecutable::AllocateBuffers(
DeviceMemoryAllocator* memory_allocator, int device_ordinal,
std::vector<perftools::gputools::DeviceMemoryBase>* buffers) {
CHECK_EQ(buffers->size(), assignment_->Allocations().size());
VLOG(3) << "Allocating " << assignment_->Allocations().size()
<< " allocations for module " << module().name();
for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size();
++i) {
auto& allocation = assignment_->GetAllocation(i);
VLOG(3) << allocation.ToString();
if (allocation.is_entry_computation_parameter()) {
// Buffers do not need to be allocated for parameters.
device_allocations.push_back(se::DeviceMemoryBase(nullptr));
VLOG(3) << "allocation #" << i << " is a parameter";
continue;
}
if (allocation.is_thread_local()) {
// Buffers do not need to be allocated for thread-local temporaries.
device_allocations.push_back(se::DeviceMemoryBase(nullptr));
VLOG(3) << "buffer #" << i << " is thread-local";
continue;
}
TF_ASSIGN_OR_RETURN(
se::DeviceMemoryBase device_allocation,
memory_allocator->Allocate(device_ordinal, allocation.size()));
int64 buffer_size = allocation.size();
if (!(*buffers)[i].is_null()) {
VLOG(3) << "buffer #" << i
<< " is in the preallocated result ShapedBuffer";
} else {
TF_ASSIGN_OR_RETURN((*buffers)[i], memory_allocator->Allocate(
device_ordinal, buffer_size));
if (VLOG_IS_ON(3)) {
VLOG(3) << "ParallelCpuExecutable allocating " << allocation.size()
<< " bytes for allocation #" << i << " ["
<< device_allocation.opaque() << "]";
std::vector<string> parts;
for (const auto& buffer_offset_size : allocation.assigned_buffers()) {
const LogicalBuffer& buffer = *buffer_offset_size.first;
parts.push_back(tensorflow::strings::StrCat(
buffer.instruction()->parent()->name(), "::", buffer.ToString()));
}
VLOG(3) << " " << tensorflow::str_util::Join(parts, ", ");
VLOG(3) << "buffer #" << i << " allocated " << buffer_size << " bytes ["
<< (*buffers)[i].opaque() << "]";
}
device_allocations.push_back(device_allocation);
// Since the output buffer and all the temporary buffers were written into
// by the JITed code, msan has no way of knowing their memory was
// initialized. Mark them initialized so that msan doesn't flag loads from
// these buffers.
TF_ANNOTATE_MEMORY_IS_INITIALIZED(device_allocation.opaque(),
allocation.size());
TF_ANNOTATE_MEMORY_IS_INITIALIZED((*buffers)[i].opaque(), buffer_size);
}
TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice,
assignment_->GetUniqueTopLevelOutputSlice());
const BufferAllocation::Index result_index = result_slice.index();
VLOG(3) << "result index: " << result_index;
VLOG(3) << "result index: " << result_slice.index();
return Status::OK();
}
Status ParallelCpuExecutable::ExecuteComputeFunctions(
const ExecutableRunOptions* run_options,
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> buffers,
HloExecutionProfile* hlo_execution_profile) {
std::vector<se::DeviceMemoryBase> argument_buffers(arguments.size());
for (int i = 0; i < arguments.size(); ++i) {
TF_RET_CHECK(!ShapeUtil::IsTuple(arguments[i]->shape()));
argument_buffers[i] = arguments[i]->buffer(/*index=*/{});
}
return ExecuteComputeFunctions(run_options, argument_buffers, buffers,
hlo_execution_profile);
}
Status ParallelCpuExecutable::ExecuteComputeFunctions(
const ExecutableRunOptions* run_options,
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments,
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> buffers,
HloExecutionProfile* hlo_execution_profile) {
// Allocate profiling counters for each hlo instruction that we would like to
// profile. Allocate an additional profile counter for the entire
// computation.
std::vector<uint64> profile_counters(hlo_to_profile_idx_.size() + 1);
std::vector<void*> buffer_pointers;
for (auto& device_allocation : device_allocations) {
buffer_pointers.reserve(buffers.size());
for (auto device_allocation : buffers) {
buffer_pointers.push_back(device_allocation.opaque());
}
@ -210,8 +214,7 @@ ParallelCpuExecutable::ExecuteOnStream(
void** temps_array = buffer_pointers.data();
uint64* profile_counters_array = profile_counters.data();
auto* thread_pool =
CHECK_NOTNULL(run_options->run_options().inter_op_thread_pool());
auto* thread_pool = CHECK_NOTNULL(run_options->inter_op_thread_pool());
tensorflow::mutex completion_queue_lock;
tensorflow::condition_variable completion_queue_cv;
std::deque<HloInstruction*> completion_queue;
@ -310,6 +313,42 @@ ParallelCpuExecutable::ExecuteOnStream(
}
}
return Status::OK();
}
StatusOr<perftools::gputools::DeviceMemoryBase>
ParallelCpuExecutable::ExecuteOnStream(
const ServiceExecutableRunOptions* run_options,
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments,
HloExecutionProfile* hlo_execution_profile) {
se::Stream* stream = run_options->stream();
DeviceMemoryAllocator* memory_allocator = run_options->allocator();
VLOG(3) << "ExecuteOnStream arg size: " << arguments.size();
if (!arguments.empty()) {
VLOG(3) << "ExecuteOnStream arg[0]: " << arguments.at(0).opaque();
}
// Allocate the temporary buffers required for the computation.
se::StreamExecutor* stream_executor = stream->parent();
int device_ordinal = stream_executor->device_ordinal();
int64 buffer_count = assignment_->Allocations().size();
VLOG(3) << "temp buffer count: " << buffer_count;
std::vector<se::DeviceMemoryBase> device_allocations(
assignment_->Allocations().size());
TF_RETURN_IF_ERROR(AllocateBuffers(memory_allocator,
stream->parent()->device_ordinal(),
&device_allocations));
TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice,
assignment_->GetUniqueTopLevelOutputSlice());
const BufferAllocation::Index result_index = result_slice.index();
VLOG(3) << "result index: " << result_index;
TF_RETURN_IF_ERROR(ExecuteComputeFunctions(&run_options->run_options(),
arguments, device_allocations,
hlo_execution_profile));
// Mark the buffers that are actually live (used in the output) when the
// computation finishes executing.
std::unordered_set<const void*> marked_addresses;
@ -328,7 +367,7 @@ ParallelCpuExecutable::ExecuteOnStream(
// live because they are referenced by the output of the computation
// and are needed by the service. They will be deallocated by the
// service.
for (auto i = 0; i < device_allocations.size(); ++i) {
for (size_t i = 0; i < device_allocations.size(); ++i) {
auto alloc = device_allocations[i];
if (marked_addresses.count(alloc.opaque()) == 0 &&
alloc.opaque() != nullptr) {
@ -345,8 +384,74 @@ StatusOr<std::unique_ptr<ShapedBuffer>> ParallelCpuExecutable::ExecuteOnStream(
const ServiceExecutableRunOptions* run_options,
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
HloExecutionProfile* hlo_execution_profile) {
return Unimplemented(
"ParallelCpuExecutable not supported yet with LocalService execution");
if (GetRootPointsToSet().IsAmbiguous()) {
return Unimplemented("Points-to set of root instruction is ambiguous");
}
se::Stream* stream = run_options->stream();
DeviceMemoryAllocator* memory_allocator = run_options->allocator();
std::vector<se::DeviceMemoryBase> buffers(assignment_->Allocations().size());
TF_ASSIGN_OR_RETURN(std::unique_ptr<ShapedBuffer> result_buffer,
ShapedBuffer::MakeShapedBuffer(
result_shape(), stream->parent()->platform(),
stream->parent()->device_ordinal()));
TF_RETURN_IF_ERROR(AllocateBuffers(
memory_allocator, stream->parent()->device_ordinal(), &buffers));
TF_RETURN_IF_ERROR(ExecuteComputeFunctions(
&run_options->run_options(), arguments, buffers, hlo_execution_profile));
// Copy DeviceMemoryBase values which contain the array(s) of the result into
// the respective location in ShapedBuffer which is returned to the caller.
std::vector<bool> buffers_in_result(assignment_->Allocations().size(), false);
TF_RETURN_IF_ERROR(
result_buffer->mutable_shape_index_to_buffer_entry()
->ForEachMutableElement(
[&buffers, &buffers_in_result, &result_buffer, this](
const ShapeIndex& index, bool is_leaf, size_t* buffer_entry) {
if (is_leaf) {
const std::vector<const LogicalBuffer*>& sources =
this->GetRootPointsToSet().element(index);
// The points to set is unambiguous so the set should be a
// singleton.
CHECK_EQ(1, sources.size());
const LogicalBuffer* buffer_source = sources[0];
HloInstruction* src = buffer_source->instruction();
// The source for this result buffer can be a nested buffer
// such as a tuple element.
// The source instruction should have a non-parameter buffer
// assigned.
TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice,
this->assignment_->GetUniqueSlice(
src, buffer_source->index()));
CHECK(!slice.allocation()->is_entry_computation_parameter());
const BufferAllocation::Index buffer_index = slice.index();
const se::DeviceMemoryBase& buffer = buffers[buffer_index];
CHECK(!buffer.is_null() || buffer.size() == 0);
*buffer_entry = result_buffer->mutable_buffers()->size();
result_buffer->mutable_buffers()->push_back(buffer);
buffers_in_result[buffer_index] = true;
}
return Status::OK();
}));
// Free all buffers not in the result.
for (size_t i = 0; i < buffers.size(); ++i) {
se::DeviceMemoryBase alloc = buffers[i];
if (!buffers_in_result[i] && !alloc.is_null()) {
VLOG(3) << "CpuExecutable deallocating buffer #" << i << " ["
<< alloc.opaque() << "]";
TF_RETURN_IF_ERROR(memory_allocator->Deallocate(
stream->parent()->device_ordinal(), &alloc));
}
}
return std::move(result_buffer);
}
StatusOr<perftools::gputools::DeviceMemoryBase>
@ -358,5 +463,10 @@ ParallelCpuExecutable::ExecuteAsyncOnStream(
"Asynchronous execution on stream is not yet supported on CPU.");
}
const PointsToSet& ParallelCpuExecutable::GetRootPointsToSet() const {
return assignment_->points_to_analysis().GetPointsToSet(
module().entry_computation()->root_instruction());
}
} // namespace cpu
} // namespace xla

View File

@ -84,6 +84,35 @@ class ParallelCpuExecutable : public Executable {
}
private:
// Allocate buffers required for execution and assign them to the elements of
// "buffers". "buffers" should be sized to the number of buffers in buffer
// assignment. Each vector element corresponds to a particular Index. If
// a vector element already contains a non-null DeviceMemoryBase, then no
// buffer is assigned for this element.
Status AllocateBuffers(
DeviceMemoryAllocator* memory_allocator, int device_ordinal,
std::vector<perftools::gputools::DeviceMemoryBase>* buffers);
// Calls the generated functions in 'function_names_', performing the
// computation with the given arguments using the supplied buffers.
Status ExecuteComputeFunctions(
const ExecutableRunOptions* run_options,
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
arguments,
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
buffers,
HloExecutionProfile* hlo_execution_profile);
Status ExecuteComputeFunctions(
const ExecutableRunOptions* run_options,
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
buffers,
HloExecutionProfile* hlo_execution_profile);
// Returns the points-to set of the root instruction of the entry
// computation. Uses points-to analysis from buffer assignment.
const PointsToSet& GetRootPointsToSet() const;
// The JIT containing compiled modules.
tensorflow::mutex jit_mutex_;
std::unique_ptr<SimpleOrcJIT> jit_ GUARDED_BY(jit_mutex_);

View File

@ -937,6 +937,68 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
};
case HloOpcode::kRng:
return MakeRngElementGenerator(hlo, operand_to_generator);
case HloOpcode::kPad:
return [=, &operand_to_generator](
const IrArray::Index& padded_index) -> StatusOr<llvm::Value*> {
auto index = padded_index;
llvm::Value* in_bounds = ir_builder_->getTrue();
for (size_t i = 0; i < index.size(); ++i) {
auto index_typed_const = [=](int64 n) {
return llvm::ConstantInt::get(index[i]->getType(), n);
};
const auto& pad_dim = hlo->padding_config().dimensions(i);
index[i] = ir_builder_->CreateSub(
index[i], index_typed_const(pad_dim.edge_padding_low()));
in_bounds = ir_builder_->CreateAnd(
in_bounds,
ir_builder_->CreateICmpSGE(index[i], index_typed_const(0)),
"in_bounds");
in_bounds = ir_builder_->CreateAnd(
in_bounds,
ir_builder_->CreateICmpEQ(
index_typed_const(0),
ir_builder_->CreateURem(
index[i],
index_typed_const(pad_dim.interior_padding() + 1))),
"in_bounds");
index[i] = ir_builder_->CreateSDiv(
index[i], index_typed_const(pad_dim.interior_padding() + 1));
in_bounds = ir_builder_->CreateAnd(
in_bounds,
ir_builder_->CreateICmpSLT(
index[i],
index_typed_const(hlo->operand(0)->shape().dimensions(i))),
"in_bounds");
}
// if (in_bounds) {
// ret_value = operand0[index]; // source
// } else {
// ret_value = *operand1; // padding
// }
llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry(
llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(),
ir_builder_),
"pad_result_addr", ir_builder_);
llvm_ir::LlvmIfData if_data =
llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_);
SetToFirstInsertPoint(if_data.true_block, ir_builder_);
TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
operand_to_generator.at(hlo->operand(0))(index));
ir_builder_->CreateStore(operand_value, ret_value_addr);
SetToFirstInsertPoint(if_data.false_block, ir_builder_);
TF_ASSIGN_OR_RETURN(llvm::Value * padding_value,
operand_to_generator.at(hlo->operand(1))({}));
ir_builder_->CreateStore(padding_value, ret_value_addr);
SetToFirstInsertPoint(if_data.after_block, ir_builder_);
// Don't create phi(operand_value, padding_value) here, because invoking
// operand_to_generator may create new basic blocks, making the parent
// of operand_value or padding_value no longer a predecessor of
// if_data.after_block.
return ir_builder_->CreateLoad(ret_value_addr);
};
default:
return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
return Unimplemented("%s", HloOpcodeString(hlo->opcode()).c_str());

View File

@ -40,8 +40,7 @@ Executable::ExecuteOnStreams(
std::vector<perftools::gputools::DeviceMemoryBase> return_values(
run_options.size());
for (tensorflow::gtl::ArraySlice<const ExecutableRunOptions>::size_type i = 0;
i < run_options.size(); ++i) {
for (size_t i = 0; i < run_options.size(); ++i) {
// We cannot BlockHostUntilDone() on the already-launched executions in case
// of error, since if the executions communicate, the initially launched
// executions may never complete if not all executions are running.

View File

@ -39,9 +39,6 @@ namespace xla {
// A given platform's compiler will produce an Executable -- this is a uniform
// interface that is used for launching compiled programs across platforms.
//
// TODO(leary) will need to extend this to support multiple streams/devices as
// we begin to compile single programs to run on multiple devices.
class Executable {
public:
explicit Executable(std::unique_ptr<HloModule> hlo_module,

View File

@ -118,7 +118,7 @@ GenericTransferManager::ShallowCopyTupleFromDevice(
// Create a DeviceMemoryBase from each void* pointer.
std::vector<se::DeviceMemoryBase> destination;
for (std::vector<void*>::size_type i = 0; i < element_pointers.size(); ++i) {
for (size_t i = 0; i < element_pointers.size(); ++i) {
if (element_pointers[i] == nullptr &&
!ShapeUtil::HasZeroElements(shape.tuple_shapes(i))) {
return FailedPrecondition("tuple contains nullptr at element %lu", i);

View File

@ -356,13 +356,12 @@ cc_library(
srcs = ["fusion_merger.cc"],
hdrs = ["fusion_merger.h"],
deps = [
":instruction_fusion",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_cost_analysis",
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/compiler/xla/service:instruction_fusion",
"//tensorflow/core:lib",
],
)
@ -434,6 +433,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/compiler/xla/service:hlo_pass_pipeline",
"//tensorflow/compiler/xla/service:hlo_subcomputation_unification",
"//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/compiler/xla/service:reshape_mover",
"//tensorflow/compiler/xla/service:transpose_folding",
"//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend",

View File

@ -87,7 +87,8 @@ tensorflow::Status BufferAllocations::TearDown(
const std::set<se::DeviceMemoryBase>& live_addresses,
const BufferAssignment& buffer_assignment) {
// Deallocate temporary buffers.
for (auto i = 0; i < buffer_assignment.Allocations().size(); ++i) {
const int64 num_buffers = buffer_assignment.Allocations().size();
for (BufferAllocation::Index i = 0; i < num_buffers; ++i) {
const BufferAllocation& allocation = buffer_assignment.GetAllocation(i);
se::DeviceMemoryBase buffer_address = GetDeviceAddress(allocation.index());
// Deallocate buffers marked "maybe_live_out" but aren't actually live out,

View File

@ -270,69 +270,6 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator(
const HloInstruction* hlo,
const HloToElementGeneratorMap& operand_to_generator) const {
switch (hlo->opcode()) {
case HloOpcode::kPad:
return [=, &operand_to_generator](
const IrArray::Index& padded_index) -> StatusOr<llvm::Value*> {
auto index = padded_index;
llvm::Value* in_bounds =
llvm::ConstantInt::get(ir_builder_->getInt1Ty(), 1);
for (auto i = 0; i < index.size(); ++i) {
auto index_typed_const = [=](int64 n) {
return llvm::ConstantInt::get(index[i]->getType(), n);
};
const auto& pad_dim = hlo->padding_config().dimensions(i);
index[i] = ir_builder_->CreateSub(
index[i], index_typed_const(pad_dim.edge_padding_low()));
in_bounds = ir_builder_->CreateAnd(
in_bounds,
ir_builder_->CreateICmpSGE(index[i], index_typed_const(0)),
"in_bounds");
in_bounds = ir_builder_->CreateAnd(
in_bounds,
ir_builder_->CreateICmpEQ(
index_typed_const(0),
ir_builder_->CreateURem(
index[i],
index_typed_const(pad_dim.interior_padding() + 1))),
"in_bounds");
index[i] = ir_builder_->CreateSDiv(
index[i], index_typed_const(pad_dim.interior_padding() + 1));
in_bounds = ir_builder_->CreateAnd(
in_bounds,
ir_builder_->CreateICmpSLT(
index[i],
index_typed_const(hlo->operand(0)->shape().dimensions(i))),
"in_bounds");
}
// if (in_bounds) {
// ret_value = operand0[index]; // source
// } else {
// ret_value = *operand1; // padding
// }
llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry(
llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(),
ir_builder_),
"pad_result_addr", ir_builder_);
llvm_ir::LlvmIfData if_data =
llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_);
SetToFirstInsertPoint(if_data.true_block, ir_builder_);
TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
operand_to_generator.at(hlo->operand(0))(index));
ir_builder_->CreateStore(operand_value, ret_value_addr);
SetToFirstInsertPoint(if_data.false_block, ir_builder_);
TF_ASSIGN_OR_RETURN(llvm::Value * padding_value,
operand_to_generator.at(hlo->operand(1))({}));
ir_builder_->CreateStore(padding_value, ret_value_addr);
SetToFirstInsertPoint(if_data.after_block, ir_builder_);
// Don't create phi(operand_value, padding_value) here, because invoking
// operand_to_generator may create new basic blocks, making the parent
// of operand_value or padding_value no longer a predecessor of
// if_data.after_block.
return ir_builder_->CreateLoad(ret_value_addr);
};
case HloOpcode::kMap:
return [=, &operand_to_generator](
const IrArray::Index& index) -> StatusOr<llvm::Value*> {

View File

@ -18,8 +18,8 @@ limitations under the License.
#include <algorithm>
#include <vector>
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
#include "tensorflow/compiler/xla/service/instruction_fusion.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
@ -221,7 +221,7 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) {
fusion->fused_instructions().end(),
[](const std::unique_ptr<HloInstruction>& instruction) {
if (instruction->opcode() != HloOpcode::kParameter &&
IsExpensive(*instruction)) {
GpuInstructionFusion::IsExpensive(*instruction)) {
return false;
}
return true;

View File

@ -51,6 +51,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/service/reshape_mover.h"
#include "tensorflow/compiler/xla/service/transpose_folding.h"
@ -121,6 +122,7 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module,
const se::DeviceDescription& device_desc) {
{
HloPassPipeline pipeline("optimization", dump_hlo);
pipeline.AddInvariantChecker<HloVerifier>();
{
auto& pass = pipeline.AddPass<HloPassFix<HloPassPipeline>>(
"simplification", dump_hlo);
@ -157,6 +159,7 @@ tensorflow::Status PrepareHloModuleForIrEmitting(
// (b/27180329). Therefore, in that case, we set the output to be a copy of
// the parameter.
HloPassPipeline pipeline("GPU-ir-emit-prepare", dump_hlo);
pipeline.AddInvariantChecker<HloVerifier>();
pipeline.AddPass<PadInsertion>();
pipeline.AddPass<GpuLayoutAssignment>(
module_config->mutable_entry_computation_layout());

View File

@ -25,7 +25,7 @@ namespace gpu {
class GpuInstructionFusion : public InstructionFusion {
public:
explicit GpuInstructionFusion(bool may_duplicate)
: InstructionFusion(may_duplicate) {}
: InstructionFusion(GpuInstructionFusion::IsExpensive, may_duplicate) {}
bool ShouldFuse(HloInstruction* consumer, int64 operand_index) override;

View File

@ -513,7 +513,7 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce, HloInstruction* arg,
llvm_ir::IrArray::Index input_index = reduced_dims_index;
llvm_ir::IrArray::Index::const_iterator it = index.begin();
for (auto i = 0; i < input_index.size(); ++i) {
for (size_t i = 0; i < input_index.size(); ++i) {
if (input_index[i] == nullptr) {
input_index[i] = *it++;
}
@ -614,7 +614,7 @@ llvm_ir::IrArray::Index IrEmitter::EmitOperandArrayLoopNest(
llvm_ir::IrArray::Index index =
loop_nest->AddLoopsForShapeOnDimensions(shape, dimensions, name_suffix);
// Verify every dimension except the reduction dimension was set in the index.
for (auto dimension = 0; dimension < index.size(); ++dimension) {
for (size_t dimension = 0; dimension < index.size(); ++dimension) {
if (dimension == reduction_dimension) {
DCHECK_EQ(nullptr, index[dimension]);
} else {

View File

@ -283,14 +283,7 @@ bool CanUpdateDynamicSliceInPlace(const BufferAssignment& assignment,
return false;
}
auto* operand = fusion->operand(fusion_operand->parameter_number());
BufferAllocation::Slice operand_slice =
assignment.GetUniqueSlice(operand, index).ConsumeValueOrDie();
BufferAllocation::Slice fusion_slice =
assignment.GetUniqueTopLevelSlice(fusion).ConsumeValueOrDie();
return operand_slice == fusion_slice;
return assignment.SharesSliceAtIndex(fusion, {}, operand, index);
}
} // namespace
@ -387,9 +380,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
TF_RETURN_IF_ERROR(root->Accept(&fused_emitter));
// Recursively lookup 'fusion_operand' for DynamicUpdateSlice operand 0.
ShapeIndex index_unused;
auto* fusion_operand =
LatestNonGteAncestorAndIndex(root->operand(0), &index_unused);
auto* fusion_operand = LatestNonGteAncestor(root->operand(0));
CHECK_EQ(HloOpcode::kParameter, fusion_operand->opcode());
// Operand(0) the input array which shares an allocation with the output.

View File

@ -79,7 +79,7 @@ ThunkSchedule::ThunkSchedule(
void ThunkSchedule::RemoveRedundantDependencyEdges() {
std::unordered_map<const Thunk*, int> thunk_to_total_order;
for (auto i = 0; i < thunk_total_order_.size(); ++i) {
for (int i = 0; i < thunk_total_order_.size(); ++i) {
InsertOrDie(&thunk_to_total_order, thunk_total_order_[i], i);
}

View File

@ -92,13 +92,22 @@ HloInstruction* HloComputation::AddInstructionInternal(
// Generate a unique name for the instruction.
instruction->set_name(
instruction_name_uniquer_.GetUniqueName(instruction->name()));
instruction->set_parent(this);
Reparent(instruction.get());
HloInstruction* pinst = instruction.get();
instruction_iterators_[pinst] =
instructions_.insert(instructions_.end(), std::move(instruction));
return pinst;
}
void HloComputation::Reparent(HloInstruction* instruction) {
instruction->set_parent(this);
if (instruction->opcode() == HloOpcode::kFusion) {
for (auto& i : instruction->fused_instructions()) {
Reparent(i.get());
}
}
}
/* static */ bool HloComputation::IsRemovable(const HloOpcode& opcode) {
return !(opcode == HloOpcode::kParameter || opcode == HloOpcode::kRecv ||
opcode == HloOpcode::kSend || opcode == HloOpcode::kTrace ||

View File

@ -235,6 +235,14 @@ class HloComputation {
HloInstruction* AddInstructionInternal(
std::unique_ptr<HloInstruction> instruction);
// Helper for setting the parent of instructions that are added to this
// computation.
//
// Because we clone HLO instructions without knowing what computation they're
// destined to be added to, this is required to appropriate set the parent on
// fused instruction sequences.
void Reparent(HloInstruction* instruction);
// Fuses HLOs in instructions_to_fuse into fusion_instruction.
//
// Pre-condition: fusion_instruction's opcode is kFusion.

View File

@ -75,8 +75,7 @@ string InstructionSequenceGraph(
std::vector<HloInstruction*> param_instructions;
for (auto& instruction : instructions) {
if (instruction->opcode() == HloOpcode::kParameter) {
std::vector<HloInstruction*>::size_type param_number =
instruction->parameter_number();
size_t param_number = instruction->parameter_number();
if (param_instructions.size() < param_number + 1) {
param_instructions.resize(param_number + 1, nullptr);

View File

@ -431,6 +431,7 @@ HloInstruction::CreateSelectAndScatter(
const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root) {
auto instruction = WrapUnique(new HloInstruction(HloOpcode::kFusion, shape));
instruction->fusion_kind_ = fusion_kind;
instruction->set_parent(fused_root->parent());
instruction->CloneAndFuseInternal(fused_root);
instruction->CheckFusionInstruction();
return instruction;
@ -568,6 +569,7 @@ HloInstruction* HloInstruction::CloneAndFuseInternal(
std::unique_ptr<HloInstruction> param_instruction =
CreateParameter(param_no, operand->shape(), "fusion_param");
param_instruction->set_parent(parent());
param_instruction->parent_fusion_instruction_ = this;
fused_parameters_.push_back(param_instruction.get());
fused_instructions_.push_back(std::move(param_instruction));
@ -602,6 +604,7 @@ void HloInstruction::CheckFusionInstruction() const {
for (auto& instruction : fused_instructions_) {
CHECK(instruction->IsFused());
CHECK_EQ(this, instruction->fusion_instruction());
CHECK_EQ(parent(), instruction->parent()) << instruction->ToString();
}
// Fused root instruction and fused parameters must all be owned by the fusion
@ -838,12 +841,14 @@ std::unique_ptr<HloInstruction> HloInstruction::Clone(const string& suffix) {
std::unique_ptr<HloInstruction> clone =
CloneWithNewOperands(shape_, operands_);
clone->name_ = name() + "." + suffix;
clone->set_parent(parent());
return clone;
}
std::unique_ptr<HloInstruction> HloInstruction::CloneFusionWithNewOperands(
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
CHECK_EQ(opcode_, HloOpcode::kFusion);
CHECK(parent() != nullptr);
auto new_instruction =
WrapUnique(new HloInstruction(HloOpcode::kFusion, shape));
@ -883,6 +888,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneFusionWithNewOperands(
old_fused_instruction->CloneWithNewOperands(
old_fused_instruction->shape(), new_operands));
HloInstruction* new_fused_instruction = new_fused_instructions.back().get();
new_fused_instruction->set_parent(parent());
new_fused_instruction->parent_fusion_instruction_ = new_instruction.get();
InsertOrDie(&old_to_new, old_fused_instruction, new_fused_instruction);
}
@ -893,6 +899,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneFusionWithNewOperands(
new_instruction->fused_instructions_ = std::move(new_fused_instructions);
new_instruction->fused_parameters_ = std::move(new_fused_parameters);
new_instruction->fused_root_ = FindOrDie(old_to_new, fused_root_);
new_instruction->set_parent(parent());
new_instruction->CheckFusionInstruction();
return new_instruction;
}

View File

@ -538,6 +538,9 @@ class HloInstruction {
// instruction. The order is a reverse postorder of the fused expression (root
// is first in the order).
//
// Note: although the list itself is const, the instructions contained in the
// list returned here are mutable.
//
// Precondition: opcode() == HloOpcode::kFusion
const std::list<std::unique_ptr<HloInstruction>>& fused_instructions() const;

View File

@ -26,6 +26,8 @@ limitations under the License.
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
using ::tensorflow::strings::StrAppend;
namespace xla {
namespace {
@ -44,6 +46,14 @@ StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
tensorflow::str_util::Split(flags->xla_disable_hlo_passes, ',');
tensorflow::gtl::FlatSet<string> disabled_passes(tmp.begin(), tmp.end());
auto run_invariant_checkers = [this, module]() -> Status {
for (auto& invariant_checker : invariant_checkers_) {
TF_ASSIGN_OR_RETURN(bool changed, invariant_checker->Run(module));
TF_RET_CHECK(!changed) << "invariant checkers must not change the graph";
}
return Status::OK();
};
string prefix = name().ToString() + ": pipeline start";
bool changed = false;
string message;
@ -55,15 +65,17 @@ StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
// Emit label containing: "after foo-pass, before bar-pass".
message.clear();
tensorflow::strings::StrAppend(&message, prefix, ", before ", pass->name());
StrAppend(&message, prefix, ", before ", pass->name());
DumpModule(dumper_, *module, message);
TF_RETURN_IF_ERROR(run_invariant_checkers());
TF_ASSIGN_OR_RETURN(bool changed_this_pass, pass->Run(module));
changed |= changed_this_pass;
prefix.clear();
tensorflow::strings::StrAppend(&prefix, name(), ": after ", pass->name());
StrAppend(&prefix, name(), ": after ", pass->name());
}
TF_RETURN_IF_ERROR(run_invariant_checkers());
DumpModule(dumper_, *module, prefix + ", pipeline end");
return changed;
}

View File

@ -52,6 +52,16 @@ class HloPassPipeline : public HloPassInterface {
return *pass;
}
// Add an invariant-checking pass to the pipeline. It will be run before and
// after each HLO pass. The invariant checking pass must not mutate the graph
// (it is required to always return "false" from its Run() method).
template <typename T, typename... Args>
T& AddInvariantChecker(Args&&... args) {
auto pass = new T(std::forward<Args>(args)...);
invariant_checkers_.push_back(std::unique_ptr<T>(pass));
return *pass;
}
// Run all passes on the given HLO module.
StatusOr<bool> Run(HloModule* module) override;
@ -59,6 +69,7 @@ class HloPassPipeline : public HloPassInterface {
const string name_;
Compiler::HloDumper dumper_;
std::vector<std::unique_ptr<HloPassInterface>> passes_;
std::vector<std::unique_ptr<HloPassInterface>> invariant_checkers_;
TF_DISALLOW_COPY_AND_ASSIGN(HloPassPipeline);
};

View File

@ -0,0 +1,38 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
namespace xla {
StatusOr<bool> HloVerifier::Run(HloModule* module) {
for (auto& computation : module->computations()) {
for (const auto& instruction : computation->instructions()) {
TF_RET_CHECK(instruction->parent() == computation.get());
if (instruction->opcode() == HloOpcode::kFusion) {
for (const auto& fused : instruction->fused_instructions()) {
TF_RET_CHECK(fused->parent() == computation.get())
<< "Fused HLO was missing a parent: " << fused->ToString()
<< " parent: " << fused->parent()
<< " computation: " << computation.get();
}
}
}
}
return false;
}
} // namespace xla

View File

@ -0,0 +1,37 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_
#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
namespace xla {
// HLO pass that verifies invariants of HLO instructions for each computation in
// the module.
class HloVerifier : public HloPassInterface {
public:
~HloVerifier() override = default;
tensorflow::StringPiece name() const override { return "verifier"; }
// Note: always returns false (no instructions are ever modified by this
// pass).
StatusOr<bool> Run(HloModule* module) override;
};
} // namespace xla
#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_

View File

@ -29,7 +29,8 @@ limitations under the License.
namespace xla {
bool IsExpensive(const HloInstruction& instruction) {
/*static*/ bool InstructionFusion::IsExpensive(
const HloInstruction& instruction) {
switch (instruction.opcode()) {
// Cheap instructions.
case HloOpcode::kAbs:
@ -105,9 +106,14 @@ bool IsExpensive(const HloInstruction& instruction) {
return false;
}
bool FusionWouldDuplicate(HloInstruction* producer, HloInstruction* consumer) {
return !(producer->users().size() == 1 && consumer->IsUserOf(producer));
namespace {
// Returns true if fusing producer into consumer would cause producer to be
// duplicated. This is the case if producer has uses other than consumer.
bool FusionWouldDuplicate(const HloInstruction& producer,
const HloInstruction& consumer) {
return !(producer.users().size() == 1 && consumer.IsUserOf(&producer));
}
} // namespace
StatusOr<bool> InstructionFusion::Run(HloModule* module) {
bool changed = false;
@ -125,8 +131,7 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) {
std::vector<HloInstruction*> post_order(post_order_list.begin(),
post_order_list.end());
tensorflow::gtl::FlatMap<HloInstruction*, int> post_order_index;
for (std::vector<HloInstruction*>::size_type i = 0; i < post_order.size();
++i) {
for (size_t i = 0; i < post_order.size(); ++i) {
InsertOrDie(&post_order_index, post_order[i], i);
}
@ -263,8 +268,8 @@ bool InstructionFusion::ShouldFuse(HloInstruction* consumer,
int64 operand_index) {
HloInstruction* producer = consumer->mutable_operand(operand_index);
// Cost condition: don't duplicate expensive instructions.
if (FusionWouldDuplicate(producer, consumer) &&
(IsExpensive(*producer) || !may_duplicate_)) {
if (FusionWouldDuplicate(*producer, *consumer) &&
(is_expensive_(*producer) || !may_duplicate_)) {
return false;
}
@ -277,7 +282,7 @@ bool InstructionFusion::ShouldFuse(HloInstruction* consumer,
// Cost condition: not fuse (expensive producers) and (consumers who reuse
// operand elements).
if (consumer->ReusesOperandElements(operand_index) &&
IsExpensive(*producer)) {
is_expensive_(*producer)) {
return false;
}

View File

@ -24,15 +24,6 @@ limitations under the License.
namespace xla {
// Returns true if the computation of the given instruction is significantly
// more expensive than just writing all the values of the instructions' result
// array. Expensive operations should not be duplicated.
bool IsExpensive(const HloInstruction& instruction);
// Returns true if fusing producer into consumer would cause producer to be
// duplicated. This is the case if producer has uses other than consumer.
bool FusionWouldDuplicate(HloInstruction* producer, HloInstruction* consumer);
// HLO pass which performs instruction fusion. Instructions are fused
// "vertically", meaning producing instructions are fused into their consumers
// with the intent that the loops which compute their values will be fused in
@ -40,15 +31,22 @@ bool FusionWouldDuplicate(HloInstruction* producer, HloInstruction* consumer);
// instructions to fuse.
class InstructionFusion : public HloPassInterface {
public:
explicit InstructionFusion(bool may_duplicate = true)
: may_duplicate_(may_duplicate) {}
~InstructionFusion() override {}
explicit InstructionFusion(
std::function<bool(const HloInstruction& instruction)> is_expensive,
bool may_duplicate = true)
: is_expensive_(is_expensive), may_duplicate_(may_duplicate) {}
~InstructionFusion() override = default;
tensorflow::StringPiece name() const override { return "fusion"; }
// Run instruction fusion on the given computation. Returns whether the
// computation was changed (instructions were fused).
StatusOr<bool> Run(HloModule* module) override;
// Returns true if the computation of the given instruction is significantly
// more expensive than just writing all the values of the instructions' result
// array. Expensive operations will not be duplicated.
static bool IsExpensive(const HloInstruction& instruction);
protected:
// Returns whether the given producer instruction should be fused into the
// given consumer instruction. producer is necessarily an operand of consumer.
@ -74,6 +72,10 @@ class InstructionFusion : public HloPassInterface {
private:
HloInstruction* Fuse(HloInstruction* producer, HloInstruction* consumer);
// Used to determine if an HLO is expensive. Expensive operations will not be
// duplicated.
std::function<bool(const HloInstruction& instruction)> is_expensive_;
// Returns whether we may duplicate an instruction if we want to fuse it.
bool may_duplicate_;

View File

@ -36,7 +36,9 @@ TEST_F(InstructionFusionTest,
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(broadcast2, computation->root_instruction());
EXPECT_TRUE(
InstructionFusion(/*may_duplicate=*/true).Run(module.get()).ValueOrDie());
InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
.Run(module.get())
.ValueOrDie());
EXPECT_EQ(broadcast2, computation->root_instruction());
}
@ -55,7 +57,9 @@ TEST_F(InstructionFusionTest,
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(broadcast2, computation->root_instruction());
EXPECT_TRUE(
InstructionFusion(/*may_duplicate=*/true).Run(module.get()).ValueOrDie());
InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
.Run(module.get())
.ValueOrDie());
EXPECT_EQ(HloOpcode::kFusion, computation->root_instruction()->opcode());
}
@ -73,7 +77,9 @@ TEST_F(InstructionFusionTest,
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(reshape2, computation->root_instruction());
EXPECT_TRUE(
InstructionFusion(/*may_duplicate=*/true).Run(module.get()).ValueOrDie());
InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
.Run(module.get())
.ValueOrDie());
EXPECT_EQ(HloOpcode::kFusion, computation->root_instruction()->opcode());
}
@ -91,7 +97,9 @@ TEST_F(InstructionFusionTest,
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(transpose2, computation->root_instruction());
EXPECT_TRUE(
InstructionFusion(/*may_duplicate=*/true).Run(module.get()).ValueOrDie());
InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
.Run(module.get())
.ValueOrDie());
EXPECT_EQ(HloOpcode::kFusion, computation->root_instruction()->opcode());
}
@ -106,7 +114,9 @@ TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfParameterUnfused) {
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(reshape1, computation->root_instruction());
EXPECT_FALSE(
InstructionFusion(/*may_duplicate=*/true).Run(module.get()).ValueOrDie());
InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
.Run(module.get())
.ValueOrDie());
}
TEST_F(InstructionFusionTest, PotentialBitcastSimpleReshapeOfParameterUnfused) {
@ -120,7 +130,9 @@ TEST_F(InstructionFusionTest, PotentialBitcastSimpleReshapeOfParameterUnfused) {
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(reshape1, computation->root_instruction());
EXPECT_FALSE(
InstructionFusion(/*may_duplicate=*/true).Run(module.get()).ValueOrDie());
InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
.Run(module.get())
.ValueOrDie());
}
TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfParameterUnfused) {
@ -134,7 +146,9 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfParameterUnfused) {
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(transpose1, computation->root_instruction());
EXPECT_FALSE(
InstructionFusion(/*may_duplicate=*/true).Run(module.get()).ValueOrDie());
InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
.Run(module.get())
.ValueOrDie());
}
} // namespace xla

View File

@ -106,6 +106,7 @@ std::vector<std::pair<HloInstruction*, int64>> GetAllUsesOfInstructionAtIndex(
// *) Is a loop fusion instruction where the only use of 'operand' at 'index'
// in the set 'user.fused_instructions' is a DynamicUpdateSlice fused root
// at operand 0.
// *) Use of 'operand' is DynamicUpdateSlice at operand index 0.
bool CanShareOperandBufferWithUser(
HloInstruction* operand, const ShapeIndex& operand_index,
HloInstruction* user, const ShapeIndex& user_index,
@ -143,6 +144,11 @@ bool CanShareOperandBufferWithUser(
break;
}
return false;
} else if (user->opcode() == HloOpcode::kDynamicUpdateSlice) {
// We eliminated other users in BufferLiveness::live_range_strictly_before,
// so here we just need to check that the use is at operand index 0.
std::vector<int64> operand_indices = user->OperandIndices(operand);
return operand_indices.size() == 1 && operand_indices[0] == 0;
}
// Check if 'user' is element-wise.
return user->IsElementwise();

View File

@ -256,8 +256,7 @@ StatusOr<std::vector<const Allocation*>> Service::ResolveAndValidateArguments(
tensorflow::gtl::ArraySlice<const GlobalDataHandle*> arguments,
const Backend* backend, int device_ordinal) {
std::vector<const Allocation*> allocations;
for (tensorflow::gtl::ArraySlice<const GlobalDataHandle*>::size_type i = 0;
i < arguments.size(); ++i) {
for (size_t i = 0; i < arguments.size(); ++i) {
auto allocation_status = allocation_tracker_.Resolve(*arguments[i]);
if (!allocation_status.ok()) {
return Status(allocation_status.status().code(),
@ -296,8 +295,7 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
program_shape.parameters_size(), arguments.size());
}
for (tensorflow::gtl::ArraySlice<const Allocation*>::size_type i = 0;
i < arguments.size(); ++i) {
for (size_t i = 0; i < arguments.size(); ++i) {
// Verify that shape of arguments matches the shape of the arguments in the
// ProgramShape.
if (!ShapeUtil::Compatible(arguments[i]->shape(),
@ -385,8 +383,7 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
hlo_dumper, std::move(executors)));
if (!other_directory_path.empty()) {
for (std::vector<VersionedComputationHandle>::size_type i = 0;
i < versioned_handles.size(); ++i) {
for (size_t i = 0; i < versioned_handles.size(); ++i) {
executables[i]->set_session_module(std::move(session_modules[i]));
}
}

View File

@ -598,7 +598,9 @@ class FusionPointsToAnalysisTest : public TuplePointsToAnalysisTest {
// Build computation and add it to module as entry computation.
BuildModule(builder.Build());
// Run instruction fusion HloPass.
EXPECT_TRUE(InstructionFusion().Run(module_.get()).ValueOrDie());
EXPECT_TRUE(InstructionFusion(InstructionFusion::IsExpensive)
.Run(module_.get())
.ValueOrDie());
// Get computation root instruction (should be a kFusion).
auto* fusion = module_->entry_computation()->root_instruction();
EXPECT_EQ(HloOpcode::kFusion, fusion->opcode());

View File

@ -56,6 +56,8 @@ class ClientLibraryTestBase : public ::testing::Test {
execution_options_.set_disable_fast_math(disabled);
}
void SetSeed(uint64 seed) { execution_options_.set_seed(seed); }
// TODO(b/25566808): Add helper that populates a literal from a testdata file.
// Convenience methods for building and running a computation from a builder.

View File

@ -187,8 +187,13 @@ ExecutableBuildOptions LocalClientTestBase::DefaultExecutableBuildOptions()
}
ExecutableRunOptions LocalClientTestBase::DefaultExecutableRunOptions() const {
return ExecutableRunOptions().set_allocator(
GetOrCreateAllocator(local_client_->platform()));
ExecutableRunOptions run_options;
run_options.set_inter_op_thread_pool(
local_client_->backend().inter_op_thread_pool());
run_options.set_intra_op_thread_pool(
local_client_->backend().eigen_intra_op_thread_pool_device());
run_options.set_allocator(GetOrCreateAllocator(local_client_->platform()));
return run_options;
}
std::unique_ptr<ScopedShapedBuffer> LocalClientTestBase::ExecuteLocallyOrDie(

View File

@ -53,6 +53,7 @@ void PrngTest::UniformTest(T a, T b, tensorflow::gtl::ArraySlice<int64> dims) {
builder.ConstantR0<T>(a), builder.ConstantR0<T>(b),
ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<T>(), dims));
SetSeed(42);
auto actual = ExecuteAndTransferOrDie(&builder, /*arguments=*/{});
EXPECT_TRUE(ContainersEqual(dims, actual->shape().dimensions()));
LiteralUtil::EachCell<T>(*actual,
@ -118,6 +119,7 @@ double PrngTest::UniformChiSquared(int32 range_size, int32 expected_count) {
builder.ConstantR0<int32>(range_size),
ShapeUtil::MakeShape(S32, {sample_size}));
SetSeed(42);
auto actual = ExecuteAndTransferOrDie(&builder, /*arguments=*/{});
std::vector<int32> counts(range_size, 0);
LiteralUtil::EachCell<int32>(
@ -264,6 +266,7 @@ XLA_TEST_F(PrngTest, TenValuesN01) {
builder.RngNormal(builder.ConstantR0<float>(0), builder.ConstantR0<float>(1),
ShapeUtil::MakeShape(F32, {10}));
SetSeed(42);
ExecuteAndTransferOrDie(&builder, /*arguments=*/{});
// TODO(b/25995601): Test that resultant values are reasonable
}

View File

@ -91,6 +91,7 @@ cc_library(
"//tensorflow/contrib/input_pipeline:input_pipeline_ops_op_lib",
"//tensorflow/contrib/layers:bucketization_op_op_lib",
"//tensorflow/contrib/layers:sparse_feature_cross_op_op_lib",
"//tensorflow/contrib/tensor_forest:tensor_forest_ops_op_lib",
],
)

View File

@ -37,7 +37,7 @@ tf_kernel_library(
srcs = [
"bigquery_reader_ops.cc",
],
visibility = ["//tensorflow:__subpackages__"],
visibility = ["//visibility:public"],
deps = [
":bigquery_table_accessor",
":bigquery_table_partition_proto_cc",

View File

@ -419,7 +419,7 @@ class DirichletMultinomialTest(test.TestCase):
with self.test_session() as sess:
dist = ds.DirichletMultinomial(
total_count=5.,
concentration=2. * self._rng.rand(4, 3, 2).astype(np.float32))
concentration=1. + 2. * self._rng.rand(4, 3, 2).astype(np.float32))
n = int(3e3)
x = dist.sample(n, seed=0)
sample_mean = math_ops.reduce_mean(x, 0)
@ -448,7 +448,7 @@ class DirichletMultinomialTest(test.TestCase):
with self.test_session() as sess:
dist = ds.DirichletMultinomial(
total_count=5.,
concentration=2. * self._rng.rand(4).astype(np.float32))
concentration=1. + 2. * self._rng.rand(4).astype(np.float32))
n = int(5e3)
x = dist.sample(n, seed=0)
sample_mean = math_ops.reduce_mean(x, 0)

View File

@ -21,6 +21,7 @@ from __future__ import print_function
import numpy as np
from scipy import stats
from tensorflow.contrib import distributions
from tensorflow.contrib.distributions.python.ops import bijectors
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import nn_ops
@ -50,6 +51,22 @@ class MultivariateNormalDiagTest(test.TestCase):
dist = ds.MultivariateNormalDiag(mu, diag, validate_args=True)
self.assertAllEqual([3, 1], dist.sample(3).get_shape())
def testDistWithBatchShapeOneThenTransformedThroughSoftplus(self):
# This complex combination of events resulted in a loss of static shape
# information when tensor_util.constant_value(self._needs_rotation) was
# being used incorrectly (resulting in always rotating).
# Batch shape = [1], event shape = [3]
mu = array_ops.zeros((1, 3))
diag = array_ops.ones((1, 3))
with self.test_session():
base_dist = ds.MultivariateNormalDiag(mu, diag, validate_args=True)
dist = ds.TransformedDistribution(
base_dist,
validate_args=True,
bijector=bijectors.Softplus(event_ndims=1))
samps = dist.sample(5) # Shape [5, 1, 3].
self.assertAllEqual([5, 1], dist.log_prob(samps).get_shape())
def testMean(self):
mu = [-1., 1]
diag = [1., -5]

View File

@ -546,7 +546,8 @@ class TransformedDistribution(distributions.Distribution):
def _maybe_rotate_dims(self, x, rotate_right=False):
"""Helper which rolls left event_dims left or right event_dims right."""
if tensor_util.constant_value(self._needs_rotation) is False:
needs_rotation_const = tensor_util.constant_value(self._needs_rotation)
if needs_rotation_const is not None and not needs_rotation_const:
return x
ndims = array_ops.rank(x)
n = (ndims - self._rotate_ndims) if rotate_right else self._rotate_ndims

View File

@ -35,6 +35,7 @@ See the @{$python/contrib.layers} guide.
@@relu6
@@repeat
@@safe_embedding_lookup_sparse
@@scale_gradient
@@separable_conv2d
@@separable_convolution2d
@@softmax
@ -68,6 +69,7 @@ See the @{$python/contrib.layers} guide.
@@embedding_column
@@scattered_embedding_column
@@input_from_feature_columns
@@transform_features
@@joint_weighted_sum_from_feature_columns
@@make_place_holder_tensors_for_base_features
@@multi_class_target

View File

@ -1248,21 +1248,29 @@ def scattered_embedding_column(column_name,
initializer=None):
"""Creates an embedding column of a sparse feature using parameter hashing.
The i-th embedding component of a value v is found by retrieving an
embedding weight whose index is a fingerprint of the pair (v,i).
This is a useful shorthand when you have a sparse feature you want to use an
embedding for, but also want to hash the embedding's values in each dimension
to a variable based on a different hash.
Specifically, the i-th embedding component of a value v is found by retrieving
an embedding weight whose index is a fingerprint of the pair (v,i).
An embedding column with sparse_column_with_hash_bucket such as
embedding_column(
embedding_column(
sparse_column_with_hash_bucket(column_name, bucket_size),
dimension)
could be replaced by
scattered_embedding_column(
column_name, size=bucket_size * dimension, dimension=dimension,
scattered_embedding_column(
column_name,
size=bucket_size * dimension,
dimension=dimension,
hash_key=tf.contrib.layers.SPARSE_FEATURE_CROSS_DEFAULT_HASH_KEY)
for the same number of embedding parameters and hopefully reduced impact of
collisions with a cost of slowing down training.
for the same number of embedding parameters. This should hopefully reduce the
impact of collisions, but adds the cost of slowing down training.
Args:
column_name: A string defining sparse column name.

View File

@ -144,6 +144,7 @@ def _input_from_feature_columns(columns_to_tensors,
output_rank,
default_name):
"""Implementation of `input_from(_sequence)_feature_columns`."""
columns_to_tensors = columns_to_tensors.copy()
check_feature_columns(feature_columns)
with variable_scope.variable_scope(scope,
default_name=default_name,
@ -430,6 +431,7 @@ def joint_weighted_sum_from_feature_columns(columns_to_tensors,
ValueError: if FeatureColumn cannot be used for linear predictions.
"""
columns_to_tensors = columns_to_tensors.copy()
check_feature_columns(feature_columns)
with variable_scope.variable_scope(
scope,
@ -518,6 +520,7 @@ def weighted_sum_from_feature_columns(columns_to_tensors,
Raises:
ValueError: if FeatureColumn cannot be used for linear predictions.
"""
columns_to_tensors = columns_to_tensors.copy()
check_feature_columns(feature_columns)
with variable_scope.variable_scope(
scope,
@ -684,8 +687,8 @@ def transform_features(features, feature_columns):
Returns:
A `dict` mapping FeatureColumn to `Tensor` and `SparseTensor` values.
"""
check_feature_columns(feature_columns)
columns_to_tensor = features.copy()
check_feature_columns(feature_columns)
transformer = _Transformer(columns_to_tensor)
for column in sorted(set(feature_columns), key=lambda x: x.key):
transformer.transform(column)

View File

@ -187,27 +187,28 @@ class TransformerTest(test.TestCase):
self.assertAllEqual(output.dense_shape.eval(), [2, 2])
def testEmbeddingColumn(self):
hashed_sparse = feature_column.sparse_column_with_hash_bucket("wire", 10)
wire_tensor = sparse_tensor.SparseTensor(
values=["omar", "stringer", "marlo"],
indices=[[0, 0], [1, 0], [1, 1]],
dense_shape=[2, 2])
features = {"wire": wire_tensor}
output = feature_column_ops._Transformer(features).transform(
feature_column.embedding_column(hashed_sparse, 10))
expected = feature_column_ops._Transformer(features).transform(
hashed_sparse)
with self.test_session():
self.assertAllEqual(output.values.eval(), expected.values.eval())
self.assertAllEqual(output.indices.eval(), expected.indices.eval())
self.assertAllEqual(output.dense_shape.eval(),
expected.dense_shape.eval())
hashed_sparse = feature_column.sparse_column_with_hash_bucket("wire", 10)
wire_embedding = feature_column.embedding_column(hashed_sparse, 10)
# Test transform features.
output = feature_column_ops.transform_features(
features=features, feature_columns=[hashed_sparse])
self.assertEqual(len(output), 1)
features=features, feature_columns=[hashed_sparse, wire_embedding])
# Check that features dict haven't changed
self.assertEqual({"wire": wire_tensor}, features)
self.assertEqual(len(output), 2)
self.assertIn(hashed_sparse, output)
self.assertIn(wire_embedding, output)
with self.test_session():
self.assertAllEqual(output[wire_embedding].indices.eval(),
wire_tensor.indices.eval())
self.assertAllEqual(output[wire_embedding].dense_shape.eval(), [2, 2])
self.assertAllEqual(output[wire_embedding].values.eval(),
output[hashed_sparse].values.eval())
def testSparseColumnWithKeys(self):
keys_sparse = feature_column.sparse_column_with_keys(

View File

@ -28,6 +28,7 @@ from tensorflow.contrib.framework.python.ops import variables
from tensorflow.contrib.layers.python.layers import initializers
from tensorflow.contrib.layers.python.layers import utils
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.layers import convolutional as convolutional_layers
@ -68,6 +69,7 @@ __all__ = ['avg_pool2d',
'relu',
'relu6',
'repeat',
'scale_gradient',
'separable_conv2d',
'separable_convolution2d',
'softmax',
@ -1745,6 +1747,48 @@ def repeat(inputs, repetitions, layer, *args, **kwargs):
return outputs
def _scale_gradient_shape(op):
"""Shape helper function for scale_gradient function below."""
return [op.inputs[0].shape]
def _scale_gradient_grad(op, grad):
"""Python gradient helper function for scale_gradient function below."""
return [grad * op.inputs[1], None]
@function.Defun(python_grad_func=_scale_gradient_grad,
shape_func=_scale_gradient_shape)
def scale_gradient(inputs, gradient_multiplier):
"""Identity operation, but with the gradient multiplied by a tensor.
The TensorFlow gradient system will compute the gradient with respect to
`inputs` as the product of the gradient with respect to the `output`
multiplied by a specified `gradient_multiplier` tensor. If
`gradient_multiplier` is equal to 1, then this results in the true gradient.
Otherwise, it results in a scaled gradient.
This can be useful for adjusting the relative learning rate of different
parameter tensors when performing gradient descent, and because this rescaling
can be inserted at arbitrary locations within a graph, is often more
convenient to apply than simply rescaling the final computed gradients.
Args:
inputs: Tensor to be output.
gradient_multiplier: Tensor by which to multiply the gradient with respect
to `output` to compute the gradient with respect to `inputs`. Its shape
must be broadcastable to the shape of `inputs`.
Returns:
output Tensor, equal to `inputs`.
"""
# gradient_multiplier is implicitly saved by decorator, and only used for
# gradient computation.
del gradient_multiplier
return inputs
@add_arg_scope
def separable_convolution2d(
inputs,

View File

@ -2980,6 +2980,22 @@ class SeparableConv2dTest(test.TestCase):
sess.run(net, feed_dict={images_placeholder: images})
class ScaleGradientTests(test.TestCase):
"""Simple tests of the scale_gradient function."""
def testBasic(self):
with self.test_session():
x = np.array([42], np.float32)
gradient_scale = np.array([2], np.float32)
x = ops.convert_to_tensor(x)
y = layers_lib.scale_gradient(x, gradient_scale)
np.testing.assert_array_equal(x.eval(), y.eval())
g_x, = gradients_impl.gradients(y, [x], [np.array([3], np.float32)])
np.testing.assert_array_equal([3 * 2], g_x.eval())
class SoftmaxTests(test.TestCase):
def setUp(self):

View File

@ -165,8 +165,7 @@ def _dnn_linear_combined_model_fn(features, labels, mode, params, config=None):
* embedding_lr_multipliers: Optional. A dictionary from
`EmbeddingColumn` to a `float` multiplier. Multiplier will be used to
multiply with learning rate for the embedding variables.
* input_layer_min_slice_size: Optional. The min slice size of input layer
partitions. If not provided, will use the default of 64M.
* input_layer_partitioner: Optional. Partitioner for input layer.
config: `RunConfig` object to configure the runtime settings.
Returns:
@ -174,7 +173,7 @@ def _dnn_linear_combined_model_fn(features, labels, mode, params, config=None):
Raises:
ValueError: If both `linear_feature_columns` and `dnn_features_columns`
are empty at the same time.
are empty at the same time, or `input_layer_partitioner` is missing.
"""
head = params["head"]
linear_feature_columns = params.get("linear_feature_columns")
@ -186,9 +185,11 @@ def _dnn_linear_combined_model_fn(features, labels, mode, params, config=None):
dnn_activation_fn = params.get("dnn_activation_fn") or nn.relu
dnn_dropout = params.get("dnn_dropout")
gradient_clip_norm = params.get("gradient_clip_norm")
input_layer_min_slice_size = (
params.get("input_layer_min_slice_size") or 64 << 20)
num_ps_replicas = config.num_ps_replicas if config else 0
input_layer_partitioner = params.get("input_layer_partitioner") or (
partitioned_variables.min_max_variable_partitioner(
max_partitions=num_ps_replicas,
min_slice_size=64 << 20))
embedding_lr_multipliers = params.get("embedding_lr_multipliers", {})
fix_global_step_increment_bug = params.get(
"fix_global_step_increment_bug", True)
@ -221,10 +222,6 @@ def _dnn_linear_combined_model_fn(features, labels, mode, params, config=None):
dnn_parent_scope,
values=tuple(six.itervalues(features)),
partitioner=dnn_partitioner):
input_layer_partitioner = (
partitioned_variables.min_max_variable_partitioner(
max_partitions=num_ps_replicas,
min_slice_size=input_layer_min_slice_size))
with variable_scope.variable_scope(
"input_from_feature_columns",
values=tuple(six.itervalues(features)),
@ -387,7 +384,8 @@ class DNNLinearCombinedEstimator(estimator.Estimator):
config=None,
feature_engineering_fn=None,
embedding_lr_multipliers=None,
fix_global_step_increment_bug=False):
fix_global_step_increment_bug=False,
input_layer_partitioner=None):
"""Initializes a DNNLinearCombinedEstimator instance.
Note: New users must set `fix_global_step_increment_bug=True` when creating
@ -432,6 +430,7 @@ class DNNLinearCombinedEstimator(estimator.Estimator):
steps to optimize both linear and dnn parts. If `True`, this bug is
fixed. New users must set this to `True`, but the default value is
`False` for backwards compatibility.
input_layer_partitioner: Optional. Partitioner for input layer.
Raises:
ValueError: If both linear_feature_columns and dnn_features_columns are
@ -459,6 +458,7 @@ class DNNLinearCombinedEstimator(estimator.Estimator):
"gradient_clip_norm": gradient_clip_norm,
"embedding_lr_multipliers": embedding_lr_multipliers,
"fix_global_step_increment_bug": fix_global_step_increment_bug,
"input_layer_partitioner": input_layer_partitioner
},
feature_engineering_fn=feature_engineering_fn)
@ -602,19 +602,25 @@ class DNNLinearCombinedClassifier(estimator.Estimator):
ValueError: If both `linear_feature_columns` and `dnn_features_columns`
are empty at the same time.
"""
if n_classes < 2:
raise ValueError("n_classes should be greater than 1. Given: {}".format(
n_classes))
head = head_lib.multi_class_head(
n_classes=n_classes,
weight_column_name=weight_column_name,
enable_centered_bias=enable_centered_bias)
linear_feature_columns = tuple(linear_feature_columns or [])
dnn_feature_columns = tuple(dnn_feature_columns or [])
self._feature_columns = linear_feature_columns + dnn_feature_columns
if not self._feature_columns:
raise ValueError("Either linear_feature_columns or dnn_feature_columns "
"must be defined.")
head = head_lib.multi_class_head(
n_classes=n_classes,
weight_column_name=weight_column_name,
enable_centered_bias=enable_centered_bias)
# TODO(b/35922130): Replace with `input_layer_partitioner` arg.
input_layer_partitioner = None
if input_layer_min_slice_size is not None:
input_layer_partitioner = (
partitioned_variables.min_max_variable_partitioner(
max_partitions=config.num_ps_replicas if config else 0,
min_slice_size=input_layer_min_slice_size))
super(DNNLinearCombinedClassifier, self).__init__(
model_fn=_dnn_linear_combined_model_fn,
model_dir=model_dir,
@ -631,7 +637,7 @@ class DNNLinearCombinedClassifier(estimator.Estimator):
"dnn_dropout": dnn_dropout,
"gradient_clip_norm": gradient_clip_norm,
"embedding_lr_multipliers": embedding_lr_multipliers,
"input_layer_min_slice_size": input_layer_min_slice_size,
"input_layer_partitioner": input_layer_partitioner,
"fix_global_step_increment_bug": fix_global_step_increment_bug,
},
feature_engineering_fn=feature_engineering_fn)
@ -916,6 +922,15 @@ class DNNLinearCombinedRegressor(estimator.Estimator):
if not self._feature_columns:
raise ValueError("Either linear_feature_columns or dnn_feature_columns "
"must be defined.")
# TODO(b/35922130): Replace with `input_layer_partitioner` arg.
input_layer_partitioner = None
if input_layer_min_slice_size is not None:
input_layer_partitioner = (
partitioned_variables.min_max_variable_partitioner(
max_partitions=config.num_ps_replicas if config else 0,
min_slice_size=input_layer_min_slice_size))
head = head_lib.regression_head(
weight_column_name=weight_column_name,
label_dimension=label_dimension,
@ -936,7 +951,7 @@ class DNNLinearCombinedRegressor(estimator.Estimator):
"dnn_dropout": dnn_dropout,
"gradient_clip_norm": gradient_clip_norm,
"embedding_lr_multipliers": embedding_lr_multipliers,
"input_layer_min_slice_size": input_layer_min_slice_size,
"input_layer_partitioner": input_layer_partitioner,
"fix_global_step_increment_bug": fix_global_step_increment_bug,
},
feature_engineering_fn=feature_engineering_fn)

View File

@ -341,7 +341,7 @@ class DNNLinearCombinedClassifierTest(test.TestCase):
input_layer_min_slice_size=1)
# Ensure the param is passed in.
self.assertEqual(1, classifier.params['input_layer_min_slice_size'])
self.assertTrue(callable(classifier.params['input_layer_partitioner']))
# Ensure the partition count is 10.
classifier.fit(input_fn=_input_fn_float_label, steps=50)

View File

@ -188,6 +188,10 @@ def build_sequence_input(features,
A `Tensor` of dtype `float32` and shape `[batch_size, padded_length, ?]`.
This will be used as input to an RNN.
"""
features = features.copy()
features.update(layers.transform_features(
features,
list(sequence_feature_columns) + list(context_feature_columns or [])))
sequence_input = layers.sequence_input_from_feature_columns(
columns_to_tensors=features,
feature_columns=sequence_feature_columns,

View File

@ -378,6 +378,8 @@ class BaseEstimator(
self._model_dir = tempfile.mkdtemp()
logging.warning('Using temporary folder as model directory: %s',
self._model_dir)
if self._config.model_dir is None:
self._config = self._config.replace(model_dir=self._model_dir)
# Set device function depending if there are replicas or not.
self._device_fn = _get_replica_device_setter(self._config)

View File

@ -481,6 +481,7 @@ class EstimatorTest(test.TestCase):
est = estimator.Estimator(model_fn=linear_model_fn,
config=config)
self.assertEqual('test_dir', est.config.model_dir)
self.assertEqual('test_dir', est.model_dir)
def testModelDirAndRunConfigModelDir(self):
config = run_config.RunConfig(model_dir='test_dir')
@ -489,11 +490,30 @@ class EstimatorTest(test.TestCase):
model_dir='test_dir')
self.assertEqual('test_dir', est.config.model_dir)
with self.assertRaises(ValueError):
with self.assertRaisesRegexp(
ValueError,
'model_dir are set both in constructor and RunConfig, '
'but with different'):
estimator.Estimator(model_fn=linear_model_fn,
config=config,
model_dir='different_dir')
def testModelDirIsCopiedToRunConfig(self):
config = run_config.RunConfig()
self.assertIsNone(config.model_dir)
est = estimator.Estimator(model_fn=linear_model_fn,
model_dir='test_dir',
config=config)
self.assertEqual('test_dir', est.config.model_dir)
self.assertEqual('test_dir', est.model_dir)
def testModelDirAsTempDir(self):
with test.mock.patch.object(tempfile, 'mkdtemp', return_value='temp_dir'):
est = estimator.Estimator(model_fn=linear_model_fn)
self.assertEqual('temp_dir', est.config.model_dir)
self.assertEqual('temp_dir', est.model_dir)
def testCheckInputs(self):
est = estimator.SKCompat(estimator.Estimator(model_fn=linear_model_fn))
# Lambdas so we have to different objects to compare

View File

@ -231,6 +231,8 @@ def sdca_model_fn(features, labels, mode, params):
with variable_scope.variable_op_scope(
features.values(), parent_scope) as scope:
features = features.copy()
features.update(layers.transform_features(features, feature_columns))
logits, columns_to_variables, bias = (
layers.weighted_sum_from_feature_columns(
columns_to_tensors=features,

View File

@ -18,12 +18,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import copy
import json
import os
import six
from tensorflow.contrib.framework.python.framework import experimental
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.training import server_lib
@ -291,7 +293,7 @@ class RunConfig(ClusterConfig):
new_copy = copy.deepcopy(self)
# TODO(xiejw): Allow more fields, such as the user allowed changed ones.
# TODO(b/33295821): Allow more fields to be replaced.
for key, new_value in six.iteritems(kwargs):
if key == 'model_dir':
new_copy._model_dir = new_value # pylint: disable=protected-access
@ -301,6 +303,24 @@ class RunConfig(ClusterConfig):
return new_copy
@experimental
def uid(self):
"""Generates a 'Unique Identifier' based on all internal fields.
Caller should use the uid string to check `RunConfig` instance integrity
in one session use, but should not rely on the implementation details, which
is subject to change.
Returns:
A uid string.
"""
# TODO(b/33295821): Allows user to specify a whitelist.
state = {k: v for k, v in self.__dict__.items() if not k.startswith('__')}
ordered_state = collections.OrderedDict(
sorted(state.items(), key=lambda t: t[0]))
return ', '.join(
'%s=%r' % (k, v) for (k, v) in six.iteritems(ordered_state))
@property
def model_dir(self):
return self._model_dir

View File

@ -238,6 +238,18 @@ class RunConfigTest(test.TestCase):
with self.assertRaises(ValueError):
config.replace(some_undefined_property=RANDOM_SEED)
def test_uid(self):
config = run_config.RunConfig(
tf_random_seed=RANDOM_SEED, model_dir=TEST_DIR)
expected_uid = config.uid()
# Check for 10 times, which should prove something.
for _ in range(10):
self.assertEqual(expected_uid, config.uid())
new_config = config.replace(model_dir=ANOTHER_TEST_DIR)
self.assertNotEqual(expected_uid, new_config.uid())
if __name__ == "__main__":
test.main()

View File

@ -42,9 +42,7 @@ from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import resources
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import coordinator
from tensorflow.python.training import monitored_session
from tensorflow.python.training import queue_runner
from tensorflow.python.training import saver as tf_saver
from tensorflow.python.training import session_manager as session_manager_lib
@ -121,205 +119,6 @@ def _run_with_monitors(session, step, tensors, feed_dict, monitors):
return outputs, should_stop
def _monitored_train(graph,
output_dir,
train_op,
loss_op,
global_step_tensor=None,
init_op=None,
init_feed_dict=None,
init_fn=None,
log_every_steps=10,
supervisor_is_chief=True,
supervisor_master='',
supervisor_save_model_secs=600,
supervisor_save_model_steps=None,
keep_checkpoint_max=5,
keep_checkpoint_every_n_hours=10000.0,
supervisor_save_summaries_secs=None,
supervisor_save_summaries_steps=100,
feed_fn=None,
steps=None,
fail_on_nan_loss=True,
hooks=None,
max_steps=None):
"""Train a model via monitored_session.
Given `graph`, a directory to write outputs to (`output_dir`), and some ops,
run a training loop. The given `train_op` performs one step of training on the
model. The `loss_op` represents the objective function of the training. It is
expected to increment the `global_step_tensor`, a scalar integer tensor
counting training steps. This function uses `Supervisor` to initialize the
graph (from a checkpoint if one is available in `output_dir`), write summaries
defined in the graph, and write regular checkpoints as defined by
`supervisor_save_model_secs`.
Training continues until `global_step_tensor` evaluates to `max_steps`, or, if
`fail_on_nan_loss`, until `loss_op` evaluates to `NaN`. In that case the
program is terminated with exit code 1.
Args:
graph: A graph to train. It is expected that this graph is not in use
elsewhere.
output_dir: A directory to write outputs to.
train_op: An op that performs one training step when run.
loss_op: A scalar loss tensor.
global_step_tensor: A tensor representing the global step. If none is given,
one is extracted from the graph using the same logic as in `Supervisor`.
init_op: An op that initializes the graph. If `None`, use `Supervisor`'s
default.
init_feed_dict: A dictionary that maps `Tensor` objects to feed values.
This feed dictionary will be used when `init_op` is evaluated.
init_fn: Optional callable passed to Supervisor to initialize the model.
log_every_steps: Output logs regularly. The logs contain timing data and the
current loss. A `0` or negative value disables logging.
supervisor_is_chief: Whether the current process is the chief supervisor in
charge of restoring the model and running standard services.
supervisor_master: The master string to use when preparing the session.
supervisor_save_model_secs: Save checkpoints every this many seconds. Can
not be specified with `supervisor_save_model_steps`.
supervisor_save_model_steps: Save checkpoints every this many steps. Can not
be specified with `supervisor_save_model_secs`.
keep_checkpoint_max: The maximum number of recent checkpoint files to
keep. As new files are created, older files are deleted. If None or 0,
all checkpoint files are kept. This is simply passed as the max_to_keep
arg to `tf.train.Saver` constructor.
keep_checkpoint_every_n_hours: In addition to keeping the most recent
`keep_checkpoint_max` checkpoint files, you might want to keep one checkpoint file
for every N hours of training. This can be useful if you want to later
analyze how a model progressed during a long training session. For
example, passing `keep_checkpoint_every_n_hours=2` ensures that you keep
one checkpoint file for every 2 hours of training. The default value of
10,000 hours effectively disables the feature.
supervisor_save_summaries_secs: Save summaries every
`supervisor_save_summaries_secs` seconds when training.
supervisor_save_summaries_steps: Save summaries every
`supervisor_save_summaries_steps` steps when training. Exactly one of
`supervisor_save_model_steps` and `supervisor_save_model_secs` should be
specified, and the other should be None.
feed_fn: A function that is called every iteration to produce a `feed_dict`
passed to `session.run` calls. Optional.
steps: Trains for this many steps (e.g. current global step + `steps`).
fail_on_nan_loss: If true, raise `NanLossDuringTrainingError` if `loss_op`
evaluates to `NaN`. If false, continue training as if nothing happened.
hooks: List of `SessionRunHook` subclass instances. Used for callbacks
inside the training loop.
max_steps: Number of total steps for which to train model. If `None`,
train forever. Two calls fit(steps=100) means 200 training iterations.
On the other hand two calls of fit(max_steps=100) means, second call
will not do any iteration since first call did all 100 steps.
Returns:
The final loss value.
Raises:
ValueError: If `output_dir`, `train_op`, `loss_op`, or `global_step_tensor`
is not provided. See `tf.contrib.framework.get_global_step` for how we
look up the latter if not provided explicitly.
NanLossDuringTrainingError: If `fail_on_nan_loss` is `True`, and loss ever
evaluates to `NaN`.
ValueError: If both `steps` and `max_steps` are not `None`.
"""
if (steps is not None) and (max_steps is not None):
raise ValueError('Can not provide both steps and max_steps.')
if not output_dir:
raise ValueError('Output directory should be non-empty %s.' % output_dir)
if train_op is None:
raise ValueError('Missing train_op.')
if loss_op is None:
raise ValueError('Missing loss_op.')
if hooks is None:
hooks = []
if not isinstance(hooks, list):
raise ValueError('Hooks should be a list.')
with graph.as_default():
global_step_tensor = contrib_variables.assert_or_get_global_step(
graph, global_step_tensor)
if global_step_tensor is None:
raise ValueError('No "global_step" was provided or found in the graph.')
if max_steps is not None:
try:
start_step = load_variable(output_dir, global_step_tensor.name)
if max_steps <= start_step:
logging.info('Skipping training since max_steps has already saved.')
return None
except: # pylint: disable=bare-except
pass
# Adapted SessionRunHooks such as ExportMonitor depend on the
# CheckpointSaverHook to be executed before they should be executed.
# The `hooks` param comprises of deprecated monitor hooks
# (such as ExportMonitor). Appending them after the basic_session_run_hooks.
all_hooks = []
with graph.as_default():
all_hooks.append(basic_session_run_hooks.NanTensorHook(
loss_op, fail_on_nan_loss=fail_on_nan_loss))
if log_every_steps > 0:
all_hooks.append(basic_session_run_hooks.LoggingTensorHook({
'loss': loss_op.name,
'step': global_step_tensor.name
}, every_n_iter=log_every_steps))
def make_saver():
return tf_saver.Saver(
sharded=True,
max_to_keep=keep_checkpoint_max,
keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
defer_build=True)
scaffold = monitored_session.Scaffold(
init_op=init_op,
init_feed_dict=init_feed_dict,
init_fn=init_fn,
saver=monitored_session.Scaffold.get_or_default('saver',
ops.GraphKeys.SAVERS,
make_saver))
if not supervisor_is_chief:
session_creator = monitored_session.WorkerSessionCreator(
scaffold=scaffold,
master=supervisor_master)
else:
session_creator = monitored_session.ChiefSessionCreator(
scaffold=scaffold,
checkpoint_dir=output_dir,
master=supervisor_master)
summary_writer = summary_io.SummaryWriterCache.get(output_dir)
all_hooks.append(
basic_session_run_hooks.StepCounterHook(
summary_writer=summary_writer))
all_hooks.append(
basic_session_run_hooks.SummarySaverHook(
save_secs=supervisor_save_summaries_secs,
save_steps=supervisor_save_summaries_steps,
summary_writer=summary_writer,
scaffold=scaffold))
if (supervisor_save_model_secs is not None
or supervisor_save_model_steps is not None):
all_hooks.append(
basic_session_run_hooks.CheckpointSaverHook(
output_dir,
save_secs=supervisor_save_model_secs,
save_steps=supervisor_save_model_steps,
scaffold=scaffold))
if steps is not None or max_steps is not None:
all_hooks.append(basic_session_run_hooks.StopAtStepHook(steps, max_steps))
all_hooks.extend(hooks)
with monitored_session.MonitoredSession(
session_creator=session_creator,
hooks=all_hooks) as super_sess:
loss = None
while not super_sess.should_stop():
_, loss = super_sess.run([train_op, loss_op], feed_fn() if feed_fn else
None)
summary_io.SummaryWriterCache.clear()
return loss
@_graph_action_deprecation
def train(graph,
output_dir,

View File

@ -27,7 +27,6 @@ from tensorflow.contrib.framework.python.ops import variables as variables_lib
from tensorflow.contrib.learn.python import learn
from tensorflow.contrib.learn.python.learn.monitors import BaseMonitor
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_ops
from tensorflow.python.ops import control_flow_ops
@ -36,7 +35,6 @@ from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.summary import summary
from tensorflow.python.training import monitored_session
from tensorflow.python.training import saver as saver_lib
@ -220,7 +218,7 @@ class GraphActionsTest(test.TestCase):
self.assertTrue(request_stop.called)
def test_run_feeds_iter_calls_resources_init(self):
with ops.Graph().as_default() as g:
with ops.Graph().as_default():
in0, _, _ = self._build_inference_graph()
handle = test_ops.stub_resource_handle_op(container='a', shared_name='b')
resources.register_resource(
@ -314,7 +312,7 @@ class GraphActionsTest(test.TestCase):
with ops.Graph().as_default() as g, self.test_session(g):
variables_lib.create_global_step()
v = variables.Variable(1.0)
w = variables.Variable(
variables.Variable(
v + 1, collections=[ops.GraphKeys.LOCAL_VARIABLES], trainable=False)
ready_for_local_init_op = variables.report_uninitialized_variables(
variables.global_variables())
@ -396,223 +394,11 @@ class GraphActionsTest(test.TestCase):
}},
expected_session_logs=[])
def test_train_invalid_args(self):
with ops.Graph().as_default() as g, self.test_session(g):
train_op = constant_op.constant(1.0)
loss_op = constant_op.constant(2.0)
with self.assertRaisesRegexp(ValueError, 'utput directory'):
learn.graph_actions._monitored_train(
g, # pylint: disable=protected-access
output_dir=None,
train_op=train_op,
loss_op=loss_op)
with self.assertRaisesRegexp(ValueError, 'utput directory'):
learn.graph_actions._monitored_train( # pylint: disable=protected-access
g,
output_dir='',
train_op=constant_op.constant(1.0),
loss_op=constant_op.constant(2.0))
with self.assertRaisesRegexp(ValueError, 'train_op'):
learn.graph_actions._monitored_train( # pylint: disable=protected-access
g,
output_dir=self._output_dir,
train_op=None,
loss_op=loss_op)
with self.assertRaisesRegexp(ValueError, 'loss_op'):
learn.graph_actions._monitored_train( # pylint: disable=protected-access
g,
output_dir=self._output_dir,
train_op=constant_op.constant(1.0),
loss_op=None)
with self.assertRaisesRegexp(ValueError, 'global_step'):
learn.graph_actions._monitored_train( # pylint: disable=protected-access
g,
output_dir=self._output_dir,
train_op=constant_op.constant(1.0),
loss_op=loss_op)
# TODO(ptucker): Resume training from previous ckpt.
# TODO(ptucker): !supervisor_is_chief
# TODO(ptucker): Custom init op for training.
# TODO(ptucker): Mock supervisor, and assert all interactions.
def test_train(self):
with ops.Graph().as_default() as g, self.test_session(g):
with ops.control_dependencies(self._build_inference_graph()):
train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
writer = learn.graph_actions.get_summary_writer(self._output_dir)
self._assert_summaries(self._output_dir, writer)
self._assert_ckpt(self._output_dir, False)
loss = learn.graph_actions._monitored_train( # pylint: disable=protected-access
g,
output_dir=self._output_dir,
train_op=train_op,
loss_op=constant_op.constant(2.0),
steps=1)
meta_graph_def = meta_graph.create_meta_graph_def(
graph_def=g.as_graph_def(add_shapes=True),
saver_def=monitored_session.Scaffold().finalize().saver.saver_def)
self.assertEqual(2.0, loss)
self._assert_summaries(
self._output_dir,
writer,
expected_graphs=[g],
expected_meta_graphs=[meta_graph_def])
self._assert_ckpt(self._output_dir, True)
def test_train_steps_is_incremental(self):
with ops.Graph().as_default() as g, self.test_session(g):
with ops.control_dependencies(self._build_inference_graph()):
train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
learn.graph_actions._monitored_train( # pylint: disable=protected-access
g,
output_dir=self._output_dir,
train_op=train_op,
loss_op=constant_op.constant(2.0),
steps=10)
step = checkpoint_utils.load_variable(
self._output_dir, variables_lib.get_global_step().name)
self.assertEqual(10, step)
with ops.Graph().as_default() as g, self.test_session(g):
with ops.control_dependencies(self._build_inference_graph()):
train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
learn.graph_actions._monitored_train( # pylint: disable=protected-access
g,
output_dir=self._output_dir,
train_op=train_op,
loss_op=constant_op.constant(2.0),
steps=15)
step = checkpoint_utils.load_variable(
self._output_dir, variables_lib.get_global_step().name)
self.assertEqual(25, step)
def test_train_max_steps_is_not_incremental(self):
with ops.Graph().as_default() as g, self.test_session(g):
with ops.control_dependencies(self._build_inference_graph()):
train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
learn.graph_actions._monitored_train( # pylint: disable=protected-access
g,
output_dir=self._output_dir,
train_op=train_op,
loss_op=constant_op.constant(2.0),
max_steps=10)
step = checkpoint_utils.load_variable(
self._output_dir, variables_lib.get_global_step().name)
self.assertEqual(10, step)
with ops.Graph().as_default() as g, self.test_session(g):
with ops.control_dependencies(self._build_inference_graph()):
train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
learn.graph_actions._monitored_train( # pylint: disable=protected-access
g,
output_dir=self._output_dir,
train_op=train_op,
loss_op=constant_op.constant(2.0),
max_steps=15)
step = checkpoint_utils.load_variable(
self._output_dir, variables_lib.get_global_step().name)
self.assertEqual(15, step)
def test_train_skip_train_if_max_step_already_saved(self):
with ops.Graph().as_default() as g, self.test_session(g):
with ops.control_dependencies(self._build_inference_graph()):
train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
learn.graph_actions._monitored_train( # pylint: disable=protected-access
g,
output_dir=self._output_dir,
train_op=train_op,
loss_op=constant_op.constant(2.0),
max_steps=10)
step = checkpoint_utils.load_variable(
self._output_dir, variables_lib.get_global_step().name)
self.assertEqual(10, step)
with ops.Graph().as_default() as g, self.test_session(g):
with ops.control_dependencies(self._build_inference_graph()):
train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
learn.graph_actions._monitored_train( # pylint: disable=protected-access
g,
output_dir=self._output_dir,
train_op=train_op,
loss_op=constant_op.constant(2.0),
max_steps=10)
step = checkpoint_utils.load_variable(
self._output_dir, variables_lib.get_global_step().name)
self.assertEqual(10, step)
def test_train_loss(self):
with ops.Graph().as_default() as g, self.test_session(g):
variables_lib.create_global_step()
loss_var = variables_lib.local_variable(10.0)
train_op = control_flow_ops.group(
state_ops.assign_add(variables_lib.get_global_step(), 1),
state_ops.assign_add(loss_var, -1.0))
writer = learn.graph_actions.get_summary_writer(self._output_dir)
self._assert_summaries(self._output_dir, writer)
self._assert_ckpt(self._output_dir, False)
loss = learn.graph_actions._monitored_train( # pylint: disable=protected-access
g,
output_dir=self._output_dir,
train_op=train_op,
loss_op=loss_var.value(),
steps=6)
self.assertEqual(4.0, loss)
self._assert_summaries(
self._output_dir,
writer,
expected_graphs=[g],
expected_meta_graphs=None)
self._assert_ckpt(self._output_dir, True)
def test_train_summaries(self):
with ops.Graph().as_default() as g, self.test_session(g):
with ops.control_dependencies(self._build_inference_graph()):
train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
loss_op = constant_op.constant(2.0)
summary.scalar('loss', loss_op)
writer = learn.graph_actions.get_summary_writer(self._output_dir)
self._assert_summaries(self._output_dir, writer)
self._assert_ckpt(self._output_dir, False)
loss = learn.graph_actions._monitored_train( # pylint: disable=protected-access
g,
output_dir=self._output_dir,
train_op=train_op,
loss_op=loss_op,
steps=1)
meta_graph_def = meta_graph.create_meta_graph_def(
graph_def=g.as_graph_def(add_shapes=True),
saver_def=monitored_session.Scaffold().finalize().saver.saver_def)
self.assertEqual(2.0, loss)
self._assert_summaries(
self._output_dir,
writer,
expected_graphs=[g],
expected_meta_graphs=[meta_graph_def],
expected_summaries={1: {
'loss': 2.0
}})
self._assert_ckpt(self._output_dir, True)
def test_train_override_saver(self):
with ops.Graph().as_default() as g, self.test_session(g):
with ops.control_dependencies(self._build_inference_graph()):
train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
self._assert_ckpt(self._output_dir, False)
real_saver = saver_lib.Saver()
saver = test.mock.Mock(wraps=real_saver, saver_def=real_saver.saver_def)
ops.add_to_collection(ops.GraphKeys.SAVERS, saver)
loss = learn.graph_actions._monitored_train( # pylint: disable=protected-access
g,
output_dir=self._output_dir,
train_op=train_op,
loss_op=constant_op.constant(2.0),
steps=1)
self.assertEqual(2.0, loss)
self._assert_ckpt(self._output_dir, True)
self.assertTrue(saver.build.called)
self.assertEqual(1, saver.save.call_count)
# TODO(ispir): remove following tests after deprecated train.
class GraphActionsTrainTest(test.TestCase):

View File

@ -60,14 +60,6 @@ class LinearOperatorTriLTest(
return operator, mat, feed_dict
def test_assert_positive_definite(self):
# Singlular matrix with one positive eigenvalue and one negative eigenvalue.
with self.test_session():
tril = [[1., 0.], [1., -1.]]
operator = linalg.LinearOperatorTriL(tril)
with self.assertRaisesOpError("was not positive definite"):
operator.assert_positive_definite().run()
def test_assert_non_singular(self):
# Singlular matrix with one positive eigenvalue and one zero eigenvalue.
with self.test_session():

View File

@ -21,8 +21,10 @@ import numpy as np
from tensorflow.contrib import linalg as linalg_lib
from tensorflow.contrib.linalg.python.ops import linear_operator_util
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
@ -91,6 +93,142 @@ class AssertNoEntriesWithModulusZeroTest(test.TestCase):
z, message="ABC123").run()
class BroadcastMatrixBatchDimsTest(test.TestCase):
def test_zero_batch_matrices_returned_as_empty_list(self):
self.assertAllEqual(
[], linear_operator_util.broadcast_matrix_batch_dims([]))
def test_one_batch_matrix_returned_after_tensor_conversion(self):
arr = rng.rand(2, 3, 4)
tensor, = linear_operator_util.broadcast_matrix_batch_dims([arr])
self.assertTrue(isinstance(tensor, ops.Tensor))
with self.test_session():
self.assertAllClose(arr, tensor.eval())
def test_static_dims_broadcast(self):
# x.batch_shape = [3, 1, 2]
# y.batch_shape = [4, 1]
# broadcast batch shape = [3, 4, 2]
x = rng.rand(3, 1, 2, 1, 5)
y = rng.rand(4, 1, 3, 7)
batch_of_zeros = np.zeros((3, 4, 2, 1, 1))
x_bc_expected = x + batch_of_zeros
y_bc_expected = y + batch_of_zeros
x_bc, y_bc = linear_operator_util.broadcast_matrix_batch_dims([x, y])
with self.test_session() as sess:
self.assertAllEqual(x_bc_expected.shape, x_bc.get_shape())
self.assertAllEqual(y_bc_expected.shape, y_bc.get_shape())
x_bc_, y_bc_ = sess.run([x_bc, y_bc])
self.assertAllClose(x_bc_expected, x_bc_)
self.assertAllClose(y_bc_expected, y_bc_)
def test_static_dims_broadcast_second_arg_higher_rank(self):
# x.batch_shape = [1, 2]
# y.batch_shape = [1, 3, 1]
# broadcast batch shape = [1, 3, 2]
x = rng.rand(1, 2, 1, 5)
y = rng.rand(1, 3, 2, 3, 7)
batch_of_zeros = np.zeros((1, 3, 2, 1, 1))
x_bc_expected = x + batch_of_zeros
y_bc_expected = y + batch_of_zeros
x_bc, y_bc = linear_operator_util.broadcast_matrix_batch_dims([x, y])
with self.test_session() as sess:
self.assertAllEqual(x_bc_expected.shape, x_bc.get_shape())
self.assertAllEqual(y_bc_expected.shape, y_bc.get_shape())
x_bc_, y_bc_ = sess.run([x_bc, y_bc])
self.assertAllClose(x_bc_expected, x_bc_)
self.assertAllClose(y_bc_expected, y_bc_)
def test_dynamic_dims_broadcast_32bit(self):
# x.batch_shape = [3, 1, 2]
# y.batch_shape = [4, 1]
# broadcast batch shape = [3, 4, 2]
x = rng.rand(3, 1, 2, 1, 5).astype(np.float32)
y = rng.rand(4, 1, 3, 7).astype(np.float32)
batch_of_zeros = np.zeros((3, 4, 2, 1, 1)).astype(np.float32)
x_bc_expected = x + batch_of_zeros
y_bc_expected = y + batch_of_zeros
x_ph = array_ops.placeholder(dtypes.float32)
y_ph = array_ops.placeholder(dtypes.float32)
x_bc, y_bc = linear_operator_util.broadcast_matrix_batch_dims([x_ph, y_ph])
with self.test_session() as sess:
x_bc_, y_bc_ = sess.run([x_bc, y_bc], feed_dict={x_ph: x, y_ph: y})
self.assertAllClose(x_bc_expected, x_bc_)
self.assertAllClose(y_bc_expected, y_bc_)
def test_dynamic_dims_broadcast_32bit_second_arg_higher_rank(self):
# x.batch_shape = [1, 2]
# y.batch_shape = [3, 4, 1]
# broadcast batch shape = [3, 4, 2]
x = rng.rand(1, 2, 1, 5).astype(np.float32)
y = rng.rand(3, 4, 1, 3, 7).astype(np.float32)
batch_of_zeros = np.zeros((3, 4, 2, 1, 1)).astype(np.float32)
x_bc_expected = x + batch_of_zeros
y_bc_expected = y + batch_of_zeros
x_ph = array_ops.placeholder(dtypes.float32)
y_ph = array_ops.placeholder(dtypes.float32)
x_bc, y_bc = linear_operator_util.broadcast_matrix_batch_dims([x_ph, y_ph])
with self.test_session() as sess:
x_bc_, y_bc_ = sess.run([x_bc, y_bc], feed_dict={x_ph: x, y_ph: y})
self.assertAllClose(x_bc_expected, x_bc_)
self.assertAllClose(y_bc_expected, y_bc_)
def test_less_than_two_dims_raises_static(self):
x = rng.rand(3)
y = rng.rand(1, 1)
with self.assertRaisesRegexp(ValueError, "at least two dimensions"):
linear_operator_util.broadcast_matrix_batch_dims([x, y])
with self.assertRaisesRegexp(ValueError, "at least two dimensions"):
linear_operator_util.broadcast_matrix_batch_dims([y, x])
class MatmulWithBroadcastTest(test.TestCase):
def test_static_dims_broadcast(self):
# batch_shape = [2]
# for each batch member, we have a 1x3 matrix times a 3x7 matrix ==> 1x7
x = rng.rand(2, 1, 3)
y = rng.rand(3, 7)
y_broadcast = y + np.zeros((2, 1, 1))
with self.test_session():
result = linear_operator_util.matmul_with_broadcast(x, y)
self.assertAllEqual((2, 1, 7), result.get_shape())
expected = math_ops.matmul(x, y_broadcast)
self.assertAllEqual(expected.eval(), result.eval())
def test_dynamic_dims_broadcast_32bit(self):
# batch_shape = [2]
# for each batch member, we have a 1x3 matrix times a 3x7 matrix ==> 1x7
x = rng.rand(2, 1, 3)
y = rng.rand(3, 7)
y_broadcast = y + np.zeros((2, 1, 1))
x_ph = array_ops.placeholder(dtypes.float64)
y_ph = array_ops.placeholder(dtypes.float64)
with self.test_session() as sess:
result, expected = sess.run(
[linear_operator_util.matmul_with_broadcast(x_ph, y_ph),
math_ops.matmul(x, y_broadcast)],
feed_dict={x_ph: x, y_ph: y})
self.assertAllEqual(expected, result)
class DomainDimensionStubOperator(object):
def __init__(self, domain_dimension):

View File

@ -155,8 +155,9 @@ class LinearOperator(object):
is_self_adjoint: Expect that this operator is equal to its hermitian
transpose. If `dtype` is real, this is equivalent to being symmetric.
is_positive_definite: Expect that this operator is positive definite,
meaning the real part of all eigenvalues is positive. We do not require
the operator to be self-adjoint to be positive-definite. See:
meaning the quadratic form `x^H A x` has positive real part for all
nonzero `x`. Note that we do not require the operator to be
self-adjoint to be positive-definite. See:
https://en.wikipedia.org/wiki/Positive-definite_matrix\
#Extension_for_non_symmetric_matrices
is_square: Expect that this operator acts like square [batch] matrices.
@ -461,8 +462,9 @@ class LinearOperator(object):
def assert_positive_definite(self, name="assert_positive_definite"):
"""Returns an `Op` that asserts this operator is positive definite.
Here, positive definite means the real part of all eigenvalues is positive.
We do not require the operator to be self-adjoint.
Here, positive definite means that the quadratic form `x^H A x` has positive
real part for all nonzero `x`. Note that we do not require the operator to
be self-adjoint to be positive definite.
Args:
name: A name to give this `Op`.

View File

@ -113,7 +113,7 @@ class LinearOperatorComposition(linear_operator.LinearOperator):
is_self_adjoint=None,
is_positive_definite=None,
name=None):
"""Initialize a `LinearOperatorComposition`.
r"""Initialize a `LinearOperatorComposition`.
`LinearOperatorComposition` is initialized with a list of operators
`[op_1,...,op_J]`. For the `apply` method to be well defined, the
@ -127,9 +127,10 @@ class LinearOperatorComposition(linear_operator.LinearOperator):
is_self_adjoint: Expect that this operator is equal to its hermitian
transpose.
is_positive_definite: Expect that this operator is positive definite,
meaning the real part of all eigenvalues is positive. We do not require
the operator to be self-adjoint to be positive-definite. See:
https://en.wikipedia.org/wiki/Positive-definite_matrix
meaning the quadratic form `x^H A x` has positive real part for all
nonzero `x`. Note that we do not require the operator to be
self-adjoint to be positive-definite. See:
https://en.wikipedia.org/wiki/Positive-definite_matrix\
#Extension_for_non_symmetric_matrices
name: A name for this `LinearOperator`. Default is the individual
operators names joined with `_o_`.

View File

@ -114,7 +114,7 @@ class LinearOperatorDiag(linear_operator.LinearOperator):
is_self_adjoint=None,
is_positive_definite=None,
name="LinearOperatorDiag"):
"""Initialize a `LinearOperatorDiag`.
r"""Initialize a `LinearOperatorDiag`.
Args:
diag: Shape `[B1,...,Bb, N]` `Tensor` with `b >= 0` `N >= 0`.
@ -124,9 +124,10 @@ class LinearOperatorDiag(linear_operator.LinearOperator):
is_self_adjoint: Expect that this operator is equal to its hermitian
transpose. If `diag.dtype` is real, this is auto-set to `True`.
is_positive_definite: Expect that this operator is positive definite,
meaning the real part of all eigenvalues is positive. We do not require
the operator to be self-adjoint to be positive-definite. See:
https://en.wikipedia.org/wiki/Positive-definite_matrix
meaning the quadratic form `x^H A x` has positive real part for all
nonzero `x`. Note that we do not require the operator to be
self-adjoint to be positive-definite. See:
https://en.wikipedia.org/wiki/Positive-definite_matrix\
#Extension_for_non_symmetric_matrices
name: A name for this `LinearOperator`.

View File

@ -109,7 +109,7 @@ class LinearOperatorFullMatrix(linear_operator.LinearOperator):
is_self_adjoint=None,
is_positive_definite=None,
name="LinearOperatorFullMatrix"):
"""Initialize a `LinearOperatorFullMatrix`.
r"""Initialize a `LinearOperatorFullMatrix`.
Args:
matrix: Shape `[B1,...,Bb, M, N]` with `b >= 0`, `M, N >= 0`.
@ -118,9 +118,10 @@ class LinearOperatorFullMatrix(linear_operator.LinearOperator):
is_self_adjoint: Expect that this operator is equal to its hermitian
transpose.
is_positive_definite: Expect that this operator is positive definite,
meaning the real part of all eigenvalues is positive. We do not require
the operator to be self-adjoint to be positive-definite. See:
https://en.wikipedia.org/wiki/Positive-definite_matrix
meaning the quadratic form `x^H A x` has positive real part for all
nonzero `x`. Note that we do not require the operator to be
self-adjoint to be positive-definite. See:
https://en.wikipedia.org/wiki/Positive-definite_matrix\
#Extension_for_non_symmetric_matrices
name: A name for this `LinearOperator`.

View File

@ -200,7 +200,7 @@ class LinearOperatorIdentity(BaseLinearOperatorIdentity):
is_positive_definite=True,
assert_proper_shapes=False,
name="LinearOperatorIdentity"):
"""Initialize a `LinearOperatorIdentity`.
r"""Initialize a `LinearOperatorIdentity`.
The `LinearOperatorIdentity` is initialized with arguments defining `dtype`
and shape.
@ -218,7 +218,12 @@ class LinearOperatorIdentity(BaseLinearOperatorIdentity):
is_non_singular: Expect that this operator is non-singular.
is_self_adjoint: Expect that this operator is equal to its hermitian
transpose.
is_positive_definite: Expect that this operator is positive definite.
is_positive_definite: Expect that this operator is positive definite,
meaning the quadratic form `x^H A x` has positive real part for all
nonzero `x`. Note that we do not require the operator to be
self-adjoint to be positive-definite. See:
https://en.wikipedia.org/wiki/Positive-definite_matrix\
#Extension_for_non_symmetric_matrices
assert_proper_shapes: Python `bool`. If `False`, only perform static
checks that initialization and method arguments have proper shape.
If `True`, and static checks are inconclusive, add asserts to the graph.
@ -523,7 +528,7 @@ class LinearOperatorScaledIdentity(BaseLinearOperatorIdentity):
is_positive_definite=None,
assert_proper_shapes=False,
name="LinearOperatorScaledIdentity"):
"""Initialize a `LinearOperatorScaledIdentity`.
r"""Initialize a `LinearOperatorScaledIdentity`.
The `LinearOperatorScaledIdentity` is initialized with `num_rows`, which
determines the size of each identity matrix, and a `multiplier`,
@ -538,7 +543,12 @@ class LinearOperatorScaledIdentity(BaseLinearOperatorIdentity):
is_non_singular: Expect that this operator is non-singular.
is_self_adjoint: Expect that this operator is equal to its hermitian
transpose.
is_positive_definite: Expect that this operator is positive definite.
is_positive_definite: Expect that this operator is positive definite,
meaning the quadratic form `x^H A x` has positive real part for all
nonzero `x`. Note that we do not require the operator to be
self-adjoint to be positive-definite. See:
https://en.wikipedia.org/wiki/Positive-definite_matrix\
#Extension_for_non_symmetric_matrices
assert_proper_shapes: Python `bool`. If `False`, only perform static
checks that initialization and method arguments have proper shape.
If `True`, and static checks are inconclusive, add asserts to the graph.

View File

@ -23,7 +23,6 @@ from tensorflow.contrib.linalg.python.ops import linear_operator_util
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
@ -108,7 +107,7 @@ class LinearOperatorTriL(linear_operator.LinearOperator):
is_self_adjoint=None,
is_positive_definite=None,
name="LinearOperatorTriL"):
"""Initialize a `LinearOperatorTriL`.
r"""Initialize a `LinearOperatorTriL`.
Args:
tril: Shape `[B1,...,Bb, N, N]` with `b >= 0`, `N >= 0`.
@ -122,9 +121,10 @@ class LinearOperatorTriL(linear_operator.LinearOperator):
real-valued diagonal entries. In this case it is advised to use
`LinearOperatorDiag`.
is_positive_definite: Expect that this operator is positive definite,
meaning the real part of all eigenvalues is positive. We do not require
the operator to be self-adjoint to be positive-definite. See:
https://en.wikipedia.org/wiki/Positive-definite_matrix
meaning the quadratic form `x^H A x` has positive real part for all
nonzero `x`. Note that we do not require the operator to be
self-adjoint to be positive-definite. See:
https://en.wikipedia.org/wiki/Positive-definite_matrix\
#Extension_for_non_symmetric_matrices
name: A name for this `LinearOperator`.
@ -173,20 +173,6 @@ class LinearOperatorTriL(linear_operator.LinearOperator):
self._diag,
message="Singular operator: Diagonal contained zero values.")
def _assert_positive_definite(self):
if self.dtype.is_complex:
message = (
"Diagonal operator had diagonal entries with non-positive real part, "
"thus was not positive definite.")
else:
message = (
"Real diagonal operator had non-positive diagonal entries, "
"thus was not positive definite.")
return check_ops.assert_positive(
math_ops.real(self._diag),
message=message)
def _apply(self, x, adjoint=False):
return math_ops.matmul(self._tril, x, adjoint_a=adjoint)

View File

@ -170,6 +170,8 @@ class LinearOperatorUDVHUpdate(linear_operator.LinearOperator):
Default is `None`, unless `base_operator` is positive-definite
`v = None` (meaning `u=v`), and `is_diag_update_positive`, in which case
this defaults to `True`.
Note that we say an operator is positive definite when the quadratic
form `x^H A x` has positive real part for all nonzero `x`.
is_square: Expect that this operator acts like square [batch] matrices.
name: A name for this `LinearOperator`.

View File

@ -87,13 +87,208 @@ def assert_compatible_matrix_dimensions(operator, x):
assert_same_dd = check_ops.assert_equal(
array_ops.shape(x)[-2],
operator.domain_dimension_tensor(),
message=(
"Incompatible matrix dimensions. "
"shape[-2] of argument to be the same as this operator"))
message=("Incompatible matrix dimensions. "
"shape[-2] of argument to be the same as this operator"))
return assert_same_dd
def assert_is_batch_matrix(tensor):
"""Static assert that `tensor` has rank `2` or higher."""
sh = tensor.get_shape()
if sh.ndims is not None and sh.ndims < 2:
raise ValueError(
"Expected [batch] matrix to have at least two dimensions. Found: "
"%s" % tensor)
def broadcast_matrix_batch_dims(batch_matrices, name=None):
"""Broadcast leading dimensions of zero or more [batch] matrices.
Example broadcasting one batch dim of two simple matrices.
```python
x = [[1, 2],
[3, 4]] # Shape [2, 2], no batch dims
y = [[[1]]] # Shape [1, 1, 1], 1 batch dim of shape [1]
x_bc, y_bc = broadcast_matrix_batch_dims([x, y])
x_bc
==> [[[1, 2],
[3, 4]]] # Shape [1, 2, 2], 1 batch dim of shape [1].
y_bc
==> same as y
```
Example broadcasting many batch dims
```python
x = tf.random_normal(shape=(2, 3, 1, 4, 4))
y = tf.random_normal(shape=(1, 3, 2, 5, 5))
x_bc, y_bc = broadcast_matrix_batch_dims([x, y])
x_bc.shape
==> (2, 3, 2, 4, 4)
y_bc.shape
==> (2, 3, 2, 5, 5)
```
Args:
batch_matrices: Iterable of `Tensor`s, each having two or more dimensions.
name: A string name to prepend to created ops.
Returns:
bcast_matrices: List of `Tensor`s, with `bcast_matricies[i]` containing
the values from `batch_matrices[i]`, with possibly broadcast batch dims.
Raises:
ValueError: If any input `Tensor` is statically determined to have less
than two dimensions.
"""
with ops.name_scope(
name or "broadcast_matrix_batch_dims", values=batch_matrices):
check_ops.assert_proper_iterable(batch_matrices)
batch_matrices = list(batch_matrices)
for i, mat in enumerate(batch_matrices):
batch_matrices[i] = ops.convert_to_tensor(mat)
assert_is_batch_matrix(batch_matrices[i])
if len(batch_matrices) < 2:
return batch_matrices
# Try static broadcasting.
# bcast_batch_shape is the broadcast batch shape of ALL matrices.
# E.g. if batch_matrices = [x, y], with
# x.shape = [2, j, k] (batch shape = [2])
# y.shape = [3, 1, l, m] (batch shape = [3, 1])
# ==> bcast_batch_shape = [3, 2]
bcast_batch_shape = batch_matrices[0].get_shape()[:-2]
for mat in batch_matrices[1:]:
bcast_batch_shape = array_ops.broadcast_static_shape(
bcast_batch_shape, mat.get_shape()[:-2])
if bcast_batch_shape.is_fully_defined():
# The [1, 1] at the end will broadcast with anything.
bcast_shape = bcast_batch_shape.concatenate([1, 1])
for i, mat in enumerate(batch_matrices):
if mat.get_shape()[:-2] != bcast_batch_shape:
batch_matrices[i] = _broadcast_to_shape(mat, bcast_shape)
return batch_matrices
# Since static didn't work, do dynamic, which always copies data.
bcast_batch_shape = array_ops.shape(batch_matrices[0])[:-2]
for mat in batch_matrices[1:]:
bcast_batch_shape = array_ops.broadcast_dynamic_shape(
bcast_batch_shape, array_ops.shape(mat)[:-2])
bcast_shape = array_ops.concat([bcast_batch_shape, [1, 1]], axis=0)
for i, mat in enumerate(batch_matrices):
batch_matrices[i] = _broadcast_to_shape(mat, bcast_shape)
return batch_matrices
def _broadcast_to_shape(x, shape):
return x + array_ops.zeros(shape=shape, dtype=x.dtype)
def matmul_with_broadcast(a,
b,
transpose_a=False,
transpose_b=False,
adjoint_a=False,
adjoint_b=False,
a_is_sparse=False,
b_is_sparse=False,
name=None):
"""Multiplies matrix `a` by matrix `b`, producing `a @ b`.
The inputs must be matrices (or tensors of rank > 2, representing batches of
matrices).
Both matrices must be of the same type. The supported types are:
`float16`, `float32`, `float64`, `int32`, `complex64`, `complex128`.
Either matrix can be transposed or adjointed (conjugated and transposed) on
the fly by setting one of the corresponding flag to `True`. These are `False`
by default.
If one or both of the matrices contain a lot of zeros, a more efficient
multiplication algorithm can be used by setting the corresponding
`a_is_sparse` or `b_is_sparse` flag to `True`. These are `False` by default.
This optimization is only available for plain matrices (rank-2 tensors) with
datatypes `bfloat16` or `float32`.
For example:
```python
# A 2-batch of 3x4 matrices
a = tf.random_normal(shape=(2, 3, 4))
# A single 4x5 matrix
b = tf.random_normal(shape=(4, 5))
result = matmul_with_broadcast(a, b)
result.shape
==> (2, 3, 5)
result[0,...]
==> tf.matmul(a[0,...], b)
result[1,...]
==> tf.matmul(a[1,...], b)
```
Args:
a: `Tensor` of type `float16`, `float32`, `float64`, `int32`, `complex64`,
`complex128` and `rank > 1`.
b: `Tensor` with same type as `a` having compatible matrix dimensions and
broadcastable batch dimensions.
transpose_a: If `True`, `a` is transposed before multiplication.
transpose_b: If `True`, `b` is transposed before multiplication.
adjoint_a: If `True`, `a` is conjugated and transposed before
multiplication.
adjoint_b: If `True`, `b` is conjugated and transposed before
multiplication.
a_is_sparse: If `True`, `a` is treated as a sparse matrix.
b_is_sparse: If `True`, `b` is treated as a sparse matrix.
name: Name for the operation (optional).
Returns:
A `Tensor` of the same type as `a` and `b` where each inner-most matrix is
the product of the corresponding matrices in `a` and `b`, e.g. if all
transpose or adjoint attributes are `False`:
The leading shape of `output` is the result of broadcasting the leading
dimensions of `a` and `b`.
`output`[..., i, j] = sum_k (`a`[..., i, k] * `b`[..., k, j]),
for all indices i, j.
Note: This is matrix product, not element-wise product.
Raises:
ValueError: If transpose_a and adjoint_a, or transpose_b and adjoint_b
are both set to True.
"""
with ops.name_scope(name, "MatMulWithBroadcast", [a, b]) as name:
a, b = broadcast_matrix_batch_dims([a, b])
return math_ops.matmul(
a,
b,
transpose_a=transpose_a,
transpose_b=transpose_b,
adjoint_a=adjoint_a,
adjoint_b=adjoint_b,
a_is_sparse=a_is_sparse,
b_is_sparse=b_is_sparse)
def shape_tensor(shape, name=None):
"""Convert Tensor using default type, unless empty list or tuple."""
# Works just like random_ops._ShapeTensor.

View File

@ -270,6 +270,8 @@ def sdca_model_fn(features, labels, mode, params, config=None):
with variable_scope.variable_op_scope(features.values(),
parent_scope) as scope:
features = features.copy()
features.update(layers.transform_features(features, feature_columns))
logits, columns_to_variables, bias = (
layers.weighted_sum_from_feature_columns(
columns_to_tensors=features,

View File

@ -1879,11 +1879,11 @@ def streaming_pearson_correlation(predictions,
math_ops.multiply(math_ops.sqrt(var_predictions),
math_ops.sqrt(var_labels)),
'pearson_r')
with ops.control_dependencies(
[update_cov, update_var_predictions, update_var_labels]):
update_op = _safe_div(update_cov, math_ops.multiply(
math_ops.sqrt(update_var_predictions),
math_ops.sqrt(update_var_labels)), 'update_op')
update_op = _safe_div(
update_cov,
math_ops.multiply(math_ops.sqrt(update_var_predictions),
math_ops.sqrt(update_var_labels)),
'update_op')
if metrics_collections:
ops.add_to_collections(metrics_collections, pearson_r)

View File

@ -42,6 +42,7 @@ See @{$python/contrib.rnn} guide.
@@GridLSTMCell
@@BidirectionalGridLSTMCell
@@NASCell
@@PhasedLSTMCell
### RNNCell wrappers
@@AttentionCellWrapper

View File

@ -173,7 +173,6 @@ class RNNCellTest(test.TestCase):
with self.test_session() as sess:
num_units = 8
batch_size = 3
input_size = 4
feature_size = 2
frequency_skip = 1
num_frequency_blocks = [1, 1]
@ -844,6 +843,45 @@ class RNNCellTest(test.TestCase):
"be set to num_units at cell init."):
cell(inputs, init_state)
def testPhasedLSTMCell(self):
with self.test_session() as sess:
num_units = 2
batch_size = 3
input_size = 4
expected_state_c = np.array(
[[2.954548e-01, 8.354891e-04],
[2.834632e-01, 8.158963e-01],
[2.291694e-01, 1.325745e-04]],
dtype=np.float32)
expected_state_h = np.array(
[[2.116566e-01, 5.985238e-04],
[2.137760e-01, 6.153145e-01],
[1.742966e-01, 1.008306e-04]],
dtype=np.float32)
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
t = array_ops.zeros([batch_size, 1], dtype=dtypes.float64)
x = array_ops.zeros([batch_size, input_size])
c0 = array_ops.zeros([batch_size, 2])
h0 = array_ops.zeros([batch_size, 2])
state0 = core_rnn_cell_impl.LSTMStateTuple(c0, h0)
output, state = rnn_cell.PhasedLSTMCell(num_units=num_units)((t, x),
state0)
sess.run([variables.global_variables_initializer()])
res = sess.run([output, state], {
t.name:
np.array([[1.], [2.], [3.]]),
x.name:
np.array([[1., 1., 1., 1.],
[2., 2., 2., 2.],
[3., 3., 3., 3.]]),
})
# This is a smoke test, making sure expected values are unchanged.
self.assertEqual(len(res), 2)
self.assertAllClose(res[0], res[1].h)
self.assertAllClose(res[1].c, expected_state_c)
self.assertAllClose(res[1].h, expected_state_h)
class LayerNormBasicLSTMCellTest(test.TestCase):

View File

@ -33,6 +33,7 @@ from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
@ -1683,3 +1684,178 @@ class CompiledWrapper(core_rnn_cell.RNNCell):
with jit.experimental_jit_scope(compile_ops=compile_ops):
return self._cell(inputs, state, scope=scope)
def _random_exp_initializer(minval,
maxval,
seed=None,
dtype=dtypes.float32):
"""Returns an exponential distribution initializer.
Args:
minval: float or a scalar float Tensor. With value > 0. Lower bound of the
range of random values to generate.
maxval: float or a scalar float Tensor. With value > minval. Upper bound of
the range of random values to generate.
seed: An integer. Used to create random seeds.
dtype: The data type.
Returns:
An initializer that generates tensors with an exponential distribution.
"""
def _initializer(shape, dtype=dtype, partition_info=None):
del partition_info # Unused.
return math_ops.exp(
random_ops.random_uniform(
shape,
math_ops.log(minval),
math_ops.log(maxval),
dtype,
seed=seed))
return _initializer
class PhasedLSTMCell(core_rnn_cell.RNNCell):
"""Phased LSTM recurrent network cell.
https://arxiv.org/pdf/1610.09513v1.pdf
"""
def __init__(self,
num_units,
use_peepholes=False,
leak=0.001,
ratio_on=0.1,
trainable_ratio_on=True,
period_init_min=1.0,
period_init_max=1000.0,
reuse=None):
"""Initialize the Phased LSTM cell.
Args:
num_units: int, The number of units in the Phased LSTM cell.
use_peepholes: bool, set True to enable peephole connections.
leak: float or scalar float Tensor with value in [0, 1]. Leak applied
during training.
ratio_on: float or scalar float Tensor with value in [0, 1]. Ratio of the
period during which the gates are open.
trainable_ratio_on: bool, weather ratio_on is trainable.
period_init_min: float or scalar float Tensor. With value > 0.
Minimum value of the initalized period.
The period values are initialized by drawing from the distribution:
e^U(log(period_init_min), log(period_init_max))
Where U(.,.) is the uniform distribution.
period_init_max: float or scalar float Tensor.
With value > period_init_min. Maximum value of the initalized period.
reuse: (optional) Python boolean describing whether to reuse variables
in an existing scope. If not `True`, and the existing scope already has
the given variables, an error is raised.
"""
self._num_units = num_units
self._use_peepholes = use_peepholes
self._leak = leak
self._ratio_on = ratio_on
self._trainable_ratio_on = trainable_ratio_on
self._period_init_min = period_init_min
self._period_init_max = period_init_max
self._reuse = reuse
@property
def state_size(self):
return core_rnn_cell.LSTMStateTuple(self._num_units, self._num_units)
@property
def output_size(self):
return self._num_units
def _mod(self, x, y):
"""Modulo function that propagates x gradients."""
return array_ops.stop_gradient(math_ops.mod(x, y) - x) + x
def _get_cycle_ratio(self, time, phase, period):
"""Compute the cycle ratio in the dtype of the time."""
phase_casted = math_ops.cast(phase, dtype=time.dtype)
period_casted = math_ops.cast(period, dtype=time.dtype)
shifted_time = time - phase_casted
cycle_ratio = self._mod(shifted_time, period_casted) / period_casted
return math_ops.cast(cycle_ratio, dtype=dtypes.float32)
def __call__(self, inputs, state, scope=None):
"""Phased LSTM Cell.
Args:
inputs: A tuple of 2 Tensor.
The first Tensor has shape [batch, 1], and type float32 or float64.
It stores the time.
The second Tensor has shape [batch, features_size], and type float32.
It stores the features.
state: core_rnn_cell.LSTMStateTuple, state from previous timestep.
scope: string, id of the variable scope.
Returns:
A tuple containing:
- A Tensor of float32, and shape [batch_size, num_units], representing the
output of the cell.
- A core_rnn_cell.LSTMStateTuple, containing 2 Tensors of float32, shape
[batch_size, num_units], representing the new state and the output.
"""
with _checked_scope(self, scope or "phased_lstm_cell", reuse=self._reuse):
(c_prev, h_prev) = state
(time, x) = inputs
in_mask_gates = [x, h_prev]
if self._use_peepholes:
in_mask_gates.append(c_prev)
with vs.variable_scope("mask_gates"):
mask_gates = math_ops.sigmoid(
_linear(in_mask_gates, 2 * self._num_units, True))
[input_gate, forget_gate] = array_ops.split(
axis=1, num_or_size_splits=2, value=mask_gates)
with vs.variable_scope("new_input"):
new_input = math_ops.tanh(
_linear([x, h_prev], self._num_units, True))
new_c = (c_prev * forget_gate + input_gate * new_input)
in_out_gate = [x, h_prev]
if self._use_peepholes:
in_out_gate.append(new_c)
with vs.variable_scope("output_gate"):
output_gate = math_ops.sigmoid(
_linear(in_out_gate, self._num_units, True))
new_h = math_ops.tanh(new_c) * output_gate
period = vs.get_variable(
"period", [self._num_units],
initializer=_random_exp_initializer(
self._period_init_min, self._period_init_max))
phase = vs.get_variable(
"phase", [self._num_units],
initializer=init_ops.random_uniform_initializer(
0., period.initial_value))
ratio_on = vs.get_variable(
"ratio_on", [self._num_units],
initializer=init_ops.constant_initializer(self._ratio_on),
trainable=self._trainable_ratio_on)
cycle_ratio = self._get_cycle_ratio(time, phase, period)
k_up = 2 * cycle_ratio / ratio_on
k_down = 2 - k_up
k_closed = self._leak * cycle_ratio
k = array_ops.where(cycle_ratio < ratio_on, k_down, k_closed)
k = array_ops.where(cycle_ratio < 0.5 * ratio_on, k_up, k)
new_c = k * new_c + (1 - k) * c_prev
new_h = k * new_h + (1 - k) * h_prev
new_state = core_rnn_cell.LSTMStateTuple(new_c, new_h)
return new_h, new_state

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
# pylint: enable=unused-import
import sys
import functools
import numpy as np
@ -38,6 +39,15 @@ from tensorflow.python.util import nest
# pylint: enable=g-import-not-at-top
# for testing
AttentionWrapperState = wrapper.AttentionWrapperState # pylint: disable=invalid-name
LSTMStateTuple = core_rnn_cell.LSTMStateTuple # pylint: disable=invalid-name
BasicDecoderOutput = basic_decoder.BasicDecoderOutput # pylint: disable=invalid-name
float32 = np.float32
int32 = np.int32
array = np.array
class AttentionWrapperTest(test.TestCase):
def assertAllClose(self, *args, **kwargs):
@ -48,10 +58,11 @@ class AttentionWrapperTest(test.TestCase):
def _testWithAttention(self,
create_attention_mechanism,
expected_final_outputs,
expected_final_output,
expected_final_state,
attention_mechanism_depth=3,
attention_history=False):
attention_history=False,
name=""):
encoder_sequence_length = [3, 2, 3, 1, 0]
decoder_sequence_length = [2, 0, 1, 2, 3]
batch_size = 5
@ -126,7 +137,13 @@ class AttentionWrapperTest(test.TestCase):
"state_attention_history": state_attention_history,
})
nest.map_structure(self.assertAllClose, expected_final_outputs,
print("Copy/paste (%s)\nexpected_final_output = " % name,
sess_results["final_outputs"])
sys.stdout.flush()
print("Copy/paste (%s)\nexpected_final_state = " % name,
sess_results["final_state"])
sys.stdout.flush()
nest.map_structure(self.assertAllClose, expected_final_output,
sess_results["final_outputs"])
nest.map_structure(self.assertAllClose, expected_final_state,
sess_results["final_state"])
@ -137,533 +154,534 @@ class AttentionWrapperTest(test.TestCase):
np.transpose(sess_results["final_outputs"].rnn_output,
(1, 0, 2)))
def testBahndahauNotNormalized(self):
def testBahdanauNotNormalized(self):
create_attention_mechanism = wrapper.BahdanauAttention
array = np.array
float32 = np.float32
int32 = np.int32
expected_final_outputs = basic_decoder.BasicDecoderOutput(
expected_final_output = BasicDecoderOutput(
rnn_output=array(
[[[
1.25166783e-02, -6.88887993e-03, 3.17239435e-03,
-1.98234897e-03, 4.77387803e-03, -1.38330357e-02
1.89980457e-03, 1.89681584e-03, 2.05339328e-03, -3.83376027e-03,
-4.31808922e-03, -6.45466987e-03
], [
1.28883058e-02, -6.76271692e-03, 3.13419267e-03,
-2.02183682e-03, 5.62057737e-03, -1.35373026e-02
2.27232254e-03, 2.02509761e-03, 2.01666891e-03, -3.87230632e-03,
-3.47119337e-03, -6.15991233e-03
], [
1.24917831e-02, -6.71574520e-03, 3.42238229e-03,
-1.79501204e-03, 5.33161033e-03, -1.36620644e-02
1.87640532e-03, 2.07374478e-03, 2.30582547e-03, -3.64564802e-03,
-3.75995948e-03, -6.28685066e-03
]], [[
1.55150667e-02, -1.07274549e-02, 4.44198400e-03,
-9.73310322e-04, 1.27242506e-02, -1.21861566e-02
4.89835022e-03, -1.94158917e-03, 3.32316267e-03,
-2.82446202e-03, 3.63192149e-03, -4.80734091e-03
], [
1.57585666e-02, -1.07965544e-02, 4.61554807e-03,
-1.01510016e-03, 1.22341057e-02, -1.27029382e-02
5.14256489e-03, -2.00877781e-03, 3.49807227e-03,
-2.86567654e-03, 3.14202951e-03, -5.32575324e-03
], [
1.58304181e-02, -1.09712025e-02, 4.67861444e-03,
-1.03920139e-03, 1.23004699e-02, -1.25949886e-02
5.21511910e-03, -2.18198029e-03, 3.56219849e-03,
-2.88951304e-03, 3.20866983e-03, -5.21918852e-03
]], [[
9.26700700e-03, -9.75431874e-03, -9.95740294e-04,
-1.27463136e-06, 3.81659716e-03, -1.64887272e-02
-1.34951377e-03, -9.68646549e-04, -2.11444520e-03,
-1.85243192e-03, -5.27541339e-03, -9.10969637e-03
], [
9.25191958e-03, -9.80092678e-03, -8.48566880e-04,
5.02091134e-05, 3.46567202e-03, -1.67435352e-02
-1.36390887e-03, -1.01293903e-03, -1.96592091e-03,
-1.80044665e-03, -5.62618347e-03, -9.36636236e-03
], [
9.48173273e-03, -9.52653307e-03, -8.79382715e-04,
-3.07094306e-05, 4.05955408e-03, -1.67226996e-02
-1.13357347e-03, -7.37126335e-04, -1.99582824e-03,
-1.88097963e-03, -5.03196474e-03, -9.34652984e-03
]], [[
1.21462569e-02, -1.27578378e-02, 1.54045003e-04, 2.70257704e-03,
7.79421115e-03, -8.14041123e-04
1.52963377e-03, -3.97205260e-03, -9.64675564e-04,
8.51404853e-04, -1.29804458e-03, 6.56467676e-03
], [
1.18412934e-02, -1.33513296e-02, 3.54760559e-05, 2.67801876e-03,
6.99122995e-03, -9.46014654e-04
1.22557906e-03, -4.56343032e-03, -1.08188344e-03,
8.27252632e-04, -2.10058759e-03, 6.43082103e-03
], [
1.16087487e-02, -1.31632648e-02, -2.98853614e-04,
2.49515846e-03, 6.92677684e-03, -6.92734495e-04
9.93478228e-04, -4.37378604e-03, -1.41531695e-03,
6.44775166e-04, -2.16480484e-03, 6.68286439e-03
]], [[
1.02377674e-02, -8.72955937e-03, 1.22555892e-03, 2.03830865e-03,
8.93574394e-03, -7.28237582e-03
-3.78854020e-04, 5.62231544e-05, 1.06837302e-04, 1.87137164e-04,
-1.56512906e-04, 9.63474595e-05
], [
1.05115287e-02, -8.92531779e-03, 1.14568521e-03, 1.91635895e-03,
8.94328393e-03, -7.39541650e-03
-1.04306288e-04, -1.37411975e-04, 2.82689070e-05,
6.56487318e-05, -1.48634164e-04, -1.84347919e-05
], [
1.07398070e-02, -8.56867433e-03, 1.52354129e-03, 2.06834078e-03,
9.36511997e-03, -7.64556089e-03
1.24452345e-04, 2.20821079e-04, 4.07114130e-04, 2.18028668e-04,
2.73401442e-04, -2.69805576e-04
]]],
dtype=float32),
sample_id=array(
[[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
[[2, 0, 2], [0, 0, 0], [1, 1, 1], [5, 5, 5], [3, 3, 2]],
dtype=int32))
expected_final_state = wrapper.AttentionWrapperState(
time=3,
attention_history=(),
cell_state=core_rnn_cell.LSTMStateTuple(
expected_final_state = AttentionWrapperState(
cell_state=LSTMStateTuple(
c=array(
[[
-0.0220502, -0.008058, -0.00160266, 0.01609341, -0.01380513,
-0.00749483, -0.00816989, -0.01210028, 0.01795324
-2.18963176e-02, -8.04424379e-03, -1.48289464e-03,
1.61068402e-02, -1.37983467e-02, -7.57976994e-03,
-8.28560349e-03, -1.18737305e-02, 1.78835373e-02
], [
0.01727026, -0.0142065, -0.00399991, 0.03195379,
-0.03547479, -0.02138772, -0.00610318, -0.00191625,
-0.01937846
1.74205080e-02, -1.41929444e-02, -3.88092734e-03,
3.19708064e-02, -3.54689620e-02, -2.14698724e-02,
-6.21716119e-03, -1.69295724e-03, -1.94495302e-02
], [
-0.0116077, 0.00876439, -0.01641787, -0.01400803,
0.01347527, -0.01036386, 0.00627491, -0.0096361, -0.00650565
-1.14528481e-02, 8.77819210e-03, -1.62970200e-02,
-1.39963552e-02, 1.34831406e-02, -1.04494914e-02,
6.16127765e-03, -9.41022579e-03, -6.57590060e-03
], [
-0.04763387, -0.01192631, -0.00019412, 0.04103886,
-0.00137999, 0.02126684, -0.02793711, -0.05467696,
-0.02912051
-4.74753827e-02, -1.19123599e-02, -7.40140676e-05,
4.10552323e-02, -1.36711076e-03, 2.11795457e-02,
-2.80460119e-02, -5.44509329e-02, -2.91906092e-02
], [
0.02241185, -0.00141741, 0.01911988, 0.00547728,
-0.01280068, -0.00307024, -0.00494239, 0.02169247,
0.01631995
2.25644894e-02, -1.40382675e-03, 1.92396250e-02,
5.49034867e-03, -1.27930511e-02, -3.15603940e-03,
-5.05525898e-03, 2.19191350e-02, 1.62497871e-02
]],
dtype=float32),
h=array(
[[
-1.10613741e-02, -3.98175791e-03, -8.15514475e-04,
7.90482666e-03, -7.02390168e-03, -3.76394135e-03,
-4.16183751e-03, -6.17114361e-03, 8.95532221e-03
-1.09840557e-02, -3.97477299e-03, -7.54582870e-04,
7.91188516e-03, -7.02184858e-03, -3.80711886e-03,
-4.22059745e-03, -6.05464494e-03, 8.92061181e-03
], [
8.60657450e-03, -7.17655150e-03, -1.94156705e-03,
1.62583217e-02, -1.76821016e-02, -1.06200138e-02,
-3.01904045e-03, -9.57608980e-04, -9.95732192e-03
8.68131686e-03, -7.16938032e-03, -1.88384682e-03,
1.62678920e-02, -1.76827926e-02, -1.06622791e-02,
-3.07528162e-03, -8.45885137e-04, -9.99388192e-03
], [
-5.78935863e-03, 4.49362956e-03, -8.13615043e-03,
-6.95384294e-03, 6.75151078e-03, -5.07845683e-03,
3.11869266e-03, -4.72904649e-03, -3.20469099e-03
-5.71205560e-03, 4.50050412e-03, -8.07640795e-03,
-6.94844872e-03, 6.75682165e-03, -5.12113515e-03,
3.06208082e-03, -4.61743120e-03, -3.23931244e-03
], [
-2.38025561e-02, -5.89242764e-03, -9.76260417e-05,
2.01697368e-02, -6.82076614e-04, 1.07111251e-02,
-1.42077375e-02, -2.70790439e-02, -1.44685479e-02
-2.37231534e-02, -5.88526297e-03, -3.72226204e-05,
2.01789513e-02, -6.75848918e-04, 1.06686354e-02,
-1.42624676e-02, -2.69628745e-02, -1.45034352e-02
], [
1.11825848e-02, -6.99267141e-04, 9.82748345e-03,
2.74566701e-03, -6.56377291e-03, -1.53681310e-03,
-2.48806458e-03, 1.10462429e-02, 7.97568541e-03
1.12585640e-02, -6.92534202e-04, 9.88917705e-03,
2.75237625e-03, -6.56115822e-03, -1.57997780e-03,
-2.54477374e-03, 1.11598391e-02, 7.94144534e-03
]],
dtype=float32)),
attention=array(
[[
1.24917831e-02, -6.71574520e-03, 3.42238229e-03,
-1.79501204e-03, 5.33161033e-03, -1.36620644e-02
0.00187641, 0.00207374, 0.00230583, -0.00364565, -0.00375996,
-0.00628685
], [
1.58304181e-02, -1.09712025e-02, 4.67861444e-03,
-1.03920139e-03, 1.23004699e-02, -1.25949886e-02
0.00521512, -0.00218198, 0.0035622, -0.00288951, 0.00320867,
-0.00521919
], [
9.48173273e-03, -9.52653307e-03, -8.79382715e-04,
-3.07094306e-05, 4.05955408e-03, -1.67226996e-02
-0.00113357, -0.00073713, -0.00199583, -0.00188098, -0.00503196,
-0.00934653
], [
1.16087487e-02, -1.31632648e-02, -2.98853614e-04,
2.49515846e-03, 6.92677684e-03, -6.92734495e-04
0.00099348, -0.00437379, -0.00141532, 0.00064478, -0.0021648,
0.00668286
], [
1.07398070e-02, -8.56867433e-03, 1.52354129e-03, 2.06834078e-03,
9.36511997e-03, -7.64556089e-03
0.00012445, 0.00022082, 0.00040711, 0.00021803, 0.0002734,
-0.00026981
]],
dtype=float32))
self._testWithAttention(create_attention_mechanism, expected_final_outputs,
expected_final_state, attention_history=True)
dtype=float32),
time=3,
attention_history=())
def testBahndahauNormalized(self):
self._testWithAttention(
create_attention_mechanism,
expected_final_output,
expected_final_state,
attention_history=True,
name="testBahdanauNotNormalized")
def testBahdanauNormalized(self):
create_attention_mechanism = functools.partial(
wrapper.BahdanauAttention, normalize=True, attention_r_initializer=2.0)
wrapper.BahdanauAttention, normalize=True)
array = np.array
float32 = np.float32
int32 = np.int32
expected_final_output = basic_decoder.BasicDecoderOutput(
expected_final_output = BasicDecoderOutput(
rnn_output=array(
[[[
1.72670335e-02, -5.83671592e-03, 6.38638902e-03,
-8.11776379e-04, 1.12681929e-03, -1.24236047e-02
6.64783875e-03, 2.94425711e-03, 5.26542449e-03, -2.64955591e-03,
-7.95925129e-03, -5.02286293e-03
], [
1.75918192e-02, -5.73426578e-03, 6.29768707e-03,
-8.63141613e-04, 2.03352375e-03, -1.21420780e-02
7.01954123e-03, 3.07301106e-03, 5.22849336e-03, -2.68844375e-03,
-7.11239874e-03, -4.72904276e-03
], [
1.72424167e-02, -5.66471322e-03, 6.63427915e-03,
-6.23903936e-04, 1.68706616e-03, -1.22524602e-02
6.62360899e-03, 3.12234787e-03, 5.51807694e-03, -2.46222341e-03,
-7.40198931e-03, -4.85701021e-03
]], [[
1.79958157e-02, -9.80986748e-03, 4.73218597e-03,
-3.89962713e-03, 1.41502675e-02, -1.48344040e-02
7.37589924e-03, -1.02620223e-03, 3.61374952e-03,
-5.74620720e-03, 5.05625410e-03, -7.45209027e-03
], [
1.82184577e-02, -9.88379307e-03, 4.90130857e-03,
-3.91892251e-03, 1.36479288e-02, -1.53291579e-02
7.61946291e-03, -1.09287468e-03, 3.78817180e-03,
-5.78709645e-03, 4.56611114e-03, -7.96987582e-03
], [
1.83001235e-02, -1.00617753e-02, 4.97077405e-03,
-3.94908339e-03, 1.37211196e-02, -1.52311027e-02
7.69207766e-03, -1.26582675e-03, 3.85218812e-03,
-5.81111759e-03, 4.63287206e-03, -7.86337163e-03
]], [[
7.93476030e-03, -8.46967567e-03, -7.16930721e-04,
4.37953044e-04, 1.04503892e-03, -1.82424393e-02
-2.69413739e-03, 3.47183552e-04, -1.82145904e-03,
-1.39805069e-03, -8.05486552e-03, -1.08372131e-02
], [
7.90629163e-03, -8.48819874e-03, -5.57833235e-04,
5.02390554e-04, 6.79406337e-04, -1.84837580e-02
-2.70848931e-03, 3.03293345e-04, -1.67230750e-03,
-1.34555507e-03, -8.40565283e-03, -1.10935047e-02
], [
8.14734399e-03, -8.23053624e-03, -5.92814526e-04,
4.16347990e-04, 1.29250437e-03, -1.84548404e-02
-2.47822329e-03, 5.79408603e-04, -1.70188327e-03,
-1.42583530e-03, -7.81180616e-03, -1.10740755e-02
]], [[
1.21026095e-02, -1.26739489e-02, 1.78718648e-04, 2.68748170e-03,
7.80996867e-03, -9.69076063e-04
1.48582947e-03, -3.88786104e-03, -9.39912978e-04,
8.36255029e-04, -1.28223014e-03, 6.40908210e-03
], [
1.17978491e-02, -1.32678337e-02, 6.00410858e-05, 2.66301399e-03,
7.00691342e-03, -1.10030361e-03
1.18177081e-03, -4.47923271e-03, -1.05711201e-03,
8.12121783e-04, -2.08477327e-03, 6.27523474e-03
], [
1.15651665e-02, -1.30795036e-02, -2.74205930e-04,
2.48012133e-03, 6.94250735e-03, -8.47495161e-04
9.49664740e-04, -4.28957958e-03, -1.39053771e-03,
6.29657647e-04, -2.14899099e-03, 6.52727811e-03
]], [[
1.02377674e-02, -8.72955937e-03, 1.22555892e-03, 2.03830865e-03,
8.93574394e-03, -7.28237582e-03
-3.78854020e-04, 5.62231544e-05, 1.06837302e-04, 1.87137164e-04,
-1.56512906e-04, 9.63474595e-05
], [
1.05115287e-02, -8.92531779e-03, 1.14568521e-03, 1.91635895e-03,
8.94328393e-03, -7.39541650e-03
-1.04306288e-04, -1.37411975e-04, 2.82689070e-05,
6.56487318e-05, -1.48634164e-04, -1.84347919e-05
], [
1.07398070e-02, -8.56867433e-03, 1.52354129e-03, 2.06834078e-03,
9.36511997e-03, -7.64556089e-03
1.24452345e-04, 2.20821079e-04, 4.07114130e-04, 2.18028668e-04,
2.73401442e-04, -2.69805576e-04
]]],
dtype=float32),
sample_id=array(
[[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
[[0, 0, 0], [0, 0, 0], [1, 1, 1], [5, 5, 5], [3, 3, 2]],
dtype=int32))
expected_final_state = wrapper.AttentionWrapperState(
time=3,
attention_history=(),
cell_state=core_rnn_cell.LSTMStateTuple(
expected_final_state = AttentionWrapperState(
cell_state=LSTMStateTuple(
c=array(
[[
-0.02209264, -0.00794879, -0.00157153, 0.01614309,
-0.01383773, -0.00750943, -0.00824213, -0.01210296,
0.01794949
-2.19389871e-02, -7.93421268e-03, -1.45148858e-03,
1.61569901e-02, -1.38310911e-02, -7.59426132e-03,
-8.35836027e-03, -1.18763093e-02, 1.78797375e-02
], [
0.01726926, -0.01418139, -0.0040099, 0.0319339, -0.03545783,
-0.02142831, -0.00609501, -0.00195033, -0.01938949
1.74194798e-02, -1.41677596e-02, -3.89095861e-03,
3.19508761e-02, -3.54519747e-02, -2.15105712e-02,
-6.20894879e-03, -1.72719418e-03, -1.94605980e-02
], [
-0.01159083, 0.0087524, -0.01639001, -0.01400012,
0.01342422, -0.01041037, 0.00620991, -0.00960796,
-0.00650131
-1.14357909e-02, 8.76635592e-03, -1.62690803e-02,
-1.39883338e-02, 1.34323873e-02, -1.04959216e-02,
6.09614328e-03, -9.38197412e-03, -6.57159975e-03
], [
-0.04763237, -0.01192762, -0.00019377, 0.04103839,
-0.00138058, 0.02126443, -0.02793917, -0.05467755,
-0.02912025
-4.74738739e-02, -1.19136795e-02, -7.36564398e-05,
4.10547666e-02, -1.36771239e-03, 2.11771261e-02,
-2.80481018e-02, -5.44515178e-02, -2.91903559e-02
], [
0.02241185, -0.00141741, 0.01911988, 0.00547728,
-0.01280068, -0.00307024, -0.00494239, 0.02169247,
0.01631995
2.25644894e-02, -1.40382675e-03, 1.92396250e-02,
5.49034867e-03, -1.27930511e-02, -3.15603940e-03,
-5.05525898e-03, 2.19191350e-02, 1.62497871e-02
]],
dtype=float32),
h=array(
[[
-1.10821165e-02, -3.92766716e-03, -7.99638336e-04,
7.92923011e-03, -7.04019284e-03, -3.77124036e-03,
-4.19876305e-03, -6.17261464e-03, 8.95325281e-03
-1.10049099e-02, -3.92028037e-03, -7.38571223e-04,
7.93652050e-03, -7.03821564e-03, -3.81436548e-03,
-4.25778655e-03, -6.05606195e-03, 8.91851448e-03
], [
8.60597286e-03, -7.16368994e-03, -1.94644753e-03,
1.62479617e-02, -1.76739115e-02, -1.06403306e-02,
-3.01484042e-03, -9.74688213e-04, -9.96260438e-03
8.68070032e-03, -7.15647917e-03, -1.88874488e-03,
1.62575077e-02, -1.76745858e-02, -1.06826536e-02,
-3.07105901e-03, -8.63034453e-04, -9.99918394e-03
], [
-5.78098884e-03, 4.48751403e-03, -8.12216662e-03,
-6.94991415e-03, 6.72604749e-03, -5.10144979e-03,
3.08637507e-03, -4.71517537e-03, -3.20256175e-03
-5.70359221e-03, 4.49446775e-03, -8.06238409e-03,
-6.94446685e-03, 6.73149945e-03, -5.14409645e-03,
3.02969781e-03, -4.60351165e-03, -3.23720207e-03
], [
-2.38018110e-02, -5.89307398e-03, -9.74484938e-05,
2.01694984e-02, -6.82370039e-04, 1.07099237e-02,
-1.42087601e-02, -2.70793457e-02, -1.44684138e-02
-2.37224046e-02, -5.88591257e-03, -3.70427515e-05,
2.01787166e-02, -6.76146999e-04, 1.06674293e-02,
-1.42635051e-02, -2.69631781e-02, -1.45033030e-02
], [
1.11825848e-02, -6.99267141e-04, 9.82748345e-03,
2.74566701e-03, -6.56377291e-03, -1.53681310e-03,
-2.48806458e-03, 1.10462429e-02, 7.97568541e-03
1.12585640e-02, -6.92534202e-04, 9.88917705e-03,
2.75237625e-03, -6.56115822e-03, -1.57997780e-03,
-2.54477374e-03, 1.11598391e-02, 7.94144534e-03
]],
dtype=float32)),
attention=array(
[[
0.01724242, -0.00566471, 0.00663428, -0.0006239, 0.00168707,
-0.01225246
0.00662361, 0.00312235, 0.00551808, -0.00246222, -0.00740199,
-0.00485701
], [
0.01830012, -0.01006178, 0.00497077, -0.00394908, 0.01372112,
-0.0152311
0.00769208, -0.00126583, 0.00385219, -0.00581112, 0.00463287,
-0.00786337
], [
0.00814734, -0.00823054, -0.00059281, 0.00041635, 0.0012925,
-0.01845484
-0.00247822, 0.00057941, -0.00170188, -0.00142584, -0.00781181,
-0.01107408
], [
0.01156517, -0.0130795, -0.00027421, 0.00248012, 0.00694251,
-0.0008475
0.00094966, -0.00428958, -0.00139054, 0.00062966, -0.00214899,
0.00652728
], [
0.01073981, -0.00856867, 0.00152354, 0.00206834, 0.00936512,
-0.00764556
0.00012445, 0.00022082, 0.00040711, 0.00021803, 0.0002734,
-0.00026981
]],
dtype=float32))
dtype=float32),
time=3,
attention_history=())
self._testWithAttention(create_attention_mechanism, expected_final_output,
expected_final_state)
self._testWithAttention(
create_attention_mechanism,
expected_final_output,
expected_final_state,
name="testBahdanauNormalized")
def testLuongNotNormalized(self):
create_attention_mechanism = wrapper.LuongAttention
array = np.array
float32 = np.float32
int32 = np.int32
expected_final_output = basic_decoder.BasicDecoderOutput(
expected_final_output = BasicDecoderOutput(
rnn_output=array(
[[[
1.23641128e-02, -6.82715839e-03, 3.24165262e-03,
-1.90772023e-03, 4.69654519e-03, -1.37025211e-02
1.74749165e-03, 1.95862399e-03, 2.12293095e-03, -3.75889172e-03,
-4.39571124e-03, -6.32379763e-03
], [
1.29463980e-02, -6.79699238e-03, 3.10124992e-03,
-2.02869414e-03, 5.66399656e-03, -1.35517996e-02
2.33045570e-03, 1.99094601e-03, 1.98377599e-03, -3.87950847e-03,
-3.42792575e-03, -6.17497414e-03
], [
1.22659411e-02, -6.81970268e-03, 3.15135531e-03,
-1.96937821e-03, 5.62768336e-03, -1.39173865e-02
1.65032526e-03, 1.96972815e-03, 2.03462853e-03, -3.82007333e-03,
-3.46369296e-03, -6.54224353e-03
]], [[
1.53944232e-02, -1.07725551e-02, 4.42822604e-03,
-8.30623554e-04, 1.26549732e-02, -1.20573286e-02
4.77780215e-03, -1.98677275e-03, 3.30950436e-03,
-2.68179504e-03, 3.56271653e-03, -4.67860466e-03
], [
1.57453734e-02, -1.08157266e-02, 4.62466478e-03,
-9.88351414e-04, 1.22286947e-02, -1.26876952e-02
5.13039157e-03, -2.02797214e-03, 3.50760575e-03,
-2.83981953e-03, 3.13726603e-03, -5.31156827e-03
], [
1.57857724e-02, -1.09536834e-02, 4.64798324e-03,
-1.01319887e-03, 1.22695938e-02, -1.25500849e-02
5.17205056e-03, -2.16446724e-03, 3.53219034e-03,
-2.86490913e-03, 3.17879021e-03, -5.17592067e-03
]], [[
9.23123397e-03, -9.42669343e-03, -9.09919385e-04,
6.09827694e-05, 3.90436035e-03, -1.63374804e-02
-1.38538703e-03, -6.40910701e-04, -2.02864106e-03,
-1.79018872e-03, -5.18789608e-03, -8.95875692e-03
], [
9.22935922e-03, -9.57853813e-03, -7.92966573e-04,
8.89014918e-05, 3.52671882e-03, -1.66499857e-02
-1.38620089e-03, -7.92010222e-04, -1.91070826e-03,
-1.76206254e-03, -5.56525169e-03, -9.27332044e-03
], [
9.49526206e-03, -9.39475093e-03, -8.49372707e-04,
-1.72815053e-05, 4.16132808e-03, -1.66336838e-02
-1.11966045e-03, -6.07630936e-04, -1.96643686e-03,
-1.86803937e-03, -4.93048411e-03, -9.25842486e-03
]], [[
1.21248290e-02, -1.27166547e-02, 1.66158192e-04, 2.69516627e-03,
7.80194718e-03, -8.90152063e-04
1.50820788e-03, -3.93087184e-03, -9.52563598e-04,
8.43994785e-04, -1.29030924e-03, 6.48857141e-03
], [
1.17861275e-02, -1.32453050e-02, 6.66640699e-05, 2.65894993e-03,
7.01114535e-03, -1.14195189e-03
1.17029145e-03, -4.45716921e-03, -1.05062663e-03,
8.08141369e-04, -2.08062865e-03, 6.23444980e-03
], [
1.15833860e-02, -1.31145213e-02, -2.84505659e-04,
2.48642010e-03, 6.93593081e-03, -7.82784075e-04
9.67921398e-04, -4.32466762e-03, -1.40085898e-03,
6.35969569e-04, -2.15558149e-03, 6.59212377e-03
]], [[
1.02377674e-02, -8.72955937e-03, 1.22555892e-03, 2.03830865e-03,
8.93574394e-03, -7.28237582e-03
-3.78854020e-04, 5.62231544e-05, 1.06837302e-04, 1.87137164e-04,
-1.56512906e-04, 9.63474595e-05
], [
1.05115287e-02, -8.92531779e-03, 1.14568521e-03, 1.91635895e-03,
8.94328393e-03, -7.39541650e-03
-1.04306288e-04, -1.37411975e-04, 2.82689070e-05,
6.56487318e-05, -1.48634164e-04, -1.84347919e-05
], [
1.07398070e-02, -8.56867433e-03, 1.52354129e-03, 2.06834078e-03,
9.36511997e-03, -7.64556089e-03
1.24452345e-04, 2.20821079e-04, 4.07114130e-04, 2.18028668e-04,
2.73401442e-04, -2.69805576e-04
]]],
dtype=float32),
sample_id=array(
[[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
[[2, 0, 2], [0, 0, 0], [1, 1, 1], [5, 5, 5], [3, 3, 2]],
dtype=int32))
expected_final_state = wrapper.AttentionWrapperState(
time=3,
attention_history=(),
cell_state=core_rnn_cell.LSTMStateTuple(
expected_final_state = AttentionWrapperState(
cell_state=LSTMStateTuple(
c=array(
[[
-0.02204997, -0.00805805, -0.00160245, 0.01609369,
-0.01380494, -0.00749439, -0.00817, -0.01209992, 0.01795316
-2.18960866e-02, -8.04429129e-03, -1.48267671e-03,
1.61071159e-02, -1.37981661e-02, -7.57933082e-03,
-8.28570686e-03, -1.18733812e-02, 1.78834442e-02
], [
0.01727016, -0.01420713, -0.00399972, 0.03195436,
-0.03547532, -0.02138666, -0.00610335, -0.00191557,
-0.01937821
1.74204130e-02, -1.41935758e-02, -3.88074201e-03,
3.19713727e-02, -3.54694910e-02, -2.14688145e-02,
-6.21731905e-03, -1.69229065e-03, -1.94492843e-02
], [
-0.01160429, 0.00876595, -0.01641685, -0.01400784,
0.01348004, -0.01036458, 0.00627241, -0.00963544,
-0.00650568
-1.14494488e-02, 8.77974741e-03, -1.62960067e-02,
-1.39961652e-02, 1.34879015e-02, -1.04502086e-02,
6.15879148e-03, -9.40956455e-03, -6.57592434e-03
], [
-0.04763246, -0.01192755, -0.00019379, 0.04103841,
-0.00138055, 0.02126456, -0.02793905, -0.0546775,
-0.02912027
-4.74739634e-02, -1.19136050e-02, -7.36759976e-05,
4.10547927e-02, -1.36767328e-03, 2.11772677e-02,
-2.80479677e-02, -5.44514805e-02, -2.91903690e-02
], [
0.02241185, -0.00141741, 0.01911988, 0.00547728,
-0.01280068, -0.00307024, -0.00494239, 0.02169247,
0.01631995
2.25644894e-02, -1.40382675e-03, 1.92396250e-02,
5.49034867e-03, -1.27930511e-02, -3.15603940e-03,
-5.05525898e-03, 2.19191350e-02, 1.62497871e-02
]],
dtype=float32),
h=array(
[[
-1.10612623e-02, -3.98178305e-03, -8.15406092e-04,
7.90496264e-03, -7.02379830e-03, -3.76371504e-03,
-4.16189339e-03, -6.17096573e-03, 8.95528216e-03
-1.09839402e-02, -3.97479767e-03, -7.54472159e-04,
7.91201927e-03, -7.02175125e-03, -3.80689627e-03,
-4.22065007e-03, -6.05447078e-03, 8.92056432e-03
], [
8.60652886e-03, -7.17687514e-03, -1.94147555e-03,
1.62586085e-02, -1.76823605e-02, -1.06194830e-02,
-3.01912241e-03, -9.57269047e-04, -9.95719433e-03
8.68127123e-03, -7.16970162e-03, -1.88375649e-03,
1.62681788e-02, -1.76830534e-02, -1.06617520e-02,
-3.07536125e-03, -8.45551898e-04, -9.99375992e-03
], [
-5.78764686e-03, 4.49441886e-03, -8.13564472e-03,
-6.95375400e-03, 6.75391173e-03, -5.07880514e-03,
3.11744539e-03, -4.72871540e-03, -3.20470310e-03
-5.71034756e-03, 4.50129062e-03, -8.07590690e-03,
-6.94835978e-03, 6.75921654e-03, -5.12148207e-03,
3.06083867e-03, -4.61710012e-03, -3.23932176e-03
], [
-2.38018595e-02, -5.89303859e-03, -9.74571449e-05,
2.01695058e-02, -6.82353624e-04, 1.07099945e-02,
-1.42086931e-02, -2.70793252e-02, -1.44684194e-02
-2.37224493e-02, -5.88587578e-03, -3.70525813e-05,
2.01787278e-02, -6.76127791e-04, 1.06675029e-02,
-1.42634306e-02, -2.69631632e-02, -1.45033058e-02
], [
1.11825848e-02, -6.99267141e-04, 9.82748345e-03,
2.74566701e-03, -6.56377291e-03, -1.53681310e-03,
-2.48806458e-03, 1.10462429e-02, 7.97568541e-03
1.12585640e-02, -6.92534202e-04, 9.88917705e-03,
2.75237625e-03, -6.56115822e-03, -1.57997780e-03,
-2.54477374e-03, 1.11598391e-02, 7.94144534e-03
]],
dtype=float32)),
attention=array(
[[
1.22659411e-02, -6.81970268e-03, 3.15135531e-03,
-1.96937821e-03, 5.62768336e-03, -1.39173865e-02
0.00165033, 0.00196973, 0.00203463, -0.00382007, -0.00346369,
-0.00654224
], [
1.57857724e-02, -1.09536834e-02, 4.64798324e-03,
-1.01319887e-03, 1.22695938e-02, -1.25500849e-02
0.00517205, -0.00216447, 0.00353219, -0.00286491, 0.00317879,
-0.00517592
], [
9.49526206e-03, -9.39475093e-03, -8.49372707e-04,
-1.72815053e-05, 4.16132808e-03, -1.66336838e-02
-0.00111966, -0.00060763, -0.00196644, -0.00186804, -0.00493048,
-0.00925842
], [
1.15833860e-02, -1.31145213e-02, -2.84505659e-04,
2.48642010e-03, 6.93593081e-03, -7.82784075e-04
0.00096792, -0.00432467, -0.00140086, 0.00063597, -0.00215558,
0.00659212
], [
1.07398070e-02, -8.56867433e-03, 1.52354129e-03, 2.06834078e-03,
9.36511997e-03, -7.64556089e-03
0.00012445, 0.00022082, 0.00040711, 0.00021803, 0.0002734,
-0.00026981
]],
dtype=float32))
dtype=float32),
time=3,
attention_history=())
self._testWithAttention(
create_attention_mechanism,
expected_final_output,
expected_final_state,
attention_mechanism_depth=9)
attention_mechanism_depth=9,
name="testLuongNotNormalized")
def testLuongNormalized(self):
def testLuongScaled(self):
create_attention_mechanism = functools.partial(
wrapper.LuongAttention, normalize=True, attention_r_initializer=2.0)
wrapper.LuongAttention, scale=True)
array = np.array
float32 = np.float32
int32 = np.int32
expected_final_output = basic_decoder.BasicDecoderOutput(
expected_final_output = BasicDecoderOutput(
rnn_output=array(
[[[
1.23956744e-02, -6.88115368e-03, 3.15234554e-03,
-1.97300944e-03, 4.79680905e-03, -1.38076628e-02
1.74749165e-03, 1.95862399e-03, 2.12293095e-03, -3.75889172e-03,
-4.39571124e-03, -6.32379763e-03
], [
1.28376717e-02, -6.78718928e-03, 3.07988771e-03,
-2.03956687e-03, 5.68403490e-03, -1.35601182e-02
2.33045570e-03, 1.99094601e-03, 1.98377599e-03, -3.87950847e-03,
-3.42792575e-03, -6.17497414e-03
], [
1.23463338e-02, -6.76322030e-03, 3.28891934e-03,
-1.86874042e-03, 5.47897862e-03, -1.37654068e-02
1.65032526e-03, 1.96972815e-03, 2.03462853e-03, -3.82007333e-03,
-3.46369296e-03, -6.54224353e-03
]], [[
1.54412268e-02, -1.07613346e-02, 4.43824846e-03,
-8.81063985e-04, 1.26828086e-02, -1.21067995e-02
4.77780215e-03, -1.98677275e-03, 3.30950436e-03,
-2.68179504e-03, 3.56271653e-03, -4.67860466e-03
], [
1.57206059e-02, -1.08218864e-02, 4.61952807e-03,
-9.61483689e-04, 1.22140013e-02, -1.26614980e-02
5.13039157e-03, -2.02797214e-03, 3.50760575e-03,
-2.83981953e-03, 3.13726603e-03, -5.31156827e-03
], [
1.57821011e-02, -1.09842420e-02, 4.66934917e-03,
-9.85997496e-04, 1.22719472e-02, -1.25438003e-02
5.17205056e-03, -2.16446724e-03, 3.53219034e-03,
-2.86490913e-03, 3.17879021e-03, -5.17592067e-03
]], [[
9.27361846e-03, -9.66077764e-03, -9.69522633e-04,
1.48308463e-05, 3.88664147e-03, -1.64083000e-02
-1.38538703e-03, -6.40910701e-04, -2.02864106e-03,
-1.79018872e-03, -5.18789608e-03, -8.95875692e-03
], [
9.26287938e-03, -9.74234194e-03, -8.32488062e-04,
5.83778601e-05, 3.52663640e-03, -1.66827720e-02
-1.38620089e-03, -7.92010222e-04, -1.91070826e-03,
-1.76206254e-03, -5.56525169e-03, -9.27332044e-03
], [
9.50474478e-03, -9.49789397e-03, -8.71829456e-04,
-3.09986062e-05, 4.13423358e-03, -1.66635048e-02
-1.11966045e-03, -6.07630936e-04, -1.96643686e-03,
-1.86803937e-03, -4.93048411e-03, -9.25842486e-03
]], [[
1.21398102e-02, -1.27454493e-02, 1.57688977e-04, 2.70034792e-03,
7.79653806e-03, -8.36936757e-04
1.50820788e-03, -3.93087184e-03, -9.52563598e-04,
8.43994785e-04, -1.29030924e-03, 6.48857141e-03
], [
1.18234595e-02, -1.33170560e-02, 4.55579720e-05, 2.67185434e-03,
6.99766818e-03, -1.00935437e-03
1.17029145e-03, -4.45716921e-03, -1.05062663e-03,
8.08141369e-04, -2.08062865e-03, 6.23444980e-03
], [
1.16009805e-02, -1.31483339e-02, -2.94458936e-04,
2.49248254e-03, 6.92958105e-03, -7.20315147e-04
9.67921398e-04, -4.32466762e-03, -1.40085898e-03,
6.35969569e-04, -2.15558149e-03, 6.59212377e-03
]], [[
1.02377674e-02, -8.72955937e-03, 1.22555892e-03, 2.03830865e-03,
8.93574394e-03, -7.28237582e-03
-3.78854020e-04, 5.62231544e-05, 1.06837302e-04, 1.87137164e-04,
-1.56512906e-04, 9.63474595e-05
], [
1.05115287e-02, -8.92531779e-03, 1.14568521e-03, 1.91635895e-03,
8.94328393e-03, -7.39541650e-03
-1.04306288e-04, -1.37411975e-04, 2.82689070e-05,
6.56487318e-05, -1.48634164e-04, -1.84347919e-05
], [
1.07398070e-02, -8.56867433e-03, 1.52354129e-03, 2.06834078e-03,
9.36511997e-03, -7.64556089e-03
1.24452345e-04, 2.20821079e-04, 4.07114130e-04, 2.18028668e-04,
2.73401442e-04, -2.69805576e-04
]]],
dtype=float32),
sample_id=array(
[[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
[[2, 0, 2], [0, 0, 0], [1, 1, 1], [5, 5, 5], [3, 3, 2]],
dtype=int32))
expected_final_state = wrapper.AttentionWrapperState(
time=3,
attention_history=(),
cell_state=core_rnn_cell.LSTMStateTuple(
expected_final_state = AttentionWrapperState(
cell_state=LSTMStateTuple(
c=array(
[[
-0.02204949, -0.00805957, -0.001603, 0.01609283,
-0.01380462, -0.0074945, -0.00816895, -0.01210009,
0.01795324
-2.18960866e-02, -8.04429129e-03, -1.48267671e-03,
1.61071159e-02, -1.37981661e-02, -7.57933082e-03,
-8.28570686e-03, -1.18733812e-02, 1.78834442e-02
], [
0.01727016, -0.01420708, -0.00399973, 0.03195432,
-0.03547529, -0.02138673, -0.00610332, -0.00191565,
-0.01937822
1.74204130e-02, -1.41935758e-02, -3.88074201e-03,
3.19713727e-02, -3.54694910e-02, -2.14688145e-02,
-6.21731905e-03, -1.69229065e-03, -1.94492843e-02
], [
-0.01160676, 0.00876512, -0.01641791, -0.01400807,
0.01347767, -0.01036341, 0.00627499, -0.00963627,
-0.00650573
-1.14494488e-02, 8.77974741e-03, -1.62960067e-02,
-1.39961652e-02, 1.34879015e-02, -1.04502086e-02,
6.15879148e-03, -9.40956455e-03, -6.57592434e-03
], [
-0.04763342, -0.01192671, -0.00019402, 0.04103871,
-0.00138017, 0.02126611, -0.02793773, -0.05467714,
-0.02912043
-4.74739634e-02, -1.19136050e-02, -7.36759976e-05,
4.10547927e-02, -1.36767328e-03, 2.11772677e-02,
-2.80479677e-02, -5.44514805e-02, -2.91903690e-02
], [
0.02241185, -0.00141741, 0.01911988, 0.00547728,
-0.01280068, -0.00307024, -0.00494239, 0.02169247,
0.01631995
2.25644894e-02, -1.40382675e-03, 1.92396250e-02,
5.49034867e-03, -1.27930511e-02, -3.15603940e-03,
-5.05525898e-03, 2.19191350e-02, 1.62497871e-02
]],
dtype=float32),
h=array(
[[
-1.10610286e-02, -3.98253463e-03, -8.15684092e-04,
7.90454168e-03, -7.02364743e-03, -3.76377185e-03,
-4.16135695e-03, -6.17104582e-03, 8.95532966e-03
-1.09839402e-02, -3.97479767e-03, -7.54472159e-04,
7.91201927e-03, -7.02175125e-03, -3.80689627e-03,
-4.22065007e-03, -6.05447078e-03, 8.92056432e-03
], [
8.60653073e-03, -7.17685232e-03, -1.94147974e-03,
1.62585936e-02, -1.76823437e-02, -1.06195193e-02,
-3.01911240e-03, -9.57308919e-04, -9.95720550e-03
8.68127123e-03, -7.16970162e-03, -1.88375649e-03,
1.62681788e-02, -1.76830534e-02, -1.06617520e-02,
-3.07536125e-03, -8.45551898e-04, -9.99375992e-03
], [
-5.78888878e-03, 4.49400023e-03, -8.13617278e-03,
-6.95386063e-03, 6.75271638e-03, -5.07823005e-03,
3.11873178e-03, -4.72912844e-03, -3.20472987e-03
-5.71034756e-03, 4.50129062e-03, -8.07590690e-03,
-6.94835978e-03, 6.75921654e-03, -5.12148207e-03,
3.06083867e-03, -4.61710012e-03, -3.23932176e-03
], [
-2.38023344e-02, -5.89262368e-03, -9.75721487e-05,
2.01696623e-02, -6.82163402e-04, 1.07107637e-02,
-1.42080421e-02, -2.70791352e-02, -1.44685050e-02
-2.37224493e-02, -5.88587578e-03, -3.70525813e-05,
2.01787278e-02, -6.76127791e-04, 1.06675029e-02,
-1.42634306e-02, -2.69631632e-02, -1.45033058e-02
], [
1.11825848e-02, -6.99267141e-04, 9.82748345e-03,
2.74566701e-03, -6.56377291e-03, -1.53681310e-03,
-2.48806458e-03, 1.10462429e-02, 7.97568541e-03
1.12585640e-02, -6.92534202e-04, 9.88917705e-03,
2.75237625e-03, -6.56115822e-03, -1.57997780e-03,
-2.54477374e-03, 1.11598391e-02, 7.94144534e-03
]],
dtype=float32)),
attention=array(
[[
1.23463338e-02, -6.76322030e-03, 3.28891934e-03,
-1.86874042e-03, 5.47897862e-03, -1.37654068e-02
0.00165033, 0.00196973, 0.00203463, -0.00382007, -0.00346369,
-0.00654224
], [
1.57821011e-02, -1.09842420e-02, 4.66934917e-03,
-9.85997496e-04, 1.22719472e-02, -1.25438003e-02
0.00517205, -0.00216447, 0.00353219, -0.00286491, 0.00317879,
-0.00517592
], [
9.50474478e-03, -9.49789397e-03, -8.71829456e-04,
-3.09986062e-05, 4.13423358e-03, -1.66635048e-02
-0.00111966, -0.00060763, -0.00196644, -0.00186804, -0.00493048,
-0.00925842
], [
1.16009805e-02, -1.31483339e-02, -2.94458936e-04,
2.49248254e-03, 6.92958105e-03, -7.20315147e-04
0.00096792, -0.00432467, -0.00140086, 0.00063597, -0.00215558,
0.00659212
], [
1.07398070e-02, -8.56867433e-03, 1.52354129e-03, 2.06834078e-03,
9.36511997e-03, -7.64556089e-03
0.00012445, 0.00022082, 0.00040711, 0.00021803, 0.0002734,
-0.00026981
]],
dtype=float32))
dtype=float32),
time=3,
attention_history=())
self._testWithAttention(
create_attention_mechanism,
expected_final_output,
expected_final_state,
attention_mechanism_depth=9)
attention_mechanism_depth=9,
name="testLuongScaled")
if __name__ == "__main__":

View File

@ -176,19 +176,18 @@ class LuongAttention(_BaseAttentionMechanism):
"Effective Approaches to Attention-based Neural Machine Translation."
EMNLP 2015. https://arxiv.org/abs/1508.04025
The second is the normalized form. This form is inspired by the
normalization proposed for Bahdanau attention in
Colin Raffel, Thang Luong, Peter J. Liu, Ron J. Weiss, and Douglas Eck.
"Online and Linear-Time Attention by Enforcing Monotonic Alignments."
(Eq. 15).
The second is the scaled form inspired partly by the normalized form of
Bahdanau attention.
To enable the second form, construct the object with parameter
`normalize=True`.
`scale=True`.
"""
def __init__(self, num_units, memory, memory_sequence_length=None,
normalize=False, attention_r_initializer=None,
def __init__(self,
num_units,
memory,
memory_sequence_length=None,
scale=False,
name="LuongAttention"):
"""Construct the AttentionMechanism mechanism.
@ -199,31 +198,21 @@ class LuongAttention(_BaseAttentionMechanism):
memory_sequence_length (optional): Sequence lengths for the batch entries
in memory. If provided, the memory tensor rows are masked with zeros
for values past the respective sequence lengths.
normalize: Python boolean. Whether to normalize the energy term.
attention_r_initializer: Initial value of the post-normalization bias
when normalizing. Default is `0`.
scale: Python boolean. Whether to scale the energy term.
name: Name to use when creating ops.
"""
# For LuongAttention, we only transform the memory layer; thus
# num_units **must** match expected the query depth.
super(LuongAttention, self).__init__(
query_layer=None,
memory_layer=layers_core.Dense(num_units, name="memory_layer"),
memory_layer=layers_core.Dense(
num_units, name="memory_layer", use_bias=False),
memory=memory,
memory_sequence_length=memory_sequence_length,
name=name)
self._num_units = num_units
self._normalize = normalize
self._scale = scale
self._name = name
if normalize and attention_r_initializer is None:
attention_r_initializer = 0
if normalize:
with ops.name_scope(name, "LuongAttention",
[memory, attention_r_initializer]):
attention_r_initializer = ops.convert_to_tensor(
attention_r_initializer, dtype=self.values.dtype,
name="attention_r_initializer")
self._attention_r_initializer = attention_r_initializer
def __call__(self, query):
"""Score the query based on the keys and values.
@ -249,7 +238,7 @@ class LuongAttention(_BaseAttentionMechanism):
% (query, depth, self.keys, key_units, key_units))
dtype = query.dtype
with ops.name_scope(None, "LuongAttentionCall", [query]):
with variable_scope.variable_scope(None, "luong_attention", [query]):
# Reshape from [batch_size, depth] to [batch_size, 1, depth]
# for matmul.
query = array_ops.expand_dims(query, 1)
@ -266,16 +255,11 @@ class LuongAttention(_BaseAttentionMechanism):
score = math_ops.matmul(query, self.keys, transpose_b=True)
score = array_ops.squeeze(score, [1])
if self._normalize:
# Scalar used in weight normalization
if self._scale:
# Scalar used in weight scaling
g = variable_scope.get_variable(
"attention_g", dtype=dtype,
initializer=math.sqrt((1. / self._num_units)))
# Scalar bias added to attention scores
r = variable_scope.get_variable(
"attention_r", dtype=dtype,
initializer=self._attention_r_initializer)
score = g * score + r
"attention_g", dtype=dtype, initializer=1.)
score = g * score
return score
@ -290,18 +274,23 @@ class BahdanauAttention(_BaseAttentionMechanism):
"Neural Machine Translation by Jointly Learning to Align and Translate."
ICLR 2015. https://arxiv.org/abs/1409.0473
The second is the normalized form, Raffel attention, as described in:
The second is the normalized form. This form is inspired by the
weight normalization article:
Colin Raffel, Thang Luong, Peter J. Liu, Ron J. Weiss, and Douglas Eck.
"Online and Linear-Time Attention by Enforcing Monotonic Alignments."
(Eq. 15).
Tim Salimans, Diederik P. Kingma.
"Weight Normalization: A Simple Reparameterization to Accelerate
Training of Deep Neural Networks."
https://arxiv.org/abs/1602.07868
To enable the second form, construct the object with parameter
`normalize=True`.
"""
def __init__(self, num_units, memory, memory_sequence_length=None,
normalize=False, attention_r_initializer=None,
def __init__(self,
num_units,
memory,
memory_sequence_length=None,
normalize=False,
name="BahdanauAttention"):
"""Construct the Attention mechanism.
@ -313,28 +302,19 @@ class BahdanauAttention(_BaseAttentionMechanism):
in memory. If provided, the memory tensor rows are masked with zeros
for values past the respective sequence lengths.
normalize: Python boolean. Whether to normalize the energy term.
attention_r_initializer: Initial value of the post-normalization bias
when normalizing. Default is `0`.
name: Name to use when creating ops.
"""
super(BahdanauAttention, self).__init__(
query_layer=layers_core.Dense(num_units, name="query_layer"),
memory_layer=layers_core.Dense(num_units, name="memory_layer"),
query_layer=layers_core.Dense(
num_units, name="query_layer", use_bias=False),
memory_layer=layers_core.Dense(
num_units, name="memory_layer", use_bias=False),
memory=memory,
memory_sequence_length=memory_sequence_length,
name=name)
self._num_units = num_units
self._normalize = normalize
self._name = name
if normalize and attention_r_initializer is None:
attention_r_initializer = 0
if normalize:
with ops.name_scope(name, "BahdanauAttention",
[memory, attention_r_initializer]):
attention_r_initializer = ops.convert_to_tensor(
attention_r_initializer, dtype=self.values.dtype,
name="attention_r_initializer")
self._attention_r_initializer = attention_r_initializer
def __call__(self, query):
"""Score the query based on the keys and values.
@ -347,7 +327,7 @@ class BahdanauAttention(_BaseAttentionMechanism):
score: Tensor of dtype matching `self.values` and shape
`[batch_size, max_time]` (`max_time` is memory's `max_time`).
"""
with ops.name_scope(None, "BahndahauAttentionCall", [query]):
with variable_scope.variable_scope(None, "bahdanau_attention", [query]):
processed_query = self.query_layer(query) if self.query_layer else query
dtype = processed_query.dtype
# Reshape from [batch_size, ...] to [batch_size, 1, ...] for broadcasting.
@ -363,15 +343,11 @@ class BahdanauAttention(_BaseAttentionMechanism):
b = variable_scope.get_variable(
"attention_b", [self._num_units], dtype=dtype,
initializer=init_ops.zeros_initializer())
# Scalar bias added to attention scores
r = variable_scope.get_variable(
"attention_r", dtype=dtype,
initializer=self._attention_r_initializer)
# normed_v = g * v / ||v||
normed_v = g * v * math_ops.rsqrt(
math_ops.reduce_sum(math_ops.square(v)))
score = math_ops.reduce_sum(
normed_v * math_ops.tanh(self.keys + processed_query + b), [2]) + r
normed_v * math_ops.tanh(self.keys + processed_query + b), [2])
else:
score = math_ops.reduce_sum(
v * math_ops.tanh(self.keys + processed_query), [2])
@ -481,7 +457,7 @@ class AttentionWrapper(core_rnn_cell.RNNCell):
self._attention_mechanism = attention_mechanism
self._attention_size = attention_size
self._attention_layer = layers_core.Dense(
attention_size, bias_initializer=None)
attention_size, name="attention_layer", use_bias=False)
self._cell_input_fn = cell_input_fn
self._probability_fn = probability_fn
self._output_attention = output_attention
@ -550,44 +526,44 @@ class AttentionWrapper(core_rnn_cell.RNNCell):
if scope is not None:
raise NotImplementedError("scope not None is not supported")
# Step 1: Calculate the true inputs to the cell based on the
# previous attention value.
cell_inputs = self._cell_input_fn(inputs, state.attention)
cell_state = state.cell_state
with variable_scope.variable_scope("attention"):
# Step 1: Calculate the true inputs to the cell based on the
# previous attention value.
cell_inputs = self._cell_input_fn(inputs, state.attention)
cell_state = state.cell_state
cell_output, next_cell_state = self._cell(cell_inputs, cell_state)
cell_output, next_cell_state = self._cell(cell_inputs, cell_state)
score = self._attention_mechanism(cell_output)
alignments = self._probability_fn(score)
score = self._attention_mechanism(cell_output)
alignments = self._probability_fn(score)
# Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time]
alignments = array_ops.expand_dims(alignments, 1)
# Context is the inner product of alignments and values along the
# memory time dimension.
# alignments shape is
# [batch_size, 1, memory_time]
# attention_mechanism.values shape is
# [batch_size, memory_time, attention_mechanism.num_units]
# the batched matmul is over memory_time, so the output shape is
# [batch_size, 1, attention_mechanism.num_units].
# we then squeeze out the singleton dim.
context = math_ops.matmul(alignments, self._attention_mechanism.values)
context = array_ops.squeeze(context, [1])
# Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time]
alignments = array_ops.expand_dims(alignments, 1)
# Context is the inner product of alignments and values along the
# memory time dimension.
# alignments shape is
# [batch_size, 1, memory_time]
# attention_mechanism.values shape is
# [batch_size, memory_time, attention_mechanism.num_units]
# the batched matmul is over memory_time, so the output shape is
# [batch_size, 1, attention_mechanism.num_units].
# we then squeeze out the singleton dim.
context = math_ops.matmul(alignments, self._attention_mechanism.values)
context = array_ops.squeeze(context, [1])
attention = self._attention_layer(
array_ops.concat([cell_output, context], 1))
attention = self._attention_layer(
array_ops.concat([cell_output, context], 1))
if self._attention_history:
attention_history = state.attention_history.write(
state.time, attention)
else:
attention_history = ()
if self._attention_history:
attention_history = state.attention_history.write(state.time, attention)
else:
attention_history = ()
next_state = AttentionWrapperState(
time=state.time + 1,
cell_state=next_cell_state,
attention=attention,
attention_history=attention_history)
next_state = AttentionWrapperState(
time=state.time + 1,
cell_state=next_cell_state,
attention=attention,
attention_history=attention_history)
if self._output_attention:
return attention, next_state

View File

@ -269,7 +269,12 @@ class SparseTensor(ItemHandler):
class Image(ItemHandler):
"""An ItemHandler that decodes a parsed Tensor as an image."""
def __init__(self, image_key=None, format_key=None, shape=None, channels=3):
def __init__(self,
image_key=None,
format_key=None,
shape=None,
channels=3,
dtype=dtypes.uint8):
"""Initializes the image.
Args:
@ -282,6 +287,11 @@ class Image(ItemHandler):
accordingly. If left as None, no reshaping is done. A shape should
be supplied only if all the stored images have the same shape.
channels: the number of channels in the image.
dtype: images will be decoded at this bit depth. Different formats
support different bit depths.
See tf.image.decode_png,
tf.decode_raw,
tf.image.decode_jpeg: only supports tf.uint8
"""
if not image_key:
image_key = 'image/encoded'
@ -293,6 +303,7 @@ class Image(ItemHandler):
self._format_key = format_key
self._shape = shape
self._channels = channels
self._dtype = dtype
def tensors_to_item(self, keys_to_tensors):
"""See base class."""
@ -314,12 +325,17 @@ class Image(ItemHandler):
"""
def decode_png():
return image_ops.decode_png(image_buffer, self._channels)
return image_ops.decode_png(
image_buffer, self._channels, dtype=self._dtype)
def decode_raw():
return parsing_ops.decode_raw(image_buffer, dtypes.uint8)
return parsing_ops.decode_raw(image_buffer, out_type=self._dtype)
def decode_jpg():
if self._dtype != dtypes.uint8:
raise ValueError(
'jpeg decoder can only be used to decode to tf.uint8 but %s was '
'requested for a jpeg image.' % self._dtype)
return image_ops.decode_jpeg(image_buffer, self._channels)
# For RGBA images JPEG is not a valid decoder option.
@ -401,6 +417,7 @@ class TFExampleDecoder(data_decoder.DataDecoder):
"""
example = parsing_ops.parse_single_example(serialized_example,
self._keys_to_features)
print(example.keys())
# Reshape non-sparse elements just once:
for k in self._keys_to_features:

View File

@ -224,6 +224,18 @@ class TFExampleDecoderTest(test.TestCase):
self.assertAllClose(image, decoded_image, atol=0)
def testDecodeExampleWithJpegEncodingAt16BitCausesError(self):
image_shape = (2, 3, 3)
unused_image, serialized_example = self.GenerateImage(
image_format='jpeg', image_shape=image_shape)
expected_regex = ('jpeg decoder can only be used to decode to tf.uint8 but '
'.* was requested for a jpeg image.')
with self.assertRaisesRegexp(ValueError, expected_regex):
unused_decoded_image = self.RunDecodeExample(
serialized_example,
tfexample_decoder.Image(dtype=dtypes.uint16),
image_format='jpeg')
def testDecodeExampleWithStringTensor(self):
tensor_shape = (2, 3, 1)
np_array = np.array([[['ab'], ['cd'], ['ef']],

Some files were not shown because too many files have changed in this diff Show More