Merge commit for internal changes
This commit is contained in:
commit
155332c116
@ -297,6 +297,7 @@ filegroup(
|
||||
"//tensorflow/tensorboard/backend:all_files",
|
||||
"//tensorflow/tensorboard/backend/event_processing:all_files",
|
||||
"//tensorflow/tensorboard/components:all_files",
|
||||
"//tensorflow/tensorboard/components/tf_text_dashboard:all_files",
|
||||
"//tensorflow/tensorboard/components/vz_data_summary:all_files",
|
||||
"//tensorflow/tensorboard/components/vz_line_chart:all_files",
|
||||
"//tensorflow/tensorboard/components/vz_line_chart/demo:all_files",
|
||||
|
@ -135,6 +135,9 @@ class TF_ManagedBuffer : public TensorBuffer {
|
||||
proto->set_requested_bytes(rb);
|
||||
proto->set_allocator_name(tensorflow::cpu_allocator()->Name());
|
||||
}
|
||||
|
||||
// Prevents input forwarding from mutating this buffer.
|
||||
bool OwnsMemory() const override { return false; }
|
||||
};
|
||||
|
||||
void* allocate_tensor(const char* operation, size_t len) {
|
||||
|
@ -314,6 +314,7 @@ tf_gen_op_wrappers_cc(
|
||||
name = "cc_ops",
|
||||
op_lib_names = [
|
||||
"array_ops",
|
||||
"audio_ops",
|
||||
"candidate_sampling_ops",
|
||||
"control_flow_ops",
|
||||
"data_flow_ops",
|
||||
|
@ -260,6 +260,8 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
|
||||
xla::ExecutableRunOptions run_options;
|
||||
run_options.set_stream(stream);
|
||||
run_options.set_allocator(&xla_allocator);
|
||||
run_options.set_inter_op_thread_pool(
|
||||
ctx->device()->tensorflow_cpu_worker_threads()->workers);
|
||||
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
|
||||
Env* env = Env::Default();
|
||||
auto start_time = env->NowMicros();
|
||||
|
@ -19,13 +19,6 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
void XlaDeviceAssignOp::Copy(OpKernelContext* context, Tensor* lhs,
|
||||
const Tensor& rhs) {
|
||||
std::shared_ptr<xla::GlobalData> gd =
|
||||
XlaTransferManager::GetTensorGlobalData(rhs);
|
||||
XlaTransferManager::SetTensorGlobalData(std::move(gd), lhs);
|
||||
}
|
||||
|
||||
XlaDeviceDummyOp::XlaDeviceDummyOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void XlaDeviceDummyOp::Compute(OpKernelContext* ctx) {
|
||||
|
@ -20,7 +20,6 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/kernels/assign_op.h"
|
||||
#include "tensorflow/core/kernels/constant_op.h"
|
||||
#include "tensorflow/core/kernels/control_flow_ops.h"
|
||||
#include "tensorflow/core/kernels/identity_op.h"
|
||||
@ -30,14 +29,6 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Implementation of Assign for XLA devices.
|
||||
class XlaDeviceAssignOp : public AssignOp {
|
||||
public:
|
||||
using AssignOp::AssignOp;
|
||||
|
||||
void Copy(OpKernelContext* context, Tensor* lhs, const Tensor& rhs) override;
|
||||
};
|
||||
|
||||
// Dummy OpKernel, used for kernels assigned to an XLA device that should be
|
||||
// compiled. Should never be called at runtime since such ops should be
|
||||
// rewritten to a _XlaLaunch op. If it is called, it means the placer placed an
|
||||
@ -72,28 +63,6 @@ class XlaDeviceDummyOp : public OpKernel {
|
||||
REGISTER_KERNEL_BUILDER(Name("PlaceholderV2").Device(DEVICE), \
|
||||
PlaceholderOp); \
|
||||
\
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("Variable").Device(DEVICE).TypeConstraint("dtype", TYPES), \
|
||||
VariableOp); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("VariableV2").Device(DEVICE).TypeConstraint("dtype", TYPES), \
|
||||
VariableOp); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("TemporaryVariable").Device(DEVICE).TypeConstraint("dtype", TYPES), \
|
||||
TemporaryVariableOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("DestroyTemporaryVariable") \
|
||||
.Device(DEVICE) \
|
||||
.TypeConstraint("T", TYPES), \
|
||||
DestroyTemporaryVariableOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("IsVariableInitialized") \
|
||||
.Device(DEVICE) \
|
||||
.TypeConstraint("dtype", TYPES) \
|
||||
.HostMemory("is_initialized"), \
|
||||
IsVariableInitializedOp); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("Assign").Device(DEVICE).TypeConstraint("T", TYPES), \
|
||||
XlaDeviceAssignOp); \
|
||||
\
|
||||
REGISTER_KERNEL_BUILDER(Name("ControlTrigger").Device(DEVICE), \
|
||||
ControlTriggerOp); \
|
||||
REGISTER_KERNEL_BUILDER(Name("Enter").Device(DEVICE), EnterOp); \
|
||||
|
@ -614,9 +614,12 @@ REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
|
||||
Name("TruncateDiv").TypeConstraint("T", kGpuIntTypes));
|
||||
REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
|
||||
Name("TruncateMod").TypeConstraint("T", kGpuNumericTypes));
|
||||
REGISTER_XLA_KERNEL(
|
||||
DEVICE_GPU_XLA_JIT,
|
||||
Name("TruncatedNormal").TypeConstraint("dtype", kGpuFloatTypes));
|
||||
|
||||
// TODO(b/34969189) The implementation of TruncatedNormal triggers a bug on GPU.
|
||||
// REGISTER_XLA_KERNEL(
|
||||
// DEVICE_GPU_XLA_JIT,
|
||||
// Name("TruncatedNormal").TypeConstraint("dtype", kGpuFloatTypes));
|
||||
|
||||
REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
|
||||
Name("Unpack").TypeConstraint("T", kGpuAllTypes));
|
||||
REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT, Name("VarIsInitializedOp"));
|
||||
|
@ -301,8 +301,7 @@ StatusOr<std::vector<std::unique_ptr<GlobalData>>> Client::ExecuteParallel(
|
||||
}
|
||||
|
||||
std::vector<std::unique_ptr<GlobalData>> outputs;
|
||||
for (tensorflow::gtl::ArraySlice<ComputationInstance>::size_type i = 0;
|
||||
i < computations.size(); ++i) {
|
||||
for (size_t i = 0; i < computations.size(); ++i) {
|
||||
outputs.push_back(
|
||||
MakeUnique<GlobalData>(stub_, response.responses(i).output()));
|
||||
if (computations[i].execution_profile != nullptr) {
|
||||
|
@ -33,7 +33,7 @@ Computation::Computation(Computation&& computation)
|
||||
}
|
||||
|
||||
void Computation::Reset() {
|
||||
// TODO(leary) deallocate any owned computation.
|
||||
// TODO(b/34469253) deallocate any owned computation.
|
||||
ResetWithoutFreeing();
|
||||
}
|
||||
|
||||
|
@ -106,9 +106,7 @@ bool ComputationBuilder::MakeWindow(
|
||||
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
|
||||
tensorflow::gtl::ArraySlice<int64> lhs_dilation,
|
||||
tensorflow::gtl::ArraySlice<int64> rhs_dilation, Window* window) {
|
||||
const auto verify_size = [&](const tensorflow::gtl::ArraySlice<
|
||||
int64>::size_type x,
|
||||
const char* x_name) {
|
||||
const auto verify_size = [&](const size_t x, const char* x_name) {
|
||||
if (x == 0 || x == window_dimensions.size()) {
|
||||
return true;
|
||||
} else {
|
||||
|
@ -541,8 +541,8 @@ class ComputationBuilder {
|
||||
// (float32 is specified as there is an implicit float32 -1.0f constant
|
||||
// exponent).
|
||||
//
|
||||
// TODO(leary) axe F32 suffix, can be determined by reflecting on the shape of
|
||||
// the operand.
|
||||
// TODO(b/34468990) axe F32 suffix, can be determined by reflecting on the
|
||||
// shape of the operand.
|
||||
ComputationDataHandle ReciprocalF32(const ComputationDataHandle& operand);
|
||||
|
||||
// Enqueues a negate instruction onto the computation.
|
||||
@ -839,7 +839,7 @@ template <typename NativeT>
|
||||
ComputationDataHandle ComputationBuilder::ConstantR4FromArray4DWithLayout(
|
||||
const Array4D<NativeT>& values, const Layout& layout) {
|
||||
return ConstantOp([&values, &layout](Literal* literal) {
|
||||
LiteralUtil::PopulateR4FromArray4D(values, layout, literal);
|
||||
LiteralUtil::PopulateR4FromArray4DWithLayout(values, layout, literal);
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -309,6 +309,14 @@ int LocalClient::default_device_ordinal() const {
|
||||
return local_service_->backend().default_device_ordinal();
|
||||
}
|
||||
|
||||
const Backend& LocalClient::backend() const {
|
||||
return local_service_->backend();
|
||||
}
|
||||
|
||||
Backend* LocalClient::mutable_backend() {
|
||||
return local_service_->mutable_backend();
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<LocalExecutable>> LocalClient::Compile(
|
||||
const Computation& computation,
|
||||
const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
|
||||
|
@ -224,6 +224,10 @@ class LocalClient : public Client {
|
||||
// capability).
|
||||
bool device_ordinal_supported(int device_ordinal) const;
|
||||
|
||||
// Returns the backend used to execute computations.
|
||||
const Backend& backend() const;
|
||||
Backend* mutable_backend();
|
||||
|
||||
private:
|
||||
LocalService* local_service_;
|
||||
};
|
||||
|
@ -35,8 +35,7 @@ std::vector<std::pair<int64, int64>> MakePadding(
|
||||
return low_high_padding;
|
||||
|
||||
case Padding::kSame:
|
||||
for (tensorflow::gtl::ArraySlice<int64>::size_type i = 0;
|
||||
i < input_dimensions.size(); ++i) {
|
||||
for (size_t i = 0; i < input_dimensions.size(); ++i) {
|
||||
int64 input_dimension = input_dimensions[i];
|
||||
int64 window_dimension = window_dimensions[i];
|
||||
int64 window_stride = window_strides[i];
|
||||
|
@ -32,8 +32,7 @@ namespace xla {
|
||||
// Padding and nested layouts not supported yet.
|
||||
DCHECK_EQ(0, shape.layout().padded_dimensions_size());
|
||||
|
||||
for (tensorflow::gtl::ArraySlice<int64>::size_type i = 0;
|
||||
i < multi_index.size(); ++i) {
|
||||
for (size_t i = 0; i < multi_index.size(); ++i) {
|
||||
DCHECK_GE(multi_index[i], 0);
|
||||
DCHECK_LT(multi_index[i], shape.dimensions(i))
|
||||
<< "indexing beyond extent in dimension " << i << ":"
|
||||
|
@ -1133,6 +1133,22 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hlo_verifier",
|
||||
srcs = ["hlo_verifier.cc"],
|
||||
hdrs = ["hlo_verifier.h"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_pass",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hlo_rematerialization",
|
||||
srcs = ["hlo_rematerialization.cc"],
|
||||
|
@ -896,7 +896,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) {
|
||||
HloInstruction* zero = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
|
||||
PaddingConfig no_padding;
|
||||
for (auto i = 0; i < 2; ++i) {
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
auto dimension = no_padding.add_dimensions();
|
||||
dimension->set_edge_padding_low(0);
|
||||
dimension->set_edge_padding_high(0);
|
||||
@ -926,7 +926,7 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) {
|
||||
PaddingConfig padding;
|
||||
int64 low_padding[2] = {-1, -2};
|
||||
int64 high_padding[2] = {2, -3};
|
||||
for (auto i = 0; i < 2; ++i) {
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
auto dimension = padding.add_dimensions();
|
||||
dimension->set_edge_padding_low(low_padding[i]);
|
||||
dimension->set_edge_padding_high(high_padding[i]);
|
||||
|
@ -138,8 +138,7 @@ tensorflow::Status AllocationTracker::DeallocateShape(
|
||||
TF_RET_CHECK(ShapeUtil::TupleElementCount(shape) == elements.size())
|
||||
<< "tuple has unexpected number of elements: " << elements.size()
|
||||
<< " != " << ShapeUtil::TupleElementCount(shape);
|
||||
for (std::vector<se::DeviceMemoryBase>::size_type i = 0;
|
||||
i < elements.size(); ++i) {
|
||||
for (size_t i = 0; i < elements.size(); ++i) {
|
||||
VLOG(2) << "recursing onto the tuple elements";
|
||||
TF_RETURN_IF_ERROR(DeallocateShape(backend, device_ordinal, &elements[i],
|
||||
shape.tuple_shapes(i),
|
||||
|
@ -212,6 +212,13 @@ StatusOr<BufferAllocation::Slice> BufferAssignment::GetUniqueTopLevelSlice(
|
||||
return GetUniqueSlice(instruction, /*index=*/{});
|
||||
}
|
||||
|
||||
bool BufferAssignment::SharesSliceAtIndex(
|
||||
const HloInstruction* hlo_a, const ShapeIndex& shape_index_a,
|
||||
const HloInstruction* hlo_b, const ShapeIndex& shape_index_b) const {
|
||||
return GetUniqueSlice(hlo_a, shape_index_a).ConsumeValueOrDie() ==
|
||||
GetUniqueSlice(hlo_b, shape_index_b).ConsumeValueOrDie();
|
||||
}
|
||||
|
||||
StatusOr<BufferAllocation::Slice>
|
||||
BufferAssignment::GetUniqueTopLevelOutputSlice() const {
|
||||
return GetUniqueTopLevelSlice(
|
||||
|
@ -294,6 +294,15 @@ class BufferAssignment {
|
||||
return GetPointsToSet(instruction).element(index);
|
||||
}
|
||||
|
||||
// Returns true if 'hlo_a{shape_index_a}' and 'hlo_b{shape_index_b}'
|
||||
// share the same BufferAllocation::Slice.
|
||||
// Returns false otherwise.
|
||||
// REQUIRES: BufferAssignment assigned allocations to both instructions.
|
||||
bool SharesSliceAtIndex(const HloInstruction* hlo_a,
|
||||
const ShapeIndex& shape_index_a,
|
||||
const HloInstruction* hlo_b,
|
||||
const ShapeIndex& shape_index_b) const;
|
||||
|
||||
// Returns the underlying points-to analysis used for this assignment.
|
||||
const TuplePointsToAnalysis& points_to_analysis() const {
|
||||
return liveness_->points_to_analysis();
|
||||
|
@ -121,6 +121,7 @@ bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a,
|
||||
// *) Is element-wise.
|
||||
// *) Is a loop fusion instruction (with DynamicUpdateSlice fused root) where
|
||||
// the singleton use of 'a' at 'a.index' is the fused root at operand 0.
|
||||
// *) Use of 'operand' is DynamicUpdateSlice at operand index 0.
|
||||
for (const BufferAlias& alias : points_to_analysis_->GetBufferAliases(a)) {
|
||||
if (b.instruction()->IsUserOf(alias.instruction()) &&
|
||||
!CanShareOperandBufferWithUser(alias.instruction(), alias.index(),
|
||||
|
@ -612,6 +612,93 @@ TEST_F(FusedDynamicUpdateSliceLivenessTest, WithInterference) {
|
||||
EXPECT_TRUE(Run(/*update_uses_tuple_element1=*/true));
|
||||
}
|
||||
|
||||
class DynamicUpdateSliceLivenessTest : public BufferLivenessTest {
|
||||
protected:
|
||||
// Builds and runs a computation (see test case computation graphs below).
|
||||
// Runs BufferLiveness on this computation.
|
||||
// Returns whether buffer interference is detected between tuple-shaped
|
||||
// parameter and root instructions at tuple element 1.
|
||||
bool Run(const bool tuple_element1_has_two_uses) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
// Create param0 Tuple.
|
||||
Shape data_shape = ShapeUtil::MakeShape(F32, {8});
|
||||
Shape update_shape = ShapeUtil::MakeShape(F32, {3});
|
||||
auto tuple_param0 = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
0, ShapeUtil::MakeTupleShape({data_shape, data_shape}), "param0"));
|
||||
|
||||
auto gte0 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 0));
|
||||
|
||||
auto gte1 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 1));
|
||||
|
||||
auto update = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR1<float>({2.f, 2.f, 2.f})));
|
||||
|
||||
if (tuple_element1_has_two_uses) {
|
||||
// Add 'gte0' and 'gte1' to create another user of 'gte1'.
|
||||
gte0 = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
data_shape, HloOpcode::kAdd, gte0, gte1));
|
||||
}
|
||||
// Create a DynamicUpdateSlice instruction of tuple element 1 with 'update'.
|
||||
auto starts = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR1<int32>({2})));
|
||||
auto dynamic_update_slice =
|
||||
builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
|
||||
data_shape, gte1, update, starts));
|
||||
// Create output tuple.
|
||||
auto tuple_root = builder.AddInstruction(
|
||||
HloInstruction::CreateTuple({gte0, dynamic_update_slice}));
|
||||
// Build module and get reference to entry computation.
|
||||
auto module = MakeUnique<HloModule>(TestName());
|
||||
module->AddEntryComputation(builder.Build());
|
||||
// Run BufferLiveness on 'module'.
|
||||
auto liveness =
|
||||
BufferLiveness::Run(module.get(),
|
||||
MakeUnique<DependencyHloOrdering>(module.get()))
|
||||
.ConsumeValueOrDie();
|
||||
// Return whether or not buffers interfernce is detected between
|
||||
// 'tuple_param0' and 'tuple_root' at shape index '{1}'.
|
||||
return TupleElementsMayInterfere(*liveness, tuple_param0, tuple_root, {1});
|
||||
}
|
||||
};
|
||||
|
||||
// Tests that live ranges of buffers Param0[1] and Tuple[1] do not overlap in
|
||||
// the following computation (because DynamicUpdateSlice (at operand 0) is the
|
||||
// unique user):
|
||||
//
|
||||
// Parameter0
|
||||
// | |
|
||||
// GTE(0) GTE(1) Const Const
|
||||
// | \ | /
|
||||
// | DynamicUpdateSlice
|
||||
// \ /
|
||||
// Tuple
|
||||
//
|
||||
TEST_F(DynamicUpdateSliceLivenessTest, NoInterference) {
|
||||
EXPECT_FALSE(Run(/*tuple_element1_has_two_uses=*/false));
|
||||
}
|
||||
|
||||
// Tests that live ranges of buffers Param0[1] and Tuple[1] do overlap because
|
||||
// GTE(1) has two users:
|
||||
// 1) DynamicUpdateSlice at operand 0.
|
||||
// 2) Add at operand 1.
|
||||
//
|
||||
// Parameter0
|
||||
// | |
|
||||
// GTE(0) GTE(1)
|
||||
// | / |
|
||||
// | / |
|
||||
// Add | Const Const
|
||||
// | | | |
|
||||
// | DynamicUpdateSlice
|
||||
// \ /
|
||||
// Tuple
|
||||
//
|
||||
TEST_F(DynamicUpdateSliceLivenessTest, WithInterference) {
|
||||
EXPECT_TRUE(Run(/*tuple_element1_has_two_uses=*/true));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace xla
|
||||
|
@ -110,8 +110,7 @@ class Compiler {
|
||||
// The compiler may optionally specialize to the individual device
|
||||
// (not just type of device) indicated by the executor.
|
||||
//
|
||||
// TODO(leary) will need to update this API when a single computation can run
|
||||
// across multiple devices simultaneously.
|
||||
// Use the overload below to compile computations that run in parallel.
|
||||
virtual StatusOr<std::unique_ptr<Executable>> Compile(
|
||||
std::unique_ptr<HloModule> module,
|
||||
std::unique_ptr<HloModuleConfig> module_config, HloDumper dump_hlo,
|
||||
|
@ -69,6 +69,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:hlo_pass",
|
||||
"//tensorflow/compiler/xla/service:hlo_pass_pipeline",
|
||||
"//tensorflow/compiler/xla/service:hlo_subcomputation_unification",
|
||||
"//tensorflow/compiler/xla/service:hlo_verifier",
|
||||
"//tensorflow/compiler/xla/service:inliner",
|
||||
"//tensorflow/compiler/xla/service:reshape_mover",
|
||||
"//tensorflow/compiler/xla/service:transpose_folding",
|
||||
|
@ -68,6 +68,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
|
||||
#include "tensorflow/compiler/xla/service/inliner.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
|
||||
#include "tensorflow/compiler/xla/service/reshape_mover.h"
|
||||
@ -214,6 +215,7 @@ Status CpuCompiler::RunHloPasses(HloModule* hlo_module,
|
||||
HloDumper dump_hlo) {
|
||||
// Optimization pipeline.
|
||||
HloPassPipeline pipeline("CPU", dump_hlo);
|
||||
pipeline.AddInvariantChecker<HloVerifier>();
|
||||
|
||||
// TODO(b/35786417): Re-enable inliner pass after fixing the bug and deciding
|
||||
// where we will take this pass in future.
|
||||
@ -573,8 +575,7 @@ CpuCompiler::CompileAheadOfTime(
|
||||
}
|
||||
|
||||
std::vector<std::unique_ptr<AotCompilationResult>> results;
|
||||
for (std::vector<std::unique_ptr<HloModule>>::size_type i = 0;
|
||||
i < hlo_modules.size(); ++i) {
|
||||
for (size_t i = 0; i < hlo_modules.size(); ++i) {
|
||||
HloModule* hlo_module = hlo_modules[i].get();
|
||||
HloModuleConfig* module_config = module_configs[i].get();
|
||||
|
||||
|
@ -24,8 +24,9 @@ namespace cpu {
|
||||
|
||||
class CpuInstructionFusion : public InstructionFusion {
|
||||
public:
|
||||
CpuInstructionFusion() {}
|
||||
~CpuInstructionFusion() override {}
|
||||
CpuInstructionFusion()
|
||||
: InstructionFusion(CpuInstructionFusion::IsExpensive) {}
|
||||
~CpuInstructionFusion() override = default;
|
||||
|
||||
protected:
|
||||
bool ShouldFuse(HloInstruction* consumer, int64 operand_index) override;
|
||||
|
@ -1111,7 +1111,7 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce, HloInstruction* arg,
|
||||
llvm_ir::IrArray::Index input_index = reduced_dims_index;
|
||||
llvm_ir::IrArray::Index::const_iterator it = index.begin();
|
||||
|
||||
for (auto i = 0; i < input_index.size(); ++i) {
|
||||
for (size_t i = 0; i < input_index.size(); ++i) {
|
||||
if (input_index[i] == nullptr) {
|
||||
input_index[i] = *it++;
|
||||
}
|
||||
@ -1180,7 +1180,7 @@ Status IrEmitter::HandlePad(HloInstruction* pad) {
|
||||
// output_index := edge_padding_low + operand_index * (interior_padding + 1)
|
||||
const PaddingConfig& padding_config = pad->padding_config();
|
||||
llvm_ir::IrArray::Index output_index;
|
||||
for (auto i = 0; i < operand_index.size(); ++i) {
|
||||
for (size_t i = 0; i < operand_index.size(); ++i) {
|
||||
llvm::Value* offset = ir_builder_.CreateMul(
|
||||
operand_index[i],
|
||||
ir_builder_.getInt64(padding_config.dimensions(i).interior_padding() +
|
||||
@ -1294,12 +1294,12 @@ Status IrEmitter::HandleCustomCall(
|
||||
llvm_ir::EmitAllocaAtFunctionEntryWithCount(
|
||||
i8_ptr_type, ir_builder_.getInt32(operands.size()),
|
||||
"cc_operands_alloca", &ir_builder_);
|
||||
for (auto i = 0; i < operands.size(); ++i) {
|
||||
for (size_t i = 0; i < operands.size(); ++i) {
|
||||
const HloInstruction* operand = operands[i];
|
||||
llvm::Value* operand_as_i8ptr =
|
||||
ir_builder_.CreatePointerCast(GetEmittedValueFor(operand), i8_ptr_type);
|
||||
llvm::Value* slot_in_operands_alloca = ir_builder_.CreateInBoundsGEP(
|
||||
operands_alloca, {ir_builder_.getInt32(i)});
|
||||
operands_alloca, {ir_builder_.getInt64(i)});
|
||||
ir_builder_.CreateStore(operand_as_i8ptr, slot_in_operands_alloca);
|
||||
}
|
||||
auto* custom_call_ir_function =
|
||||
@ -1659,13 +1659,13 @@ void IrEmitter::EmitArrayFunctionCallInto(
|
||||
ir_builder_.getInt32(parameter_addresses.size()),
|
||||
tensorflow::strings::StrCat(name, "_parameter_addresses"),
|
||||
&ir_builder_);
|
||||
for (auto i = 0; i < parameter_addresses.size(); ++i) {
|
||||
for (size_t i = 0; i < parameter_addresses.size(); ++i) {
|
||||
llvm::Value* parameter_as_i8ptr = ir_builder_.CreateBitCast(
|
||||
parameter_addresses[i], ir_builder_.getInt8PtrTy(),
|
||||
llvm_ir::AsStringRef(tensorflow::strings::StrCat(name, "_parameter_", i,
|
||||
"_address_as_i8ptr")));
|
||||
llvm::Value* slot_in_param_adresses = ir_builder_.CreateInBoundsGEP(
|
||||
parameter_addresses_buffer, {ir_builder_.getInt32(i)});
|
||||
parameter_addresses_buffer, {ir_builder_.getInt64(i)});
|
||||
ir_builder_.CreateStore(parameter_as_i8ptr, slot_in_param_adresses);
|
||||
}
|
||||
|
||||
|
@ -97,77 +97,81 @@ static void MarkLiveAddressesInOutput(
|
||||
}
|
||||
}
|
||||
|
||||
StatusOr<perftools::gputools::DeviceMemoryBase>
|
||||
ParallelCpuExecutable::ExecuteOnStream(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments,
|
||||
HloExecutionProfile* hlo_execution_profile) {
|
||||
se::Stream* stream = run_options->stream();
|
||||
DeviceMemoryAllocator* memory_allocator = run_options->allocator();
|
||||
VLOG(3) << "ExecuteOnStream arg size: " << arguments.size();
|
||||
if (!arguments.empty()) {
|
||||
VLOG(3) << "ExecuteOnStream arg[0]: " << arguments.at(0).opaque();
|
||||
}
|
||||
|
||||
// Allocate the temporary buffers required for the computation.
|
||||
se::StreamExecutor* stream_executor = stream->parent();
|
||||
int device_ordinal = stream_executor->device_ordinal();
|
||||
int64 buffer_count = assignment_->Allocations().size();
|
||||
VLOG(3) << "temp buffer count: " << buffer_count;
|
||||
|
||||
std::vector<se::DeviceMemoryBase> device_allocations;
|
||||
for (BufferAllocation::Index i = 0; i < buffer_count; ++i) {
|
||||
Status ParallelCpuExecutable::AllocateBuffers(
|
||||
DeviceMemoryAllocator* memory_allocator, int device_ordinal,
|
||||
std::vector<perftools::gputools::DeviceMemoryBase>* buffers) {
|
||||
CHECK_EQ(buffers->size(), assignment_->Allocations().size());
|
||||
VLOG(3) << "Allocating " << assignment_->Allocations().size()
|
||||
<< " allocations for module " << module().name();
|
||||
for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size();
|
||||
++i) {
|
||||
auto& allocation = assignment_->GetAllocation(i);
|
||||
|
||||
VLOG(3) << allocation.ToString();
|
||||
|
||||
if (allocation.is_entry_computation_parameter()) {
|
||||
// Buffers do not need to be allocated for parameters.
|
||||
device_allocations.push_back(se::DeviceMemoryBase(nullptr));
|
||||
VLOG(3) << "allocation #" << i << " is a parameter";
|
||||
continue;
|
||||
}
|
||||
|
||||
if (allocation.is_thread_local()) {
|
||||
// Buffers do not need to be allocated for thread-local temporaries.
|
||||
device_allocations.push_back(se::DeviceMemoryBase(nullptr));
|
||||
VLOG(3) << "buffer #" << i << " is thread-local";
|
||||
continue;
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
se::DeviceMemoryBase device_allocation,
|
||||
memory_allocator->Allocate(device_ordinal, allocation.size()));
|
||||
int64 buffer_size = allocation.size();
|
||||
if (!(*buffers)[i].is_null()) {
|
||||
VLOG(3) << "buffer #" << i
|
||||
<< " is in the preallocated result ShapedBuffer";
|
||||
} else {
|
||||
TF_ASSIGN_OR_RETURN((*buffers)[i], memory_allocator->Allocate(
|
||||
device_ordinal, buffer_size));
|
||||
|
||||
if (VLOG_IS_ON(3)) {
|
||||
VLOG(3) << "ParallelCpuExecutable allocating " << allocation.size()
|
||||
<< " bytes for allocation #" << i << " ["
|
||||
<< device_allocation.opaque() << "]";
|
||||
std::vector<string> parts;
|
||||
for (const auto& buffer_offset_size : allocation.assigned_buffers()) {
|
||||
const LogicalBuffer& buffer = *buffer_offset_size.first;
|
||||
parts.push_back(tensorflow::strings::StrCat(
|
||||
buffer.instruction()->parent()->name(), "::", buffer.ToString()));
|
||||
}
|
||||
VLOG(3) << " " << tensorflow::str_util::Join(parts, ", ");
|
||||
VLOG(3) << "buffer #" << i << " allocated " << buffer_size << " bytes ["
|
||||
<< (*buffers)[i].opaque() << "]";
|
||||
}
|
||||
|
||||
device_allocations.push_back(device_allocation);
|
||||
// Since the output buffer and all the temporary buffers were written into
|
||||
// by the JITed code, msan has no way of knowing their memory was
|
||||
// initialized. Mark them initialized so that msan doesn't flag loads from
|
||||
// these buffers.
|
||||
TF_ANNOTATE_MEMORY_IS_INITIALIZED(device_allocation.opaque(),
|
||||
allocation.size());
|
||||
TF_ANNOTATE_MEMORY_IS_INITIALIZED((*buffers)[i].opaque(), buffer_size);
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice,
|
||||
assignment_->GetUniqueTopLevelOutputSlice());
|
||||
const BufferAllocation::Index result_index = result_slice.index();
|
||||
VLOG(3) << "result index: " << result_index;
|
||||
VLOG(3) << "result index: " << result_slice.index();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ParallelCpuExecutable::ExecuteComputeFunctions(
|
||||
const ExecutableRunOptions* run_options,
|
||||
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
|
||||
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> buffers,
|
||||
HloExecutionProfile* hlo_execution_profile) {
|
||||
std::vector<se::DeviceMemoryBase> argument_buffers(arguments.size());
|
||||
for (int i = 0; i < arguments.size(); ++i) {
|
||||
TF_RET_CHECK(!ShapeUtil::IsTuple(arguments[i]->shape()));
|
||||
argument_buffers[i] = arguments[i]->buffer(/*index=*/{});
|
||||
}
|
||||
return ExecuteComputeFunctions(run_options, argument_buffers, buffers,
|
||||
hlo_execution_profile);
|
||||
}
|
||||
|
||||
Status ParallelCpuExecutable::ExecuteComputeFunctions(
|
||||
const ExecutableRunOptions* run_options,
|
||||
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments,
|
||||
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> buffers,
|
||||
HloExecutionProfile* hlo_execution_profile) {
|
||||
// Allocate profiling counters for each hlo instruction that we would like to
|
||||
// profile. Allocate an additional profile counter for the entire
|
||||
// computation.
|
||||
std::vector<uint64> profile_counters(hlo_to_profile_idx_.size() + 1);
|
||||
|
||||
std::vector<void*> buffer_pointers;
|
||||
for (auto& device_allocation : device_allocations) {
|
||||
buffer_pointers.reserve(buffers.size());
|
||||
for (auto device_allocation : buffers) {
|
||||
buffer_pointers.push_back(device_allocation.opaque());
|
||||
}
|
||||
|
||||
@ -210,8 +214,7 @@ ParallelCpuExecutable::ExecuteOnStream(
|
||||
|
||||
void** temps_array = buffer_pointers.data();
|
||||
uint64* profile_counters_array = profile_counters.data();
|
||||
auto* thread_pool =
|
||||
CHECK_NOTNULL(run_options->run_options().inter_op_thread_pool());
|
||||
auto* thread_pool = CHECK_NOTNULL(run_options->inter_op_thread_pool());
|
||||
tensorflow::mutex completion_queue_lock;
|
||||
tensorflow::condition_variable completion_queue_cv;
|
||||
std::deque<HloInstruction*> completion_queue;
|
||||
@ -310,6 +313,42 @@ ParallelCpuExecutable::ExecuteOnStream(
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
StatusOr<perftools::gputools::DeviceMemoryBase>
|
||||
ParallelCpuExecutable::ExecuteOnStream(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments,
|
||||
HloExecutionProfile* hlo_execution_profile) {
|
||||
se::Stream* stream = run_options->stream();
|
||||
DeviceMemoryAllocator* memory_allocator = run_options->allocator();
|
||||
VLOG(3) << "ExecuteOnStream arg size: " << arguments.size();
|
||||
if (!arguments.empty()) {
|
||||
VLOG(3) << "ExecuteOnStream arg[0]: " << arguments.at(0).opaque();
|
||||
}
|
||||
|
||||
// Allocate the temporary buffers required for the computation.
|
||||
se::StreamExecutor* stream_executor = stream->parent();
|
||||
int device_ordinal = stream_executor->device_ordinal();
|
||||
int64 buffer_count = assignment_->Allocations().size();
|
||||
VLOG(3) << "temp buffer count: " << buffer_count;
|
||||
|
||||
std::vector<se::DeviceMemoryBase> device_allocations(
|
||||
assignment_->Allocations().size());
|
||||
TF_RETURN_IF_ERROR(AllocateBuffers(memory_allocator,
|
||||
stream->parent()->device_ordinal(),
|
||||
&device_allocations));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice,
|
||||
assignment_->GetUniqueTopLevelOutputSlice());
|
||||
const BufferAllocation::Index result_index = result_slice.index();
|
||||
VLOG(3) << "result index: " << result_index;
|
||||
|
||||
TF_RETURN_IF_ERROR(ExecuteComputeFunctions(&run_options->run_options(),
|
||||
arguments, device_allocations,
|
||||
hlo_execution_profile));
|
||||
|
||||
// Mark the buffers that are actually live (used in the output) when the
|
||||
// computation finishes executing.
|
||||
std::unordered_set<const void*> marked_addresses;
|
||||
@ -328,7 +367,7 @@ ParallelCpuExecutable::ExecuteOnStream(
|
||||
// live because they are referenced by the output of the computation
|
||||
// and are needed by the service. They will be deallocated by the
|
||||
// service.
|
||||
for (auto i = 0; i < device_allocations.size(); ++i) {
|
||||
for (size_t i = 0; i < device_allocations.size(); ++i) {
|
||||
auto alloc = device_allocations[i];
|
||||
if (marked_addresses.count(alloc.opaque()) == 0 &&
|
||||
alloc.opaque() != nullptr) {
|
||||
@ -345,8 +384,74 @@ StatusOr<std::unique_ptr<ShapedBuffer>> ParallelCpuExecutable::ExecuteOnStream(
|
||||
const ServiceExecutableRunOptions* run_options,
|
||||
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
|
||||
HloExecutionProfile* hlo_execution_profile) {
|
||||
return Unimplemented(
|
||||
"ParallelCpuExecutable not supported yet with LocalService execution");
|
||||
if (GetRootPointsToSet().IsAmbiguous()) {
|
||||
return Unimplemented("Points-to set of root instruction is ambiguous");
|
||||
}
|
||||
|
||||
se::Stream* stream = run_options->stream();
|
||||
DeviceMemoryAllocator* memory_allocator = run_options->allocator();
|
||||
std::vector<se::DeviceMemoryBase> buffers(assignment_->Allocations().size());
|
||||
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<ShapedBuffer> result_buffer,
|
||||
ShapedBuffer::MakeShapedBuffer(
|
||||
result_shape(), stream->parent()->platform(),
|
||||
stream->parent()->device_ordinal()));
|
||||
|
||||
TF_RETURN_IF_ERROR(AllocateBuffers(
|
||||
memory_allocator, stream->parent()->device_ordinal(), &buffers));
|
||||
|
||||
TF_RETURN_IF_ERROR(ExecuteComputeFunctions(
|
||||
&run_options->run_options(), arguments, buffers, hlo_execution_profile));
|
||||
|
||||
// Copy DeviceMemoryBase values which contain the array(s) of the result into
|
||||
// the respective location in ShapedBuffer which is returned to the caller.
|
||||
std::vector<bool> buffers_in_result(assignment_->Allocations().size(), false);
|
||||
TF_RETURN_IF_ERROR(
|
||||
result_buffer->mutable_shape_index_to_buffer_entry()
|
||||
->ForEachMutableElement(
|
||||
[&buffers, &buffers_in_result, &result_buffer, this](
|
||||
const ShapeIndex& index, bool is_leaf, size_t* buffer_entry) {
|
||||
if (is_leaf) {
|
||||
const std::vector<const LogicalBuffer*>& sources =
|
||||
this->GetRootPointsToSet().element(index);
|
||||
// The points to set is unambiguous so the set should be a
|
||||
// singleton.
|
||||
CHECK_EQ(1, sources.size());
|
||||
const LogicalBuffer* buffer_source = sources[0];
|
||||
HloInstruction* src = buffer_source->instruction();
|
||||
|
||||
// The source for this result buffer can be a nested buffer
|
||||
// such as a tuple element.
|
||||
|
||||
// The source instruction should have a non-parameter buffer
|
||||
// assigned.
|
||||
TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice,
|
||||
this->assignment_->GetUniqueSlice(
|
||||
src, buffer_source->index()));
|
||||
CHECK(!slice.allocation()->is_entry_computation_parameter());
|
||||
|
||||
const BufferAllocation::Index buffer_index = slice.index();
|
||||
const se::DeviceMemoryBase& buffer = buffers[buffer_index];
|
||||
CHECK(!buffer.is_null() || buffer.size() == 0);
|
||||
*buffer_entry = result_buffer->mutable_buffers()->size();
|
||||
result_buffer->mutable_buffers()->push_back(buffer);
|
||||
buffers_in_result[buffer_index] = true;
|
||||
}
|
||||
return Status::OK();
|
||||
}));
|
||||
|
||||
// Free all buffers not in the result.
|
||||
for (size_t i = 0; i < buffers.size(); ++i) {
|
||||
se::DeviceMemoryBase alloc = buffers[i];
|
||||
if (!buffers_in_result[i] && !alloc.is_null()) {
|
||||
VLOG(3) << "CpuExecutable deallocating buffer #" << i << " ["
|
||||
<< alloc.opaque() << "]";
|
||||
TF_RETURN_IF_ERROR(memory_allocator->Deallocate(
|
||||
stream->parent()->device_ordinal(), &alloc));
|
||||
}
|
||||
}
|
||||
|
||||
return std::move(result_buffer);
|
||||
}
|
||||
|
||||
StatusOr<perftools::gputools::DeviceMemoryBase>
|
||||
@ -358,5 +463,10 @@ ParallelCpuExecutable::ExecuteAsyncOnStream(
|
||||
"Asynchronous execution on stream is not yet supported on CPU.");
|
||||
}
|
||||
|
||||
const PointsToSet& ParallelCpuExecutable::GetRootPointsToSet() const {
|
||||
return assignment_->points_to_analysis().GetPointsToSet(
|
||||
module().entry_computation()->root_instruction());
|
||||
}
|
||||
|
||||
} // namespace cpu
|
||||
} // namespace xla
|
||||
|
@ -84,6 +84,35 @@ class ParallelCpuExecutable : public Executable {
|
||||
}
|
||||
|
||||
private:
|
||||
// Allocate buffers required for execution and assign them to the elements of
|
||||
// "buffers". "buffers" should be sized to the number of buffers in buffer
|
||||
// assignment. Each vector element corresponds to a particular Index. If
|
||||
// a vector element already contains a non-null DeviceMemoryBase, then no
|
||||
// buffer is assigned for this element.
|
||||
Status AllocateBuffers(
|
||||
DeviceMemoryAllocator* memory_allocator, int device_ordinal,
|
||||
std::vector<perftools::gputools::DeviceMemoryBase>* buffers);
|
||||
|
||||
// Calls the generated functions in 'function_names_', performing the
|
||||
// computation with the given arguments using the supplied buffers.
|
||||
Status ExecuteComputeFunctions(
|
||||
const ExecutableRunOptions* run_options,
|
||||
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
|
||||
arguments,
|
||||
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
|
||||
buffers,
|
||||
HloExecutionProfile* hlo_execution_profile);
|
||||
Status ExecuteComputeFunctions(
|
||||
const ExecutableRunOptions* run_options,
|
||||
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
|
||||
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
|
||||
buffers,
|
||||
HloExecutionProfile* hlo_execution_profile);
|
||||
|
||||
// Returns the points-to set of the root instruction of the entry
|
||||
// computation. Uses points-to analysis from buffer assignment.
|
||||
const PointsToSet& GetRootPointsToSet() const;
|
||||
|
||||
// The JIT containing compiled modules.
|
||||
tensorflow::mutex jit_mutex_;
|
||||
std::unique_ptr<SimpleOrcJIT> jit_ GUARDED_BY(jit_mutex_);
|
||||
|
@ -937,6 +937,68 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
|
||||
};
|
||||
case HloOpcode::kRng:
|
||||
return MakeRngElementGenerator(hlo, operand_to_generator);
|
||||
case HloOpcode::kPad:
|
||||
return [=, &operand_to_generator](
|
||||
const IrArray::Index& padded_index) -> StatusOr<llvm::Value*> {
|
||||
auto index = padded_index;
|
||||
llvm::Value* in_bounds = ir_builder_->getTrue();
|
||||
for (size_t i = 0; i < index.size(); ++i) {
|
||||
auto index_typed_const = [=](int64 n) {
|
||||
return llvm::ConstantInt::get(index[i]->getType(), n);
|
||||
};
|
||||
const auto& pad_dim = hlo->padding_config().dimensions(i);
|
||||
index[i] = ir_builder_->CreateSub(
|
||||
index[i], index_typed_const(pad_dim.edge_padding_low()));
|
||||
in_bounds = ir_builder_->CreateAnd(
|
||||
in_bounds,
|
||||
ir_builder_->CreateICmpSGE(index[i], index_typed_const(0)),
|
||||
"in_bounds");
|
||||
in_bounds = ir_builder_->CreateAnd(
|
||||
in_bounds,
|
||||
ir_builder_->CreateICmpEQ(
|
||||
index_typed_const(0),
|
||||
ir_builder_->CreateURem(
|
||||
index[i],
|
||||
index_typed_const(pad_dim.interior_padding() + 1))),
|
||||
"in_bounds");
|
||||
index[i] = ir_builder_->CreateSDiv(
|
||||
index[i], index_typed_const(pad_dim.interior_padding() + 1));
|
||||
in_bounds = ir_builder_->CreateAnd(
|
||||
in_bounds,
|
||||
ir_builder_->CreateICmpSLT(
|
||||
index[i],
|
||||
index_typed_const(hlo->operand(0)->shape().dimensions(i))),
|
||||
"in_bounds");
|
||||
}
|
||||
|
||||
// if (in_bounds) {
|
||||
// ret_value = operand0[index]; // source
|
||||
// } else {
|
||||
// ret_value = *operand1; // padding
|
||||
// }
|
||||
llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry(
|
||||
llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(),
|
||||
ir_builder_),
|
||||
"pad_result_addr", ir_builder_);
|
||||
llvm_ir::LlvmIfData if_data =
|
||||
llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_);
|
||||
SetToFirstInsertPoint(if_data.true_block, ir_builder_);
|
||||
TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
|
||||
operand_to_generator.at(hlo->operand(0))(index));
|
||||
ir_builder_->CreateStore(operand_value, ret_value_addr);
|
||||
|
||||
SetToFirstInsertPoint(if_data.false_block, ir_builder_);
|
||||
TF_ASSIGN_OR_RETURN(llvm::Value * padding_value,
|
||||
operand_to_generator.at(hlo->operand(1))({}));
|
||||
ir_builder_->CreateStore(padding_value, ret_value_addr);
|
||||
|
||||
SetToFirstInsertPoint(if_data.after_block, ir_builder_);
|
||||
// Don't create phi(operand_value, padding_value) here, because invoking
|
||||
// operand_to_generator may create new basic blocks, making the parent
|
||||
// of operand_value or padding_value no longer a predecessor of
|
||||
// if_data.after_block.
|
||||
return ir_builder_->CreateLoad(ret_value_addr);
|
||||
};
|
||||
default:
|
||||
return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
|
||||
return Unimplemented("%s", HloOpcodeString(hlo->opcode()).c_str());
|
||||
|
@ -40,8 +40,7 @@ Executable::ExecuteOnStreams(
|
||||
|
||||
std::vector<perftools::gputools::DeviceMemoryBase> return_values(
|
||||
run_options.size());
|
||||
for (tensorflow::gtl::ArraySlice<const ExecutableRunOptions>::size_type i = 0;
|
||||
i < run_options.size(); ++i) {
|
||||
for (size_t i = 0; i < run_options.size(); ++i) {
|
||||
// We cannot BlockHostUntilDone() on the already-launched executions in case
|
||||
// of error, since if the executions communicate, the initially launched
|
||||
// executions may never complete if not all executions are running.
|
||||
|
@ -39,9 +39,6 @@ namespace xla {
|
||||
|
||||
// A given platform's compiler will produce an Executable -- this is a uniform
|
||||
// interface that is used for launching compiled programs across platforms.
|
||||
//
|
||||
// TODO(leary) will need to extend this to support multiple streams/devices as
|
||||
// we begin to compile single programs to run on multiple devices.
|
||||
class Executable {
|
||||
public:
|
||||
explicit Executable(std::unique_ptr<HloModule> hlo_module,
|
||||
|
@ -118,7 +118,7 @@ GenericTransferManager::ShallowCopyTupleFromDevice(
|
||||
|
||||
// Create a DeviceMemoryBase from each void* pointer.
|
||||
std::vector<se::DeviceMemoryBase> destination;
|
||||
for (std::vector<void*>::size_type i = 0; i < element_pointers.size(); ++i) {
|
||||
for (size_t i = 0; i < element_pointers.size(); ++i) {
|
||||
if (element_pointers[i] == nullptr &&
|
||||
!ShapeUtil::HasZeroElements(shape.tuple_shapes(i))) {
|
||||
return FailedPrecondition("tuple contains nullptr at element %lu", i);
|
||||
|
@ -356,13 +356,12 @@ cc_library(
|
||||
srcs = ["fusion_merger.cc"],
|
||||
hdrs = ["fusion_merger.h"],
|
||||
deps = [
|
||||
":instruction_fusion",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_cost_analysis",
|
||||
"//tensorflow/compiler/xla/service:hlo_pass",
|
||||
"//tensorflow/compiler/xla/service:instruction_fusion",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
@ -434,6 +433,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:hlo_pass",
|
||||
"//tensorflow/compiler/xla/service:hlo_pass_pipeline",
|
||||
"//tensorflow/compiler/xla/service:hlo_subcomputation_unification",
|
||||
"//tensorflow/compiler/xla/service:hlo_verifier",
|
||||
"//tensorflow/compiler/xla/service:reshape_mover",
|
||||
"//tensorflow/compiler/xla/service:transpose_folding",
|
||||
"//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend",
|
||||
|
@ -87,7 +87,8 @@ tensorflow::Status BufferAllocations::TearDown(
|
||||
const std::set<se::DeviceMemoryBase>& live_addresses,
|
||||
const BufferAssignment& buffer_assignment) {
|
||||
// Deallocate temporary buffers.
|
||||
for (auto i = 0; i < buffer_assignment.Allocations().size(); ++i) {
|
||||
const int64 num_buffers = buffer_assignment.Allocations().size();
|
||||
for (BufferAllocation::Index i = 0; i < num_buffers; ++i) {
|
||||
const BufferAllocation& allocation = buffer_assignment.GetAllocation(i);
|
||||
se::DeviceMemoryBase buffer_address = GetDeviceAddress(allocation.index());
|
||||
// Deallocate buffers marked "maybe_live_out" but aren't actually live out,
|
||||
|
@ -270,69 +270,6 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator(
|
||||
const HloInstruction* hlo,
|
||||
const HloToElementGeneratorMap& operand_to_generator) const {
|
||||
switch (hlo->opcode()) {
|
||||
case HloOpcode::kPad:
|
||||
return [=, &operand_to_generator](
|
||||
const IrArray::Index& padded_index) -> StatusOr<llvm::Value*> {
|
||||
auto index = padded_index;
|
||||
llvm::Value* in_bounds =
|
||||
llvm::ConstantInt::get(ir_builder_->getInt1Ty(), 1);
|
||||
for (auto i = 0; i < index.size(); ++i) {
|
||||
auto index_typed_const = [=](int64 n) {
|
||||
return llvm::ConstantInt::get(index[i]->getType(), n);
|
||||
};
|
||||
const auto& pad_dim = hlo->padding_config().dimensions(i);
|
||||
index[i] = ir_builder_->CreateSub(
|
||||
index[i], index_typed_const(pad_dim.edge_padding_low()));
|
||||
in_bounds = ir_builder_->CreateAnd(
|
||||
in_bounds,
|
||||
ir_builder_->CreateICmpSGE(index[i], index_typed_const(0)),
|
||||
"in_bounds");
|
||||
in_bounds = ir_builder_->CreateAnd(
|
||||
in_bounds,
|
||||
ir_builder_->CreateICmpEQ(
|
||||
index_typed_const(0),
|
||||
ir_builder_->CreateURem(
|
||||
index[i],
|
||||
index_typed_const(pad_dim.interior_padding() + 1))),
|
||||
"in_bounds");
|
||||
index[i] = ir_builder_->CreateSDiv(
|
||||
index[i], index_typed_const(pad_dim.interior_padding() + 1));
|
||||
in_bounds = ir_builder_->CreateAnd(
|
||||
in_bounds,
|
||||
ir_builder_->CreateICmpSLT(
|
||||
index[i],
|
||||
index_typed_const(hlo->operand(0)->shape().dimensions(i))),
|
||||
"in_bounds");
|
||||
}
|
||||
|
||||
// if (in_bounds) {
|
||||
// ret_value = operand0[index]; // source
|
||||
// } else {
|
||||
// ret_value = *operand1; // padding
|
||||
// }
|
||||
llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry(
|
||||
llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(),
|
||||
ir_builder_),
|
||||
"pad_result_addr", ir_builder_);
|
||||
llvm_ir::LlvmIfData if_data =
|
||||
llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_);
|
||||
SetToFirstInsertPoint(if_data.true_block, ir_builder_);
|
||||
TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
|
||||
operand_to_generator.at(hlo->operand(0))(index));
|
||||
ir_builder_->CreateStore(operand_value, ret_value_addr);
|
||||
|
||||
SetToFirstInsertPoint(if_data.false_block, ir_builder_);
|
||||
TF_ASSIGN_OR_RETURN(llvm::Value * padding_value,
|
||||
operand_to_generator.at(hlo->operand(1))({}));
|
||||
ir_builder_->CreateStore(padding_value, ret_value_addr);
|
||||
|
||||
SetToFirstInsertPoint(if_data.after_block, ir_builder_);
|
||||
// Don't create phi(operand_value, padding_value) here, because invoking
|
||||
// operand_to_generator may create new basic blocks, making the parent
|
||||
// of operand_value or padding_value no longer a predecessor of
|
||||
// if_data.after_block.
|
||||
return ir_builder_->CreateLoad(ret_value_addr);
|
||||
};
|
||||
case HloOpcode::kMap:
|
||||
return [=, &operand_to_generator](
|
||||
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
|
||||
|
@ -18,8 +18,8 @@ limitations under the License.
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
|
||||
#include "tensorflow/compiler/xla/service/instruction_fusion.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
@ -221,7 +221,7 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) {
|
||||
fusion->fused_instructions().end(),
|
||||
[](const std::unique_ptr<HloInstruction>& instruction) {
|
||||
if (instruction->opcode() != HloOpcode::kParameter &&
|
||||
IsExpensive(*instruction)) {
|
||||
GpuInstructionFusion::IsExpensive(*instruction)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
|
@ -51,6 +51,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
|
||||
#include "tensorflow/compiler/xla/service/reshape_mover.h"
|
||||
#include "tensorflow/compiler/xla/service/transpose_folding.h"
|
||||
@ -121,6 +122,7 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module,
|
||||
const se::DeviceDescription& device_desc) {
|
||||
{
|
||||
HloPassPipeline pipeline("optimization", dump_hlo);
|
||||
pipeline.AddInvariantChecker<HloVerifier>();
|
||||
{
|
||||
auto& pass = pipeline.AddPass<HloPassFix<HloPassPipeline>>(
|
||||
"simplification", dump_hlo);
|
||||
@ -157,6 +159,7 @@ tensorflow::Status PrepareHloModuleForIrEmitting(
|
||||
// (b/27180329). Therefore, in that case, we set the output to be a copy of
|
||||
// the parameter.
|
||||
HloPassPipeline pipeline("GPU-ir-emit-prepare", dump_hlo);
|
||||
pipeline.AddInvariantChecker<HloVerifier>();
|
||||
pipeline.AddPass<PadInsertion>();
|
||||
pipeline.AddPass<GpuLayoutAssignment>(
|
||||
module_config->mutable_entry_computation_layout());
|
||||
|
@ -25,7 +25,7 @@ namespace gpu {
|
||||
class GpuInstructionFusion : public InstructionFusion {
|
||||
public:
|
||||
explicit GpuInstructionFusion(bool may_duplicate)
|
||||
: InstructionFusion(may_duplicate) {}
|
||||
: InstructionFusion(GpuInstructionFusion::IsExpensive, may_duplicate) {}
|
||||
|
||||
bool ShouldFuse(HloInstruction* consumer, int64 operand_index) override;
|
||||
|
||||
|
@ -513,7 +513,7 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce, HloInstruction* arg,
|
||||
llvm_ir::IrArray::Index input_index = reduced_dims_index;
|
||||
llvm_ir::IrArray::Index::const_iterator it = index.begin();
|
||||
|
||||
for (auto i = 0; i < input_index.size(); ++i) {
|
||||
for (size_t i = 0; i < input_index.size(); ++i) {
|
||||
if (input_index[i] == nullptr) {
|
||||
input_index[i] = *it++;
|
||||
}
|
||||
@ -614,7 +614,7 @@ llvm_ir::IrArray::Index IrEmitter::EmitOperandArrayLoopNest(
|
||||
llvm_ir::IrArray::Index index =
|
||||
loop_nest->AddLoopsForShapeOnDimensions(shape, dimensions, name_suffix);
|
||||
// Verify every dimension except the reduction dimension was set in the index.
|
||||
for (auto dimension = 0; dimension < index.size(); ++dimension) {
|
||||
for (size_t dimension = 0; dimension < index.size(); ++dimension) {
|
||||
if (dimension == reduction_dimension) {
|
||||
DCHECK_EQ(nullptr, index[dimension]);
|
||||
} else {
|
||||
|
@ -283,14 +283,7 @@ bool CanUpdateDynamicSliceInPlace(const BufferAssignment& assignment,
|
||||
return false;
|
||||
}
|
||||
auto* operand = fusion->operand(fusion_operand->parameter_number());
|
||||
|
||||
BufferAllocation::Slice operand_slice =
|
||||
assignment.GetUniqueSlice(operand, index).ConsumeValueOrDie();
|
||||
|
||||
BufferAllocation::Slice fusion_slice =
|
||||
assignment.GetUniqueTopLevelSlice(fusion).ConsumeValueOrDie();
|
||||
|
||||
return operand_slice == fusion_slice;
|
||||
return assignment.SharesSliceAtIndex(fusion, {}, operand, index);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@ -387,9 +380,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
|
||||
TF_RETURN_IF_ERROR(root->Accept(&fused_emitter));
|
||||
|
||||
// Recursively lookup 'fusion_operand' for DynamicUpdateSlice operand 0.
|
||||
ShapeIndex index_unused;
|
||||
auto* fusion_operand =
|
||||
LatestNonGteAncestorAndIndex(root->operand(0), &index_unused);
|
||||
auto* fusion_operand = LatestNonGteAncestor(root->operand(0));
|
||||
CHECK_EQ(HloOpcode::kParameter, fusion_operand->opcode());
|
||||
|
||||
// Operand(0) the input array which shares an allocation with the output.
|
||||
|
@ -79,7 +79,7 @@ ThunkSchedule::ThunkSchedule(
|
||||
|
||||
void ThunkSchedule::RemoveRedundantDependencyEdges() {
|
||||
std::unordered_map<const Thunk*, int> thunk_to_total_order;
|
||||
for (auto i = 0; i < thunk_total_order_.size(); ++i) {
|
||||
for (int i = 0; i < thunk_total_order_.size(); ++i) {
|
||||
InsertOrDie(&thunk_to_total_order, thunk_total_order_[i], i);
|
||||
}
|
||||
|
||||
|
@ -92,13 +92,22 @@ HloInstruction* HloComputation::AddInstructionInternal(
|
||||
// Generate a unique name for the instruction.
|
||||
instruction->set_name(
|
||||
instruction_name_uniquer_.GetUniqueName(instruction->name()));
|
||||
instruction->set_parent(this);
|
||||
Reparent(instruction.get());
|
||||
HloInstruction* pinst = instruction.get();
|
||||
instruction_iterators_[pinst] =
|
||||
instructions_.insert(instructions_.end(), std::move(instruction));
|
||||
return pinst;
|
||||
}
|
||||
|
||||
void HloComputation::Reparent(HloInstruction* instruction) {
|
||||
instruction->set_parent(this);
|
||||
if (instruction->opcode() == HloOpcode::kFusion) {
|
||||
for (auto& i : instruction->fused_instructions()) {
|
||||
Reparent(i.get());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* static */ bool HloComputation::IsRemovable(const HloOpcode& opcode) {
|
||||
return !(opcode == HloOpcode::kParameter || opcode == HloOpcode::kRecv ||
|
||||
opcode == HloOpcode::kSend || opcode == HloOpcode::kTrace ||
|
||||
|
@ -235,6 +235,14 @@ class HloComputation {
|
||||
HloInstruction* AddInstructionInternal(
|
||||
std::unique_ptr<HloInstruction> instruction);
|
||||
|
||||
// Helper for setting the parent of instructions that are added to this
|
||||
// computation.
|
||||
//
|
||||
// Because we clone HLO instructions without knowing what computation they're
|
||||
// destined to be added to, this is required to appropriate set the parent on
|
||||
// fused instruction sequences.
|
||||
void Reparent(HloInstruction* instruction);
|
||||
|
||||
// Fuses HLOs in instructions_to_fuse into fusion_instruction.
|
||||
//
|
||||
// Pre-condition: fusion_instruction's opcode is kFusion.
|
||||
|
@ -75,8 +75,7 @@ string InstructionSequenceGraph(
|
||||
std::vector<HloInstruction*> param_instructions;
|
||||
for (auto& instruction : instructions) {
|
||||
if (instruction->opcode() == HloOpcode::kParameter) {
|
||||
std::vector<HloInstruction*>::size_type param_number =
|
||||
instruction->parameter_number();
|
||||
size_t param_number = instruction->parameter_number();
|
||||
|
||||
if (param_instructions.size() < param_number + 1) {
|
||||
param_instructions.resize(param_number + 1, nullptr);
|
||||
|
@ -431,6 +431,7 @@ HloInstruction::CreateSelectAndScatter(
|
||||
const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root) {
|
||||
auto instruction = WrapUnique(new HloInstruction(HloOpcode::kFusion, shape));
|
||||
instruction->fusion_kind_ = fusion_kind;
|
||||
instruction->set_parent(fused_root->parent());
|
||||
instruction->CloneAndFuseInternal(fused_root);
|
||||
instruction->CheckFusionInstruction();
|
||||
return instruction;
|
||||
@ -568,6 +569,7 @@ HloInstruction* HloInstruction::CloneAndFuseInternal(
|
||||
std::unique_ptr<HloInstruction> param_instruction =
|
||||
CreateParameter(param_no, operand->shape(), "fusion_param");
|
||||
|
||||
param_instruction->set_parent(parent());
|
||||
param_instruction->parent_fusion_instruction_ = this;
|
||||
fused_parameters_.push_back(param_instruction.get());
|
||||
fused_instructions_.push_back(std::move(param_instruction));
|
||||
@ -602,6 +604,7 @@ void HloInstruction::CheckFusionInstruction() const {
|
||||
for (auto& instruction : fused_instructions_) {
|
||||
CHECK(instruction->IsFused());
|
||||
CHECK_EQ(this, instruction->fusion_instruction());
|
||||
CHECK_EQ(parent(), instruction->parent()) << instruction->ToString();
|
||||
}
|
||||
|
||||
// Fused root instruction and fused parameters must all be owned by the fusion
|
||||
@ -838,12 +841,14 @@ std::unique_ptr<HloInstruction> HloInstruction::Clone(const string& suffix) {
|
||||
std::unique_ptr<HloInstruction> clone =
|
||||
CloneWithNewOperands(shape_, operands_);
|
||||
clone->name_ = name() + "." + suffix;
|
||||
clone->set_parent(parent());
|
||||
return clone;
|
||||
}
|
||||
|
||||
std::unique_ptr<HloInstruction> HloInstruction::CloneFusionWithNewOperands(
|
||||
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
|
||||
CHECK_EQ(opcode_, HloOpcode::kFusion);
|
||||
CHECK(parent() != nullptr);
|
||||
|
||||
auto new_instruction =
|
||||
WrapUnique(new HloInstruction(HloOpcode::kFusion, shape));
|
||||
@ -883,6 +888,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneFusionWithNewOperands(
|
||||
old_fused_instruction->CloneWithNewOperands(
|
||||
old_fused_instruction->shape(), new_operands));
|
||||
HloInstruction* new_fused_instruction = new_fused_instructions.back().get();
|
||||
new_fused_instruction->set_parent(parent());
|
||||
new_fused_instruction->parent_fusion_instruction_ = new_instruction.get();
|
||||
InsertOrDie(&old_to_new, old_fused_instruction, new_fused_instruction);
|
||||
}
|
||||
@ -893,6 +899,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneFusionWithNewOperands(
|
||||
new_instruction->fused_instructions_ = std::move(new_fused_instructions);
|
||||
new_instruction->fused_parameters_ = std::move(new_fused_parameters);
|
||||
new_instruction->fused_root_ = FindOrDie(old_to_new, fused_root_);
|
||||
new_instruction->set_parent(parent());
|
||||
new_instruction->CheckFusionInstruction();
|
||||
return new_instruction;
|
||||
}
|
||||
|
@ -538,6 +538,9 @@ class HloInstruction {
|
||||
// instruction. The order is a reverse postorder of the fused expression (root
|
||||
// is first in the order).
|
||||
//
|
||||
// Note: although the list itself is const, the instructions contained in the
|
||||
// list returned here are mutable.
|
||||
//
|
||||
// Precondition: opcode() == HloOpcode::kFusion
|
||||
const std::list<std::unique_ptr<HloInstruction>>& fused_instructions() const;
|
||||
|
||||
|
@ -26,6 +26,8 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
using ::tensorflow::strings::StrAppend;
|
||||
|
||||
namespace xla {
|
||||
|
||||
namespace {
|
||||
@ -44,6 +46,14 @@ StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
|
||||
tensorflow::str_util::Split(flags->xla_disable_hlo_passes, ',');
|
||||
tensorflow::gtl::FlatSet<string> disabled_passes(tmp.begin(), tmp.end());
|
||||
|
||||
auto run_invariant_checkers = [this, module]() -> Status {
|
||||
for (auto& invariant_checker : invariant_checkers_) {
|
||||
TF_ASSIGN_OR_RETURN(bool changed, invariant_checker->Run(module));
|
||||
TF_RET_CHECK(!changed) << "invariant checkers must not change the graph";
|
||||
}
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
string prefix = name().ToString() + ": pipeline start";
|
||||
bool changed = false;
|
||||
string message;
|
||||
@ -55,15 +65,17 @@ StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
|
||||
|
||||
// Emit label containing: "after foo-pass, before bar-pass".
|
||||
message.clear();
|
||||
tensorflow::strings::StrAppend(&message, prefix, ", before ", pass->name());
|
||||
StrAppend(&message, prefix, ", before ", pass->name());
|
||||
DumpModule(dumper_, *module, message);
|
||||
|
||||
TF_RETURN_IF_ERROR(run_invariant_checkers());
|
||||
TF_ASSIGN_OR_RETURN(bool changed_this_pass, pass->Run(module));
|
||||
|
||||
changed |= changed_this_pass;
|
||||
prefix.clear();
|
||||
tensorflow::strings::StrAppend(&prefix, name(), ": after ", pass->name());
|
||||
StrAppend(&prefix, name(), ": after ", pass->name());
|
||||
}
|
||||
TF_RETURN_IF_ERROR(run_invariant_checkers());
|
||||
DumpModule(dumper_, *module, prefix + ", pipeline end");
|
||||
return changed;
|
||||
}
|
||||
|
@ -52,6 +52,16 @@ class HloPassPipeline : public HloPassInterface {
|
||||
return *pass;
|
||||
}
|
||||
|
||||
// Add an invariant-checking pass to the pipeline. It will be run before and
|
||||
// after each HLO pass. The invariant checking pass must not mutate the graph
|
||||
// (it is required to always return "false" from its Run() method).
|
||||
template <typename T, typename... Args>
|
||||
T& AddInvariantChecker(Args&&... args) {
|
||||
auto pass = new T(std::forward<Args>(args)...);
|
||||
invariant_checkers_.push_back(std::unique_ptr<T>(pass));
|
||||
return *pass;
|
||||
}
|
||||
|
||||
// Run all passes on the given HLO module.
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
|
||||
@ -59,6 +69,7 @@ class HloPassPipeline : public HloPassInterface {
|
||||
const string name_;
|
||||
Compiler::HloDumper dumper_;
|
||||
std::vector<std::unique_ptr<HloPassInterface>> passes_;
|
||||
std::vector<std::unique_ptr<HloPassInterface>> invariant_checkers_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(HloPassPipeline);
|
||||
};
|
||||
|
38
tensorflow/compiler/xla/service/hlo_verifier.cc
Normal file
38
tensorflow/compiler/xla/service/hlo_verifier.cc
Normal file
@ -0,0 +1,38 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
StatusOr<bool> HloVerifier::Run(HloModule* module) {
|
||||
for (auto& computation : module->computations()) {
|
||||
for (const auto& instruction : computation->instructions()) {
|
||||
TF_RET_CHECK(instruction->parent() == computation.get());
|
||||
if (instruction->opcode() == HloOpcode::kFusion) {
|
||||
for (const auto& fused : instruction->fused_instructions()) {
|
||||
TF_RET_CHECK(fused->parent() == computation.get())
|
||||
<< "Fused HLO was missing a parent: " << fused->ToString()
|
||||
<< " parent: " << fused->parent()
|
||||
<< " computation: " << computation.get();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace xla
|
37
tensorflow/compiler/xla/service/hlo_verifier.h
Normal file
37
tensorflow/compiler/xla/service/hlo_verifier.h
Normal file
@ -0,0 +1,37 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// HLO pass that verifies invariants of HLO instructions for each computation in
|
||||
// the module.
|
||||
class HloVerifier : public HloPassInterface {
|
||||
public:
|
||||
~HloVerifier() override = default;
|
||||
tensorflow::StringPiece name() const override { return "verifier"; }
|
||||
|
||||
// Note: always returns false (no instructions are ever modified by this
|
||||
// pass).
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_
|
@ -29,7 +29,8 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
bool IsExpensive(const HloInstruction& instruction) {
|
||||
/*static*/ bool InstructionFusion::IsExpensive(
|
||||
const HloInstruction& instruction) {
|
||||
switch (instruction.opcode()) {
|
||||
// Cheap instructions.
|
||||
case HloOpcode::kAbs:
|
||||
@ -105,9 +106,14 @@ bool IsExpensive(const HloInstruction& instruction) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool FusionWouldDuplicate(HloInstruction* producer, HloInstruction* consumer) {
|
||||
return !(producer->users().size() == 1 && consumer->IsUserOf(producer));
|
||||
namespace {
|
||||
// Returns true if fusing producer into consumer would cause producer to be
|
||||
// duplicated. This is the case if producer has uses other than consumer.
|
||||
bool FusionWouldDuplicate(const HloInstruction& producer,
|
||||
const HloInstruction& consumer) {
|
||||
return !(producer.users().size() == 1 && consumer.IsUserOf(&producer));
|
||||
}
|
||||
} // namespace
|
||||
|
||||
StatusOr<bool> InstructionFusion::Run(HloModule* module) {
|
||||
bool changed = false;
|
||||
@ -125,8 +131,7 @@ StatusOr<bool> InstructionFusion::Run(HloModule* module) {
|
||||
std::vector<HloInstruction*> post_order(post_order_list.begin(),
|
||||
post_order_list.end());
|
||||
tensorflow::gtl::FlatMap<HloInstruction*, int> post_order_index;
|
||||
for (std::vector<HloInstruction*>::size_type i = 0; i < post_order.size();
|
||||
++i) {
|
||||
for (size_t i = 0; i < post_order.size(); ++i) {
|
||||
InsertOrDie(&post_order_index, post_order[i], i);
|
||||
}
|
||||
|
||||
@ -263,8 +268,8 @@ bool InstructionFusion::ShouldFuse(HloInstruction* consumer,
|
||||
int64 operand_index) {
|
||||
HloInstruction* producer = consumer->mutable_operand(operand_index);
|
||||
// Cost condition: don't duplicate expensive instructions.
|
||||
if (FusionWouldDuplicate(producer, consumer) &&
|
||||
(IsExpensive(*producer) || !may_duplicate_)) {
|
||||
if (FusionWouldDuplicate(*producer, *consumer) &&
|
||||
(is_expensive_(*producer) || !may_duplicate_)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -277,7 +282,7 @@ bool InstructionFusion::ShouldFuse(HloInstruction* consumer,
|
||||
// Cost condition: not fuse (expensive producers) and (consumers who reuse
|
||||
// operand elements).
|
||||
if (consumer->ReusesOperandElements(operand_index) &&
|
||||
IsExpensive(*producer)) {
|
||||
is_expensive_(*producer)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -24,15 +24,6 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Returns true if the computation of the given instruction is significantly
|
||||
// more expensive than just writing all the values of the instructions' result
|
||||
// array. Expensive operations should not be duplicated.
|
||||
bool IsExpensive(const HloInstruction& instruction);
|
||||
|
||||
// Returns true if fusing producer into consumer would cause producer to be
|
||||
// duplicated. This is the case if producer has uses other than consumer.
|
||||
bool FusionWouldDuplicate(HloInstruction* producer, HloInstruction* consumer);
|
||||
|
||||
// HLO pass which performs instruction fusion. Instructions are fused
|
||||
// "vertically", meaning producing instructions are fused into their consumers
|
||||
// with the intent that the loops which compute their values will be fused in
|
||||
@ -40,15 +31,22 @@ bool FusionWouldDuplicate(HloInstruction* producer, HloInstruction* consumer);
|
||||
// instructions to fuse.
|
||||
class InstructionFusion : public HloPassInterface {
|
||||
public:
|
||||
explicit InstructionFusion(bool may_duplicate = true)
|
||||
: may_duplicate_(may_duplicate) {}
|
||||
~InstructionFusion() override {}
|
||||
explicit InstructionFusion(
|
||||
std::function<bool(const HloInstruction& instruction)> is_expensive,
|
||||
bool may_duplicate = true)
|
||||
: is_expensive_(is_expensive), may_duplicate_(may_duplicate) {}
|
||||
~InstructionFusion() override = default;
|
||||
tensorflow::StringPiece name() const override { return "fusion"; }
|
||||
|
||||
// Run instruction fusion on the given computation. Returns whether the
|
||||
// computation was changed (instructions were fused).
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
|
||||
// Returns true if the computation of the given instruction is significantly
|
||||
// more expensive than just writing all the values of the instructions' result
|
||||
// array. Expensive operations will not be duplicated.
|
||||
static bool IsExpensive(const HloInstruction& instruction);
|
||||
|
||||
protected:
|
||||
// Returns whether the given producer instruction should be fused into the
|
||||
// given consumer instruction. producer is necessarily an operand of consumer.
|
||||
@ -74,6 +72,10 @@ class InstructionFusion : public HloPassInterface {
|
||||
private:
|
||||
HloInstruction* Fuse(HloInstruction* producer, HloInstruction* consumer);
|
||||
|
||||
// Used to determine if an HLO is expensive. Expensive operations will not be
|
||||
// duplicated.
|
||||
std::function<bool(const HloInstruction& instruction)> is_expensive_;
|
||||
|
||||
// Returns whether we may duplicate an instruction if we want to fuse it.
|
||||
bool may_duplicate_;
|
||||
|
||||
|
@ -36,7 +36,9 @@ TEST_F(InstructionFusionTest,
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
EXPECT_EQ(broadcast2, computation->root_instruction());
|
||||
EXPECT_TRUE(
|
||||
InstructionFusion(/*may_duplicate=*/true).Run(module.get()).ValueOrDie());
|
||||
InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
|
||||
.Run(module.get())
|
||||
.ValueOrDie());
|
||||
EXPECT_EQ(broadcast2, computation->root_instruction());
|
||||
}
|
||||
|
||||
@ -55,7 +57,9 @@ TEST_F(InstructionFusionTest,
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
EXPECT_EQ(broadcast2, computation->root_instruction());
|
||||
EXPECT_TRUE(
|
||||
InstructionFusion(/*may_duplicate=*/true).Run(module.get()).ValueOrDie());
|
||||
InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
|
||||
.Run(module.get())
|
||||
.ValueOrDie());
|
||||
EXPECT_EQ(HloOpcode::kFusion, computation->root_instruction()->opcode());
|
||||
}
|
||||
|
||||
@ -73,7 +77,9 @@ TEST_F(InstructionFusionTest,
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
EXPECT_EQ(reshape2, computation->root_instruction());
|
||||
EXPECT_TRUE(
|
||||
InstructionFusion(/*may_duplicate=*/true).Run(module.get()).ValueOrDie());
|
||||
InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
|
||||
.Run(module.get())
|
||||
.ValueOrDie());
|
||||
EXPECT_EQ(HloOpcode::kFusion, computation->root_instruction()->opcode());
|
||||
}
|
||||
|
||||
@ -91,7 +97,9 @@ TEST_F(InstructionFusionTest,
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
EXPECT_EQ(transpose2, computation->root_instruction());
|
||||
EXPECT_TRUE(
|
||||
InstructionFusion(/*may_duplicate=*/true).Run(module.get()).ValueOrDie());
|
||||
InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
|
||||
.Run(module.get())
|
||||
.ValueOrDie());
|
||||
EXPECT_EQ(HloOpcode::kFusion, computation->root_instruction()->opcode());
|
||||
}
|
||||
|
||||
@ -106,7 +114,9 @@ TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfParameterUnfused) {
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
EXPECT_EQ(reshape1, computation->root_instruction());
|
||||
EXPECT_FALSE(
|
||||
InstructionFusion(/*may_duplicate=*/true).Run(module.get()).ValueOrDie());
|
||||
InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
|
||||
.Run(module.get())
|
||||
.ValueOrDie());
|
||||
}
|
||||
|
||||
TEST_F(InstructionFusionTest, PotentialBitcastSimpleReshapeOfParameterUnfused) {
|
||||
@ -120,7 +130,9 @@ TEST_F(InstructionFusionTest, PotentialBitcastSimpleReshapeOfParameterUnfused) {
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
EXPECT_EQ(reshape1, computation->root_instruction());
|
||||
EXPECT_FALSE(
|
||||
InstructionFusion(/*may_duplicate=*/true).Run(module.get()).ValueOrDie());
|
||||
InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
|
||||
.Run(module.get())
|
||||
.ValueOrDie());
|
||||
}
|
||||
|
||||
TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfParameterUnfused) {
|
||||
@ -134,7 +146,9 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfParameterUnfused) {
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
EXPECT_EQ(transpose1, computation->root_instruction());
|
||||
EXPECT_FALSE(
|
||||
InstructionFusion(/*may_duplicate=*/true).Run(module.get()).ValueOrDie());
|
||||
InstructionFusion(InstructionFusion::IsExpensive, /*may_duplicate=*/true)
|
||||
.Run(module.get())
|
||||
.ValueOrDie());
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -106,6 +106,7 @@ std::vector<std::pair<HloInstruction*, int64>> GetAllUsesOfInstructionAtIndex(
|
||||
// *) Is a loop fusion instruction where the only use of 'operand' at 'index'
|
||||
// in the set 'user.fused_instructions' is a DynamicUpdateSlice fused root
|
||||
// at operand 0.
|
||||
// *) Use of 'operand' is DynamicUpdateSlice at operand index 0.
|
||||
bool CanShareOperandBufferWithUser(
|
||||
HloInstruction* operand, const ShapeIndex& operand_index,
|
||||
HloInstruction* user, const ShapeIndex& user_index,
|
||||
@ -143,6 +144,11 @@ bool CanShareOperandBufferWithUser(
|
||||
break;
|
||||
}
|
||||
return false;
|
||||
} else if (user->opcode() == HloOpcode::kDynamicUpdateSlice) {
|
||||
// We eliminated other users in BufferLiveness::live_range_strictly_before,
|
||||
// so here we just need to check that the use is at operand index 0.
|
||||
std::vector<int64> operand_indices = user->OperandIndices(operand);
|
||||
return operand_indices.size() == 1 && operand_indices[0] == 0;
|
||||
}
|
||||
// Check if 'user' is element-wise.
|
||||
return user->IsElementwise();
|
||||
|
@ -256,8 +256,7 @@ StatusOr<std::vector<const Allocation*>> Service::ResolveAndValidateArguments(
|
||||
tensorflow::gtl::ArraySlice<const GlobalDataHandle*> arguments,
|
||||
const Backend* backend, int device_ordinal) {
|
||||
std::vector<const Allocation*> allocations;
|
||||
for (tensorflow::gtl::ArraySlice<const GlobalDataHandle*>::size_type i = 0;
|
||||
i < arguments.size(); ++i) {
|
||||
for (size_t i = 0; i < arguments.size(); ++i) {
|
||||
auto allocation_status = allocation_tracker_.Resolve(*arguments[i]);
|
||||
if (!allocation_status.ok()) {
|
||||
return Status(allocation_status.status().code(),
|
||||
@ -296,8 +295,7 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
|
||||
program_shape.parameters_size(), arguments.size());
|
||||
}
|
||||
|
||||
for (tensorflow::gtl::ArraySlice<const Allocation*>::size_type i = 0;
|
||||
i < arguments.size(); ++i) {
|
||||
for (size_t i = 0; i < arguments.size(); ++i) {
|
||||
// Verify that shape of arguments matches the shape of the arguments in the
|
||||
// ProgramShape.
|
||||
if (!ShapeUtil::Compatible(arguments[i]->shape(),
|
||||
@ -385,8 +383,7 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
|
||||
hlo_dumper, std::move(executors)));
|
||||
|
||||
if (!other_directory_path.empty()) {
|
||||
for (std::vector<VersionedComputationHandle>::size_type i = 0;
|
||||
i < versioned_handles.size(); ++i) {
|
||||
for (size_t i = 0; i < versioned_handles.size(); ++i) {
|
||||
executables[i]->set_session_module(std::move(session_modules[i]));
|
||||
}
|
||||
}
|
||||
|
@ -598,7 +598,9 @@ class FusionPointsToAnalysisTest : public TuplePointsToAnalysisTest {
|
||||
// Build computation and add it to module as entry computation.
|
||||
BuildModule(builder.Build());
|
||||
// Run instruction fusion HloPass.
|
||||
EXPECT_TRUE(InstructionFusion().Run(module_.get()).ValueOrDie());
|
||||
EXPECT_TRUE(InstructionFusion(InstructionFusion::IsExpensive)
|
||||
.Run(module_.get())
|
||||
.ValueOrDie());
|
||||
// Get computation root instruction (should be a kFusion).
|
||||
auto* fusion = module_->entry_computation()->root_instruction();
|
||||
EXPECT_EQ(HloOpcode::kFusion, fusion->opcode());
|
||||
|
@ -56,6 +56,8 @@ class ClientLibraryTestBase : public ::testing::Test {
|
||||
execution_options_.set_disable_fast_math(disabled);
|
||||
}
|
||||
|
||||
void SetSeed(uint64 seed) { execution_options_.set_seed(seed); }
|
||||
|
||||
// TODO(b/25566808): Add helper that populates a literal from a testdata file.
|
||||
|
||||
// Convenience methods for building and running a computation from a builder.
|
||||
|
@ -187,8 +187,13 @@ ExecutableBuildOptions LocalClientTestBase::DefaultExecutableBuildOptions()
|
||||
}
|
||||
|
||||
ExecutableRunOptions LocalClientTestBase::DefaultExecutableRunOptions() const {
|
||||
return ExecutableRunOptions().set_allocator(
|
||||
GetOrCreateAllocator(local_client_->platform()));
|
||||
ExecutableRunOptions run_options;
|
||||
run_options.set_inter_op_thread_pool(
|
||||
local_client_->backend().inter_op_thread_pool());
|
||||
run_options.set_intra_op_thread_pool(
|
||||
local_client_->backend().eigen_intra_op_thread_pool_device());
|
||||
run_options.set_allocator(GetOrCreateAllocator(local_client_->platform()));
|
||||
return run_options;
|
||||
}
|
||||
|
||||
std::unique_ptr<ScopedShapedBuffer> LocalClientTestBase::ExecuteLocallyOrDie(
|
||||
|
@ -53,6 +53,7 @@ void PrngTest::UniformTest(T a, T b, tensorflow::gtl::ArraySlice<int64> dims) {
|
||||
builder.ConstantR0<T>(a), builder.ConstantR0<T>(b),
|
||||
ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<T>(), dims));
|
||||
|
||||
SetSeed(42);
|
||||
auto actual = ExecuteAndTransferOrDie(&builder, /*arguments=*/{});
|
||||
EXPECT_TRUE(ContainersEqual(dims, actual->shape().dimensions()));
|
||||
LiteralUtil::EachCell<T>(*actual,
|
||||
@ -118,6 +119,7 @@ double PrngTest::UniformChiSquared(int32 range_size, int32 expected_count) {
|
||||
builder.ConstantR0<int32>(range_size),
|
||||
ShapeUtil::MakeShape(S32, {sample_size}));
|
||||
|
||||
SetSeed(42);
|
||||
auto actual = ExecuteAndTransferOrDie(&builder, /*arguments=*/{});
|
||||
std::vector<int32> counts(range_size, 0);
|
||||
LiteralUtil::EachCell<int32>(
|
||||
@ -264,6 +266,7 @@ XLA_TEST_F(PrngTest, TenValuesN01) {
|
||||
builder.RngNormal(builder.ConstantR0<float>(0), builder.ConstantR0<float>(1),
|
||||
ShapeUtil::MakeShape(F32, {10}));
|
||||
|
||||
SetSeed(42);
|
||||
ExecuteAndTransferOrDie(&builder, /*arguments=*/{});
|
||||
// TODO(b/25995601): Test that resultant values are reasonable
|
||||
}
|
||||
|
@ -91,6 +91,7 @@ cc_library(
|
||||
"//tensorflow/contrib/input_pipeline:input_pipeline_ops_op_lib",
|
||||
"//tensorflow/contrib/layers:bucketization_op_op_lib",
|
||||
"//tensorflow/contrib/layers:sparse_feature_cross_op_op_lib",
|
||||
"//tensorflow/contrib/tensor_forest:tensor_forest_ops_op_lib",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -37,7 +37,7 @@ tf_kernel_library(
|
||||
srcs = [
|
||||
"bigquery_reader_ops.cc",
|
||||
],
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":bigquery_table_accessor",
|
||||
":bigquery_table_partition_proto_cc",
|
||||
|
@ -419,7 +419,7 @@ class DirichletMultinomialTest(test.TestCase):
|
||||
with self.test_session() as sess:
|
||||
dist = ds.DirichletMultinomial(
|
||||
total_count=5.,
|
||||
concentration=2. * self._rng.rand(4, 3, 2).astype(np.float32))
|
||||
concentration=1. + 2. * self._rng.rand(4, 3, 2).astype(np.float32))
|
||||
n = int(3e3)
|
||||
x = dist.sample(n, seed=0)
|
||||
sample_mean = math_ops.reduce_mean(x, 0)
|
||||
@ -448,7 +448,7 @@ class DirichletMultinomialTest(test.TestCase):
|
||||
with self.test_session() as sess:
|
||||
dist = ds.DirichletMultinomial(
|
||||
total_count=5.,
|
||||
concentration=2. * self._rng.rand(4).astype(np.float32))
|
||||
concentration=1. + 2. * self._rng.rand(4).astype(np.float32))
|
||||
n = int(5e3)
|
||||
x = dist.sample(n, seed=0)
|
||||
sample_mean = math_ops.reduce_mean(x, 0)
|
||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
from scipy import stats
|
||||
from tensorflow.contrib import distributions
|
||||
from tensorflow.contrib.distributions.python.ops import bijectors
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
@ -50,6 +51,22 @@ class MultivariateNormalDiagTest(test.TestCase):
|
||||
dist = ds.MultivariateNormalDiag(mu, diag, validate_args=True)
|
||||
self.assertAllEqual([3, 1], dist.sample(3).get_shape())
|
||||
|
||||
def testDistWithBatchShapeOneThenTransformedThroughSoftplus(self):
|
||||
# This complex combination of events resulted in a loss of static shape
|
||||
# information when tensor_util.constant_value(self._needs_rotation) was
|
||||
# being used incorrectly (resulting in always rotating).
|
||||
# Batch shape = [1], event shape = [3]
|
||||
mu = array_ops.zeros((1, 3))
|
||||
diag = array_ops.ones((1, 3))
|
||||
with self.test_session():
|
||||
base_dist = ds.MultivariateNormalDiag(mu, diag, validate_args=True)
|
||||
dist = ds.TransformedDistribution(
|
||||
base_dist,
|
||||
validate_args=True,
|
||||
bijector=bijectors.Softplus(event_ndims=1))
|
||||
samps = dist.sample(5) # Shape [5, 1, 3].
|
||||
self.assertAllEqual([5, 1], dist.log_prob(samps).get_shape())
|
||||
|
||||
def testMean(self):
|
||||
mu = [-1., 1]
|
||||
diag = [1., -5]
|
||||
|
@ -546,7 +546,8 @@ class TransformedDistribution(distributions.Distribution):
|
||||
|
||||
def _maybe_rotate_dims(self, x, rotate_right=False):
|
||||
"""Helper which rolls left event_dims left or right event_dims right."""
|
||||
if tensor_util.constant_value(self._needs_rotation) is False:
|
||||
needs_rotation_const = tensor_util.constant_value(self._needs_rotation)
|
||||
if needs_rotation_const is not None and not needs_rotation_const:
|
||||
return x
|
||||
ndims = array_ops.rank(x)
|
||||
n = (ndims - self._rotate_ndims) if rotate_right else self._rotate_ndims
|
||||
|
@ -35,6 +35,7 @@ See the @{$python/contrib.layers} guide.
|
||||
@@relu6
|
||||
@@repeat
|
||||
@@safe_embedding_lookup_sparse
|
||||
@@scale_gradient
|
||||
@@separable_conv2d
|
||||
@@separable_convolution2d
|
||||
@@softmax
|
||||
@ -68,6 +69,7 @@ See the @{$python/contrib.layers} guide.
|
||||
@@embedding_column
|
||||
@@scattered_embedding_column
|
||||
@@input_from_feature_columns
|
||||
@@transform_features
|
||||
@@joint_weighted_sum_from_feature_columns
|
||||
@@make_place_holder_tensors_for_base_features
|
||||
@@multi_class_target
|
||||
|
@ -1248,21 +1248,29 @@ def scattered_embedding_column(column_name,
|
||||
initializer=None):
|
||||
"""Creates an embedding column of a sparse feature using parameter hashing.
|
||||
|
||||
The i-th embedding component of a value v is found by retrieving an
|
||||
embedding weight whose index is a fingerprint of the pair (v,i).
|
||||
This is a useful shorthand when you have a sparse feature you want to use an
|
||||
embedding for, but also want to hash the embedding's values in each dimension
|
||||
to a variable based on a different hash.
|
||||
|
||||
Specifically, the i-th embedding component of a value v is found by retrieving
|
||||
an embedding weight whose index is a fingerprint of the pair (v,i).
|
||||
|
||||
An embedding column with sparse_column_with_hash_bucket such as
|
||||
embedding_column(
|
||||
|
||||
embedding_column(
|
||||
sparse_column_with_hash_bucket(column_name, bucket_size),
|
||||
dimension)
|
||||
|
||||
could be replaced by
|
||||
scattered_embedding_column(
|
||||
column_name, size=bucket_size * dimension, dimension=dimension,
|
||||
|
||||
scattered_embedding_column(
|
||||
column_name,
|
||||
size=bucket_size * dimension,
|
||||
dimension=dimension,
|
||||
hash_key=tf.contrib.layers.SPARSE_FEATURE_CROSS_DEFAULT_HASH_KEY)
|
||||
|
||||
for the same number of embedding parameters and hopefully reduced impact of
|
||||
collisions with a cost of slowing down training.
|
||||
for the same number of embedding parameters. This should hopefully reduce the
|
||||
impact of collisions, but adds the cost of slowing down training.
|
||||
|
||||
Args:
|
||||
column_name: A string defining sparse column name.
|
||||
|
@ -144,6 +144,7 @@ def _input_from_feature_columns(columns_to_tensors,
|
||||
output_rank,
|
||||
default_name):
|
||||
"""Implementation of `input_from(_sequence)_feature_columns`."""
|
||||
columns_to_tensors = columns_to_tensors.copy()
|
||||
check_feature_columns(feature_columns)
|
||||
with variable_scope.variable_scope(scope,
|
||||
default_name=default_name,
|
||||
@ -430,6 +431,7 @@ def joint_weighted_sum_from_feature_columns(columns_to_tensors,
|
||||
ValueError: if FeatureColumn cannot be used for linear predictions.
|
||||
|
||||
"""
|
||||
columns_to_tensors = columns_to_tensors.copy()
|
||||
check_feature_columns(feature_columns)
|
||||
with variable_scope.variable_scope(
|
||||
scope,
|
||||
@ -518,6 +520,7 @@ def weighted_sum_from_feature_columns(columns_to_tensors,
|
||||
Raises:
|
||||
ValueError: if FeatureColumn cannot be used for linear predictions.
|
||||
"""
|
||||
columns_to_tensors = columns_to_tensors.copy()
|
||||
check_feature_columns(feature_columns)
|
||||
with variable_scope.variable_scope(
|
||||
scope,
|
||||
@ -684,8 +687,8 @@ def transform_features(features, feature_columns):
|
||||
Returns:
|
||||
A `dict` mapping FeatureColumn to `Tensor` and `SparseTensor` values.
|
||||
"""
|
||||
check_feature_columns(feature_columns)
|
||||
columns_to_tensor = features.copy()
|
||||
check_feature_columns(feature_columns)
|
||||
transformer = _Transformer(columns_to_tensor)
|
||||
for column in sorted(set(feature_columns), key=lambda x: x.key):
|
||||
transformer.transform(column)
|
||||
|
@ -187,27 +187,28 @@ class TransformerTest(test.TestCase):
|
||||
self.assertAllEqual(output.dense_shape.eval(), [2, 2])
|
||||
|
||||
def testEmbeddingColumn(self):
|
||||
hashed_sparse = feature_column.sparse_column_with_hash_bucket("wire", 10)
|
||||
wire_tensor = sparse_tensor.SparseTensor(
|
||||
values=["omar", "stringer", "marlo"],
|
||||
indices=[[0, 0], [1, 0], [1, 1]],
|
||||
dense_shape=[2, 2])
|
||||
features = {"wire": wire_tensor}
|
||||
output = feature_column_ops._Transformer(features).transform(
|
||||
feature_column.embedding_column(hashed_sparse, 10))
|
||||
expected = feature_column_ops._Transformer(features).transform(
|
||||
hashed_sparse)
|
||||
with self.test_session():
|
||||
self.assertAllEqual(output.values.eval(), expected.values.eval())
|
||||
self.assertAllEqual(output.indices.eval(), expected.indices.eval())
|
||||
self.assertAllEqual(output.dense_shape.eval(),
|
||||
expected.dense_shape.eval())
|
||||
hashed_sparse = feature_column.sparse_column_with_hash_bucket("wire", 10)
|
||||
wire_embedding = feature_column.embedding_column(hashed_sparse, 10)
|
||||
|
||||
# Test transform features.
|
||||
output = feature_column_ops.transform_features(
|
||||
features=features, feature_columns=[hashed_sparse])
|
||||
self.assertEqual(len(output), 1)
|
||||
features=features, feature_columns=[hashed_sparse, wire_embedding])
|
||||
# Check that features dict haven't changed
|
||||
self.assertEqual({"wire": wire_tensor}, features)
|
||||
self.assertEqual(len(output), 2)
|
||||
self.assertIn(hashed_sparse, output)
|
||||
self.assertIn(wire_embedding, output)
|
||||
with self.test_session():
|
||||
self.assertAllEqual(output[wire_embedding].indices.eval(),
|
||||
wire_tensor.indices.eval())
|
||||
self.assertAllEqual(output[wire_embedding].dense_shape.eval(), [2, 2])
|
||||
self.assertAllEqual(output[wire_embedding].values.eval(),
|
||||
output[hashed_sparse].values.eval())
|
||||
|
||||
def testSparseColumnWithKeys(self):
|
||||
keys_sparse = feature_column.sparse_column_with_keys(
|
||||
|
@ -28,6 +28,7 @@ from tensorflow.contrib.framework.python.ops import variables
|
||||
from tensorflow.contrib.layers.python.layers import initializers
|
||||
from tensorflow.contrib.layers.python.layers import utils
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import function
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.layers import convolutional as convolutional_layers
|
||||
@ -68,6 +69,7 @@ __all__ = ['avg_pool2d',
|
||||
'relu',
|
||||
'relu6',
|
||||
'repeat',
|
||||
'scale_gradient',
|
||||
'separable_conv2d',
|
||||
'separable_convolution2d',
|
||||
'softmax',
|
||||
@ -1745,6 +1747,48 @@ def repeat(inputs, repetitions, layer, *args, **kwargs):
|
||||
return outputs
|
||||
|
||||
|
||||
def _scale_gradient_shape(op):
|
||||
"""Shape helper function for scale_gradient function below."""
|
||||
return [op.inputs[0].shape]
|
||||
|
||||
|
||||
def _scale_gradient_grad(op, grad):
|
||||
"""Python gradient helper function for scale_gradient function below."""
|
||||
return [grad * op.inputs[1], None]
|
||||
|
||||
|
||||
@function.Defun(python_grad_func=_scale_gradient_grad,
|
||||
shape_func=_scale_gradient_shape)
|
||||
def scale_gradient(inputs, gradient_multiplier):
|
||||
"""Identity operation, but with the gradient multiplied by a tensor.
|
||||
|
||||
The TensorFlow gradient system will compute the gradient with respect to
|
||||
`inputs` as the product of the gradient with respect to the `output`
|
||||
multiplied by a specified `gradient_multiplier` tensor. If
|
||||
`gradient_multiplier` is equal to 1, then this results in the true gradient.
|
||||
Otherwise, it results in a scaled gradient.
|
||||
|
||||
This can be useful for adjusting the relative learning rate of different
|
||||
parameter tensors when performing gradient descent, and because this rescaling
|
||||
can be inserted at arbitrary locations within a graph, is often more
|
||||
convenient to apply than simply rescaling the final computed gradients.
|
||||
|
||||
Args:
|
||||
inputs: Tensor to be output.
|
||||
gradient_multiplier: Tensor by which to multiply the gradient with respect
|
||||
to `output` to compute the gradient with respect to `inputs`. Its shape
|
||||
must be broadcastable to the shape of `inputs`.
|
||||
|
||||
Returns:
|
||||
output Tensor, equal to `inputs`.
|
||||
"""
|
||||
# gradient_multiplier is implicitly saved by decorator, and only used for
|
||||
# gradient computation.
|
||||
del gradient_multiplier
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
@add_arg_scope
|
||||
def separable_convolution2d(
|
||||
inputs,
|
||||
|
@ -2980,6 +2980,22 @@ class SeparableConv2dTest(test.TestCase):
|
||||
sess.run(net, feed_dict={images_placeholder: images})
|
||||
|
||||
|
||||
class ScaleGradientTests(test.TestCase):
|
||||
"""Simple tests of the scale_gradient function."""
|
||||
|
||||
def testBasic(self):
|
||||
with self.test_session():
|
||||
x = np.array([42], np.float32)
|
||||
gradient_scale = np.array([2], np.float32)
|
||||
|
||||
x = ops.convert_to_tensor(x)
|
||||
y = layers_lib.scale_gradient(x, gradient_scale)
|
||||
|
||||
np.testing.assert_array_equal(x.eval(), y.eval())
|
||||
g_x, = gradients_impl.gradients(y, [x], [np.array([3], np.float32)])
|
||||
np.testing.assert_array_equal([3 * 2], g_x.eval())
|
||||
|
||||
|
||||
class SoftmaxTests(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
@ -165,8 +165,7 @@ def _dnn_linear_combined_model_fn(features, labels, mode, params, config=None):
|
||||
* embedding_lr_multipliers: Optional. A dictionary from
|
||||
`EmbeddingColumn` to a `float` multiplier. Multiplier will be used to
|
||||
multiply with learning rate for the embedding variables.
|
||||
* input_layer_min_slice_size: Optional. The min slice size of input layer
|
||||
partitions. If not provided, will use the default of 64M.
|
||||
* input_layer_partitioner: Optional. Partitioner for input layer.
|
||||
config: `RunConfig` object to configure the runtime settings.
|
||||
|
||||
Returns:
|
||||
@ -174,7 +173,7 @@ def _dnn_linear_combined_model_fn(features, labels, mode, params, config=None):
|
||||
|
||||
Raises:
|
||||
ValueError: If both `linear_feature_columns` and `dnn_features_columns`
|
||||
are empty at the same time.
|
||||
are empty at the same time, or `input_layer_partitioner` is missing.
|
||||
"""
|
||||
head = params["head"]
|
||||
linear_feature_columns = params.get("linear_feature_columns")
|
||||
@ -186,9 +185,11 @@ def _dnn_linear_combined_model_fn(features, labels, mode, params, config=None):
|
||||
dnn_activation_fn = params.get("dnn_activation_fn") or nn.relu
|
||||
dnn_dropout = params.get("dnn_dropout")
|
||||
gradient_clip_norm = params.get("gradient_clip_norm")
|
||||
input_layer_min_slice_size = (
|
||||
params.get("input_layer_min_slice_size") or 64 << 20)
|
||||
num_ps_replicas = config.num_ps_replicas if config else 0
|
||||
input_layer_partitioner = params.get("input_layer_partitioner") or (
|
||||
partitioned_variables.min_max_variable_partitioner(
|
||||
max_partitions=num_ps_replicas,
|
||||
min_slice_size=64 << 20))
|
||||
embedding_lr_multipliers = params.get("embedding_lr_multipliers", {})
|
||||
fix_global_step_increment_bug = params.get(
|
||||
"fix_global_step_increment_bug", True)
|
||||
@ -221,10 +222,6 @@ def _dnn_linear_combined_model_fn(features, labels, mode, params, config=None):
|
||||
dnn_parent_scope,
|
||||
values=tuple(six.itervalues(features)),
|
||||
partitioner=dnn_partitioner):
|
||||
input_layer_partitioner = (
|
||||
partitioned_variables.min_max_variable_partitioner(
|
||||
max_partitions=num_ps_replicas,
|
||||
min_slice_size=input_layer_min_slice_size))
|
||||
with variable_scope.variable_scope(
|
||||
"input_from_feature_columns",
|
||||
values=tuple(six.itervalues(features)),
|
||||
@ -387,7 +384,8 @@ class DNNLinearCombinedEstimator(estimator.Estimator):
|
||||
config=None,
|
||||
feature_engineering_fn=None,
|
||||
embedding_lr_multipliers=None,
|
||||
fix_global_step_increment_bug=False):
|
||||
fix_global_step_increment_bug=False,
|
||||
input_layer_partitioner=None):
|
||||
"""Initializes a DNNLinearCombinedEstimator instance.
|
||||
|
||||
Note: New users must set `fix_global_step_increment_bug=True` when creating
|
||||
@ -432,6 +430,7 @@ class DNNLinearCombinedEstimator(estimator.Estimator):
|
||||
steps to optimize both linear and dnn parts. If `True`, this bug is
|
||||
fixed. New users must set this to `True`, but the default value is
|
||||
`False` for backwards compatibility.
|
||||
input_layer_partitioner: Optional. Partitioner for input layer.
|
||||
|
||||
Raises:
|
||||
ValueError: If both linear_feature_columns and dnn_features_columns are
|
||||
@ -459,6 +458,7 @@ class DNNLinearCombinedEstimator(estimator.Estimator):
|
||||
"gradient_clip_norm": gradient_clip_norm,
|
||||
"embedding_lr_multipliers": embedding_lr_multipliers,
|
||||
"fix_global_step_increment_bug": fix_global_step_increment_bug,
|
||||
"input_layer_partitioner": input_layer_partitioner
|
||||
},
|
||||
feature_engineering_fn=feature_engineering_fn)
|
||||
|
||||
@ -602,19 +602,25 @@ class DNNLinearCombinedClassifier(estimator.Estimator):
|
||||
ValueError: If both `linear_feature_columns` and `dnn_features_columns`
|
||||
are empty at the same time.
|
||||
"""
|
||||
if n_classes < 2:
|
||||
raise ValueError("n_classes should be greater than 1. Given: {}".format(
|
||||
n_classes))
|
||||
head = head_lib.multi_class_head(
|
||||
n_classes=n_classes,
|
||||
weight_column_name=weight_column_name,
|
||||
enable_centered_bias=enable_centered_bias)
|
||||
linear_feature_columns = tuple(linear_feature_columns or [])
|
||||
dnn_feature_columns = tuple(dnn_feature_columns or [])
|
||||
self._feature_columns = linear_feature_columns + dnn_feature_columns
|
||||
if not self._feature_columns:
|
||||
raise ValueError("Either linear_feature_columns or dnn_feature_columns "
|
||||
"must be defined.")
|
||||
head = head_lib.multi_class_head(
|
||||
n_classes=n_classes,
|
||||
weight_column_name=weight_column_name,
|
||||
enable_centered_bias=enable_centered_bias)
|
||||
|
||||
# TODO(b/35922130): Replace with `input_layer_partitioner` arg.
|
||||
input_layer_partitioner = None
|
||||
if input_layer_min_slice_size is not None:
|
||||
input_layer_partitioner = (
|
||||
partitioned_variables.min_max_variable_partitioner(
|
||||
max_partitions=config.num_ps_replicas if config else 0,
|
||||
min_slice_size=input_layer_min_slice_size))
|
||||
|
||||
super(DNNLinearCombinedClassifier, self).__init__(
|
||||
model_fn=_dnn_linear_combined_model_fn,
|
||||
model_dir=model_dir,
|
||||
@ -631,7 +637,7 @@ class DNNLinearCombinedClassifier(estimator.Estimator):
|
||||
"dnn_dropout": dnn_dropout,
|
||||
"gradient_clip_norm": gradient_clip_norm,
|
||||
"embedding_lr_multipliers": embedding_lr_multipliers,
|
||||
"input_layer_min_slice_size": input_layer_min_slice_size,
|
||||
"input_layer_partitioner": input_layer_partitioner,
|
||||
"fix_global_step_increment_bug": fix_global_step_increment_bug,
|
||||
},
|
||||
feature_engineering_fn=feature_engineering_fn)
|
||||
@ -916,6 +922,15 @@ class DNNLinearCombinedRegressor(estimator.Estimator):
|
||||
if not self._feature_columns:
|
||||
raise ValueError("Either linear_feature_columns or dnn_feature_columns "
|
||||
"must be defined.")
|
||||
|
||||
# TODO(b/35922130): Replace with `input_layer_partitioner` arg.
|
||||
input_layer_partitioner = None
|
||||
if input_layer_min_slice_size is not None:
|
||||
input_layer_partitioner = (
|
||||
partitioned_variables.min_max_variable_partitioner(
|
||||
max_partitions=config.num_ps_replicas if config else 0,
|
||||
min_slice_size=input_layer_min_slice_size))
|
||||
|
||||
head = head_lib.regression_head(
|
||||
weight_column_name=weight_column_name,
|
||||
label_dimension=label_dimension,
|
||||
@ -936,7 +951,7 @@ class DNNLinearCombinedRegressor(estimator.Estimator):
|
||||
"dnn_dropout": dnn_dropout,
|
||||
"gradient_clip_norm": gradient_clip_norm,
|
||||
"embedding_lr_multipliers": embedding_lr_multipliers,
|
||||
"input_layer_min_slice_size": input_layer_min_slice_size,
|
||||
"input_layer_partitioner": input_layer_partitioner,
|
||||
"fix_global_step_increment_bug": fix_global_step_increment_bug,
|
||||
},
|
||||
feature_engineering_fn=feature_engineering_fn)
|
||||
|
@ -341,7 +341,7 @@ class DNNLinearCombinedClassifierTest(test.TestCase):
|
||||
input_layer_min_slice_size=1)
|
||||
|
||||
# Ensure the param is passed in.
|
||||
self.assertEqual(1, classifier.params['input_layer_min_slice_size'])
|
||||
self.assertTrue(callable(classifier.params['input_layer_partitioner']))
|
||||
|
||||
# Ensure the partition count is 10.
|
||||
classifier.fit(input_fn=_input_fn_float_label, steps=50)
|
||||
|
@ -188,6 +188,10 @@ def build_sequence_input(features,
|
||||
A `Tensor` of dtype `float32` and shape `[batch_size, padded_length, ?]`.
|
||||
This will be used as input to an RNN.
|
||||
"""
|
||||
features = features.copy()
|
||||
features.update(layers.transform_features(
|
||||
features,
|
||||
list(sequence_feature_columns) + list(context_feature_columns or [])))
|
||||
sequence_input = layers.sequence_input_from_feature_columns(
|
||||
columns_to_tensors=features,
|
||||
feature_columns=sequence_feature_columns,
|
||||
|
@ -378,6 +378,8 @@ class BaseEstimator(
|
||||
self._model_dir = tempfile.mkdtemp()
|
||||
logging.warning('Using temporary folder as model directory: %s',
|
||||
self._model_dir)
|
||||
if self._config.model_dir is None:
|
||||
self._config = self._config.replace(model_dir=self._model_dir)
|
||||
|
||||
# Set device function depending if there are replicas or not.
|
||||
self._device_fn = _get_replica_device_setter(self._config)
|
||||
|
@ -481,6 +481,7 @@ class EstimatorTest(test.TestCase):
|
||||
est = estimator.Estimator(model_fn=linear_model_fn,
|
||||
config=config)
|
||||
self.assertEqual('test_dir', est.config.model_dir)
|
||||
self.assertEqual('test_dir', est.model_dir)
|
||||
|
||||
def testModelDirAndRunConfigModelDir(self):
|
||||
config = run_config.RunConfig(model_dir='test_dir')
|
||||
@ -489,11 +490,30 @@ class EstimatorTest(test.TestCase):
|
||||
model_dir='test_dir')
|
||||
self.assertEqual('test_dir', est.config.model_dir)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError,
|
||||
'model_dir are set both in constructor and RunConfig, '
|
||||
'but with different'):
|
||||
estimator.Estimator(model_fn=linear_model_fn,
|
||||
config=config,
|
||||
model_dir='different_dir')
|
||||
|
||||
def testModelDirIsCopiedToRunConfig(self):
|
||||
config = run_config.RunConfig()
|
||||
self.assertIsNone(config.model_dir)
|
||||
|
||||
est = estimator.Estimator(model_fn=linear_model_fn,
|
||||
model_dir='test_dir',
|
||||
config=config)
|
||||
self.assertEqual('test_dir', est.config.model_dir)
|
||||
self.assertEqual('test_dir', est.model_dir)
|
||||
|
||||
def testModelDirAsTempDir(self):
|
||||
with test.mock.patch.object(tempfile, 'mkdtemp', return_value='temp_dir'):
|
||||
est = estimator.Estimator(model_fn=linear_model_fn)
|
||||
self.assertEqual('temp_dir', est.config.model_dir)
|
||||
self.assertEqual('temp_dir', est.model_dir)
|
||||
|
||||
def testCheckInputs(self):
|
||||
est = estimator.SKCompat(estimator.Estimator(model_fn=linear_model_fn))
|
||||
# Lambdas so we have to different objects to compare
|
||||
|
@ -231,6 +231,8 @@ def sdca_model_fn(features, labels, mode, params):
|
||||
|
||||
with variable_scope.variable_op_scope(
|
||||
features.values(), parent_scope) as scope:
|
||||
features = features.copy()
|
||||
features.update(layers.transform_features(features, feature_columns))
|
||||
logits, columns_to_variables, bias = (
|
||||
layers.weighted_sum_from_feature_columns(
|
||||
columns_to_tensors=features,
|
||||
|
@ -18,12 +18,14 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.contrib.framework.python.framework import experimental
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.training import server_lib
|
||||
|
||||
@ -291,7 +293,7 @@ class RunConfig(ClusterConfig):
|
||||
|
||||
new_copy = copy.deepcopy(self)
|
||||
|
||||
# TODO(xiejw): Allow more fields, such as the user allowed changed ones.
|
||||
# TODO(b/33295821): Allow more fields to be replaced.
|
||||
for key, new_value in six.iteritems(kwargs):
|
||||
if key == 'model_dir':
|
||||
new_copy._model_dir = new_value # pylint: disable=protected-access
|
||||
@ -301,6 +303,24 @@ class RunConfig(ClusterConfig):
|
||||
|
||||
return new_copy
|
||||
|
||||
@experimental
|
||||
def uid(self):
|
||||
"""Generates a 'Unique Identifier' based on all internal fields.
|
||||
|
||||
Caller should use the uid string to check `RunConfig` instance integrity
|
||||
in one session use, but should not rely on the implementation details, which
|
||||
is subject to change.
|
||||
|
||||
Returns:
|
||||
A uid string.
|
||||
"""
|
||||
# TODO(b/33295821): Allows user to specify a whitelist.
|
||||
state = {k: v for k, v in self.__dict__.items() if not k.startswith('__')}
|
||||
ordered_state = collections.OrderedDict(
|
||||
sorted(state.items(), key=lambda t: t[0]))
|
||||
return ', '.join(
|
||||
'%s=%r' % (k, v) for (k, v) in six.iteritems(ordered_state))
|
||||
|
||||
@property
|
||||
def model_dir(self):
|
||||
return self._model_dir
|
||||
|
@ -238,6 +238,18 @@ class RunConfigTest(test.TestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
config.replace(some_undefined_property=RANDOM_SEED)
|
||||
|
||||
def test_uid(self):
|
||||
config = run_config.RunConfig(
|
||||
tf_random_seed=RANDOM_SEED, model_dir=TEST_DIR)
|
||||
|
||||
expected_uid = config.uid()
|
||||
# Check for 10 times, which should prove something.
|
||||
for _ in range(10):
|
||||
self.assertEqual(expected_uid, config.uid())
|
||||
|
||||
new_config = config.replace(model_dir=ANOTHER_TEST_DIR)
|
||||
self.assertNotEqual(expected_uid, new_config.uid())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -42,9 +42,7 @@ from tensorflow.python.ops import logging_ops
|
||||
from tensorflow.python.ops import resources
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import basic_session_run_hooks
|
||||
from tensorflow.python.training import coordinator
|
||||
from tensorflow.python.training import monitored_session
|
||||
from tensorflow.python.training import queue_runner
|
||||
from tensorflow.python.training import saver as tf_saver
|
||||
from tensorflow.python.training import session_manager as session_manager_lib
|
||||
@ -121,205 +119,6 @@ def _run_with_monitors(session, step, tensors, feed_dict, monitors):
|
||||
return outputs, should_stop
|
||||
|
||||
|
||||
def _monitored_train(graph,
|
||||
output_dir,
|
||||
train_op,
|
||||
loss_op,
|
||||
global_step_tensor=None,
|
||||
init_op=None,
|
||||
init_feed_dict=None,
|
||||
init_fn=None,
|
||||
log_every_steps=10,
|
||||
supervisor_is_chief=True,
|
||||
supervisor_master='',
|
||||
supervisor_save_model_secs=600,
|
||||
supervisor_save_model_steps=None,
|
||||
keep_checkpoint_max=5,
|
||||
keep_checkpoint_every_n_hours=10000.0,
|
||||
supervisor_save_summaries_secs=None,
|
||||
supervisor_save_summaries_steps=100,
|
||||
feed_fn=None,
|
||||
steps=None,
|
||||
fail_on_nan_loss=True,
|
||||
hooks=None,
|
||||
max_steps=None):
|
||||
"""Train a model via monitored_session.
|
||||
|
||||
Given `graph`, a directory to write outputs to (`output_dir`), and some ops,
|
||||
run a training loop. The given `train_op` performs one step of training on the
|
||||
model. The `loss_op` represents the objective function of the training. It is
|
||||
expected to increment the `global_step_tensor`, a scalar integer tensor
|
||||
counting training steps. This function uses `Supervisor` to initialize the
|
||||
graph (from a checkpoint if one is available in `output_dir`), write summaries
|
||||
defined in the graph, and write regular checkpoints as defined by
|
||||
`supervisor_save_model_secs`.
|
||||
|
||||
Training continues until `global_step_tensor` evaluates to `max_steps`, or, if
|
||||
`fail_on_nan_loss`, until `loss_op` evaluates to `NaN`. In that case the
|
||||
program is terminated with exit code 1.
|
||||
|
||||
Args:
|
||||
graph: A graph to train. It is expected that this graph is not in use
|
||||
elsewhere.
|
||||
output_dir: A directory to write outputs to.
|
||||
train_op: An op that performs one training step when run.
|
||||
loss_op: A scalar loss tensor.
|
||||
global_step_tensor: A tensor representing the global step. If none is given,
|
||||
one is extracted from the graph using the same logic as in `Supervisor`.
|
||||
init_op: An op that initializes the graph. If `None`, use `Supervisor`'s
|
||||
default.
|
||||
init_feed_dict: A dictionary that maps `Tensor` objects to feed values.
|
||||
This feed dictionary will be used when `init_op` is evaluated.
|
||||
init_fn: Optional callable passed to Supervisor to initialize the model.
|
||||
log_every_steps: Output logs regularly. The logs contain timing data and the
|
||||
current loss. A `0` or negative value disables logging.
|
||||
supervisor_is_chief: Whether the current process is the chief supervisor in
|
||||
charge of restoring the model and running standard services.
|
||||
supervisor_master: The master string to use when preparing the session.
|
||||
supervisor_save_model_secs: Save checkpoints every this many seconds. Can
|
||||
not be specified with `supervisor_save_model_steps`.
|
||||
supervisor_save_model_steps: Save checkpoints every this many steps. Can not
|
||||
be specified with `supervisor_save_model_secs`.
|
||||
keep_checkpoint_max: The maximum number of recent checkpoint files to
|
||||
keep. As new files are created, older files are deleted. If None or 0,
|
||||
all checkpoint files are kept. This is simply passed as the max_to_keep
|
||||
arg to `tf.train.Saver` constructor.
|
||||
keep_checkpoint_every_n_hours: In addition to keeping the most recent
|
||||
`keep_checkpoint_max` checkpoint files, you might want to keep one checkpoint file
|
||||
for every N hours of training. This can be useful if you want to later
|
||||
analyze how a model progressed during a long training session. For
|
||||
example, passing `keep_checkpoint_every_n_hours=2` ensures that you keep
|
||||
one checkpoint file for every 2 hours of training. The default value of
|
||||
10,000 hours effectively disables the feature.
|
||||
supervisor_save_summaries_secs: Save summaries every
|
||||
`supervisor_save_summaries_secs` seconds when training.
|
||||
supervisor_save_summaries_steps: Save summaries every
|
||||
`supervisor_save_summaries_steps` steps when training. Exactly one of
|
||||
`supervisor_save_model_steps` and `supervisor_save_model_secs` should be
|
||||
specified, and the other should be None.
|
||||
feed_fn: A function that is called every iteration to produce a `feed_dict`
|
||||
passed to `session.run` calls. Optional.
|
||||
steps: Trains for this many steps (e.g. current global step + `steps`).
|
||||
fail_on_nan_loss: If true, raise `NanLossDuringTrainingError` if `loss_op`
|
||||
evaluates to `NaN`. If false, continue training as if nothing happened.
|
||||
hooks: List of `SessionRunHook` subclass instances. Used for callbacks
|
||||
inside the training loop.
|
||||
max_steps: Number of total steps for which to train model. If `None`,
|
||||
train forever. Two calls fit(steps=100) means 200 training iterations.
|
||||
On the other hand two calls of fit(max_steps=100) means, second call
|
||||
will not do any iteration since first call did all 100 steps.
|
||||
|
||||
Returns:
|
||||
The final loss value.
|
||||
|
||||
Raises:
|
||||
ValueError: If `output_dir`, `train_op`, `loss_op`, or `global_step_tensor`
|
||||
is not provided. See `tf.contrib.framework.get_global_step` for how we
|
||||
look up the latter if not provided explicitly.
|
||||
NanLossDuringTrainingError: If `fail_on_nan_loss` is `True`, and loss ever
|
||||
evaluates to `NaN`.
|
||||
ValueError: If both `steps` and `max_steps` are not `None`.
|
||||
"""
|
||||
if (steps is not None) and (max_steps is not None):
|
||||
raise ValueError('Can not provide both steps and max_steps.')
|
||||
if not output_dir:
|
||||
raise ValueError('Output directory should be non-empty %s.' % output_dir)
|
||||
if train_op is None:
|
||||
raise ValueError('Missing train_op.')
|
||||
if loss_op is None:
|
||||
raise ValueError('Missing loss_op.')
|
||||
if hooks is None:
|
||||
hooks = []
|
||||
if not isinstance(hooks, list):
|
||||
raise ValueError('Hooks should be a list.')
|
||||
with graph.as_default():
|
||||
global_step_tensor = contrib_variables.assert_or_get_global_step(
|
||||
graph, global_step_tensor)
|
||||
if global_step_tensor is None:
|
||||
raise ValueError('No "global_step" was provided or found in the graph.')
|
||||
|
||||
if max_steps is not None:
|
||||
try:
|
||||
start_step = load_variable(output_dir, global_step_tensor.name)
|
||||
if max_steps <= start_step:
|
||||
logging.info('Skipping training since max_steps has already saved.')
|
||||
return None
|
||||
except: # pylint: disable=bare-except
|
||||
pass
|
||||
|
||||
# Adapted SessionRunHooks such as ExportMonitor depend on the
|
||||
# CheckpointSaverHook to be executed before they should be executed.
|
||||
# The `hooks` param comprises of deprecated monitor hooks
|
||||
# (such as ExportMonitor). Appending them after the basic_session_run_hooks.
|
||||
all_hooks = []
|
||||
with graph.as_default():
|
||||
all_hooks.append(basic_session_run_hooks.NanTensorHook(
|
||||
loss_op, fail_on_nan_loss=fail_on_nan_loss))
|
||||
if log_every_steps > 0:
|
||||
all_hooks.append(basic_session_run_hooks.LoggingTensorHook({
|
||||
'loss': loss_op.name,
|
||||
'step': global_step_tensor.name
|
||||
}, every_n_iter=log_every_steps))
|
||||
|
||||
def make_saver():
|
||||
return tf_saver.Saver(
|
||||
sharded=True,
|
||||
max_to_keep=keep_checkpoint_max,
|
||||
keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
|
||||
defer_build=True)
|
||||
|
||||
scaffold = monitored_session.Scaffold(
|
||||
init_op=init_op,
|
||||
init_feed_dict=init_feed_dict,
|
||||
init_fn=init_fn,
|
||||
saver=monitored_session.Scaffold.get_or_default('saver',
|
||||
ops.GraphKeys.SAVERS,
|
||||
make_saver))
|
||||
|
||||
if not supervisor_is_chief:
|
||||
session_creator = monitored_session.WorkerSessionCreator(
|
||||
scaffold=scaffold,
|
||||
master=supervisor_master)
|
||||
else:
|
||||
session_creator = monitored_session.ChiefSessionCreator(
|
||||
scaffold=scaffold,
|
||||
checkpoint_dir=output_dir,
|
||||
master=supervisor_master)
|
||||
summary_writer = summary_io.SummaryWriterCache.get(output_dir)
|
||||
all_hooks.append(
|
||||
basic_session_run_hooks.StepCounterHook(
|
||||
summary_writer=summary_writer))
|
||||
all_hooks.append(
|
||||
basic_session_run_hooks.SummarySaverHook(
|
||||
save_secs=supervisor_save_summaries_secs,
|
||||
save_steps=supervisor_save_summaries_steps,
|
||||
summary_writer=summary_writer,
|
||||
scaffold=scaffold))
|
||||
if (supervisor_save_model_secs is not None
|
||||
or supervisor_save_model_steps is not None):
|
||||
all_hooks.append(
|
||||
basic_session_run_hooks.CheckpointSaverHook(
|
||||
output_dir,
|
||||
save_secs=supervisor_save_model_secs,
|
||||
save_steps=supervisor_save_model_steps,
|
||||
scaffold=scaffold))
|
||||
|
||||
if steps is not None or max_steps is not None:
|
||||
all_hooks.append(basic_session_run_hooks.StopAtStepHook(steps, max_steps))
|
||||
all_hooks.extend(hooks)
|
||||
|
||||
with monitored_session.MonitoredSession(
|
||||
session_creator=session_creator,
|
||||
hooks=all_hooks) as super_sess:
|
||||
loss = None
|
||||
while not super_sess.should_stop():
|
||||
_, loss = super_sess.run([train_op, loss_op], feed_fn() if feed_fn else
|
||||
None)
|
||||
|
||||
summary_io.SummaryWriterCache.clear()
|
||||
return loss
|
||||
|
||||
|
||||
@_graph_action_deprecation
|
||||
def train(graph,
|
||||
output_dir,
|
||||
|
@ -27,7 +27,6 @@ from tensorflow.contrib.framework.python.ops import variables as variables_lib
|
||||
from tensorflow.contrib.learn.python import learn
|
||||
from tensorflow.contrib.learn.python.learn.monitors import BaseMonitor
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import meta_graph
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
@ -36,7 +35,6 @@ from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.summary import summary
|
||||
from tensorflow.python.training import monitored_session
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
|
||||
|
||||
@ -220,7 +218,7 @@ class GraphActionsTest(test.TestCase):
|
||||
self.assertTrue(request_stop.called)
|
||||
|
||||
def test_run_feeds_iter_calls_resources_init(self):
|
||||
with ops.Graph().as_default() as g:
|
||||
with ops.Graph().as_default():
|
||||
in0, _, _ = self._build_inference_graph()
|
||||
handle = test_ops.stub_resource_handle_op(container='a', shared_name='b')
|
||||
resources.register_resource(
|
||||
@ -314,7 +312,7 @@ class GraphActionsTest(test.TestCase):
|
||||
with ops.Graph().as_default() as g, self.test_session(g):
|
||||
variables_lib.create_global_step()
|
||||
v = variables.Variable(1.0)
|
||||
w = variables.Variable(
|
||||
variables.Variable(
|
||||
v + 1, collections=[ops.GraphKeys.LOCAL_VARIABLES], trainable=False)
|
||||
ready_for_local_init_op = variables.report_uninitialized_variables(
|
||||
variables.global_variables())
|
||||
@ -396,223 +394,11 @@ class GraphActionsTest(test.TestCase):
|
||||
}},
|
||||
expected_session_logs=[])
|
||||
|
||||
def test_train_invalid_args(self):
|
||||
with ops.Graph().as_default() as g, self.test_session(g):
|
||||
train_op = constant_op.constant(1.0)
|
||||
loss_op = constant_op.constant(2.0)
|
||||
with self.assertRaisesRegexp(ValueError, 'utput directory'):
|
||||
learn.graph_actions._monitored_train(
|
||||
g, # pylint: disable=protected-access
|
||||
output_dir=None,
|
||||
train_op=train_op,
|
||||
loss_op=loss_op)
|
||||
with self.assertRaisesRegexp(ValueError, 'utput directory'):
|
||||
learn.graph_actions._monitored_train( # pylint: disable=protected-access
|
||||
g,
|
||||
output_dir='',
|
||||
train_op=constant_op.constant(1.0),
|
||||
loss_op=constant_op.constant(2.0))
|
||||
with self.assertRaisesRegexp(ValueError, 'train_op'):
|
||||
learn.graph_actions._monitored_train( # pylint: disable=protected-access
|
||||
g,
|
||||
output_dir=self._output_dir,
|
||||
train_op=None,
|
||||
loss_op=loss_op)
|
||||
with self.assertRaisesRegexp(ValueError, 'loss_op'):
|
||||
learn.graph_actions._monitored_train( # pylint: disable=protected-access
|
||||
g,
|
||||
output_dir=self._output_dir,
|
||||
train_op=constant_op.constant(1.0),
|
||||
loss_op=None)
|
||||
with self.assertRaisesRegexp(ValueError, 'global_step'):
|
||||
learn.graph_actions._monitored_train( # pylint: disable=protected-access
|
||||
g,
|
||||
output_dir=self._output_dir,
|
||||
train_op=constant_op.constant(1.0),
|
||||
loss_op=loss_op)
|
||||
|
||||
# TODO(ptucker): Resume training from previous ckpt.
|
||||
# TODO(ptucker): !supervisor_is_chief
|
||||
# TODO(ptucker): Custom init op for training.
|
||||
# TODO(ptucker): Mock supervisor, and assert all interactions.
|
||||
|
||||
def test_train(self):
|
||||
with ops.Graph().as_default() as g, self.test_session(g):
|
||||
with ops.control_dependencies(self._build_inference_graph()):
|
||||
train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
|
||||
writer = learn.graph_actions.get_summary_writer(self._output_dir)
|
||||
self._assert_summaries(self._output_dir, writer)
|
||||
self._assert_ckpt(self._output_dir, False)
|
||||
loss = learn.graph_actions._monitored_train( # pylint: disable=protected-access
|
||||
g,
|
||||
output_dir=self._output_dir,
|
||||
train_op=train_op,
|
||||
loss_op=constant_op.constant(2.0),
|
||||
steps=1)
|
||||
meta_graph_def = meta_graph.create_meta_graph_def(
|
||||
graph_def=g.as_graph_def(add_shapes=True),
|
||||
saver_def=monitored_session.Scaffold().finalize().saver.saver_def)
|
||||
self.assertEqual(2.0, loss)
|
||||
self._assert_summaries(
|
||||
self._output_dir,
|
||||
writer,
|
||||
expected_graphs=[g],
|
||||
expected_meta_graphs=[meta_graph_def])
|
||||
self._assert_ckpt(self._output_dir, True)
|
||||
|
||||
def test_train_steps_is_incremental(self):
|
||||
with ops.Graph().as_default() as g, self.test_session(g):
|
||||
with ops.control_dependencies(self._build_inference_graph()):
|
||||
train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
|
||||
learn.graph_actions._monitored_train( # pylint: disable=protected-access
|
||||
g,
|
||||
output_dir=self._output_dir,
|
||||
train_op=train_op,
|
||||
loss_op=constant_op.constant(2.0),
|
||||
steps=10)
|
||||
step = checkpoint_utils.load_variable(
|
||||
self._output_dir, variables_lib.get_global_step().name)
|
||||
self.assertEqual(10, step)
|
||||
|
||||
with ops.Graph().as_default() as g, self.test_session(g):
|
||||
with ops.control_dependencies(self._build_inference_graph()):
|
||||
train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
|
||||
learn.graph_actions._monitored_train( # pylint: disable=protected-access
|
||||
g,
|
||||
output_dir=self._output_dir,
|
||||
train_op=train_op,
|
||||
loss_op=constant_op.constant(2.0),
|
||||
steps=15)
|
||||
step = checkpoint_utils.load_variable(
|
||||
self._output_dir, variables_lib.get_global_step().name)
|
||||
self.assertEqual(25, step)
|
||||
|
||||
def test_train_max_steps_is_not_incremental(self):
|
||||
with ops.Graph().as_default() as g, self.test_session(g):
|
||||
with ops.control_dependencies(self._build_inference_graph()):
|
||||
train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
|
||||
learn.graph_actions._monitored_train( # pylint: disable=protected-access
|
||||
g,
|
||||
output_dir=self._output_dir,
|
||||
train_op=train_op,
|
||||
loss_op=constant_op.constant(2.0),
|
||||
max_steps=10)
|
||||
step = checkpoint_utils.load_variable(
|
||||
self._output_dir, variables_lib.get_global_step().name)
|
||||
self.assertEqual(10, step)
|
||||
|
||||
with ops.Graph().as_default() as g, self.test_session(g):
|
||||
with ops.control_dependencies(self._build_inference_graph()):
|
||||
train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
|
||||
learn.graph_actions._monitored_train( # pylint: disable=protected-access
|
||||
g,
|
||||
output_dir=self._output_dir,
|
||||
train_op=train_op,
|
||||
loss_op=constant_op.constant(2.0),
|
||||
max_steps=15)
|
||||
step = checkpoint_utils.load_variable(
|
||||
self._output_dir, variables_lib.get_global_step().name)
|
||||
self.assertEqual(15, step)
|
||||
|
||||
def test_train_skip_train_if_max_step_already_saved(self):
|
||||
with ops.Graph().as_default() as g, self.test_session(g):
|
||||
with ops.control_dependencies(self._build_inference_graph()):
|
||||
train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
|
||||
learn.graph_actions._monitored_train( # pylint: disable=protected-access
|
||||
g,
|
||||
output_dir=self._output_dir,
|
||||
train_op=train_op,
|
||||
loss_op=constant_op.constant(2.0),
|
||||
max_steps=10)
|
||||
step = checkpoint_utils.load_variable(
|
||||
self._output_dir, variables_lib.get_global_step().name)
|
||||
self.assertEqual(10, step)
|
||||
|
||||
with ops.Graph().as_default() as g, self.test_session(g):
|
||||
with ops.control_dependencies(self._build_inference_graph()):
|
||||
train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
|
||||
learn.graph_actions._monitored_train( # pylint: disable=protected-access
|
||||
g,
|
||||
output_dir=self._output_dir,
|
||||
train_op=train_op,
|
||||
loss_op=constant_op.constant(2.0),
|
||||
max_steps=10)
|
||||
step = checkpoint_utils.load_variable(
|
||||
self._output_dir, variables_lib.get_global_step().name)
|
||||
self.assertEqual(10, step)
|
||||
|
||||
def test_train_loss(self):
|
||||
with ops.Graph().as_default() as g, self.test_session(g):
|
||||
variables_lib.create_global_step()
|
||||
loss_var = variables_lib.local_variable(10.0)
|
||||
train_op = control_flow_ops.group(
|
||||
state_ops.assign_add(variables_lib.get_global_step(), 1),
|
||||
state_ops.assign_add(loss_var, -1.0))
|
||||
writer = learn.graph_actions.get_summary_writer(self._output_dir)
|
||||
self._assert_summaries(self._output_dir, writer)
|
||||
self._assert_ckpt(self._output_dir, False)
|
||||
loss = learn.graph_actions._monitored_train( # pylint: disable=protected-access
|
||||
g,
|
||||
output_dir=self._output_dir,
|
||||
train_op=train_op,
|
||||
loss_op=loss_var.value(),
|
||||
steps=6)
|
||||
self.assertEqual(4.0, loss)
|
||||
self._assert_summaries(
|
||||
self._output_dir,
|
||||
writer,
|
||||
expected_graphs=[g],
|
||||
expected_meta_graphs=None)
|
||||
self._assert_ckpt(self._output_dir, True)
|
||||
|
||||
def test_train_summaries(self):
|
||||
with ops.Graph().as_default() as g, self.test_session(g):
|
||||
with ops.control_dependencies(self._build_inference_graph()):
|
||||
train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
|
||||
loss_op = constant_op.constant(2.0)
|
||||
summary.scalar('loss', loss_op)
|
||||
writer = learn.graph_actions.get_summary_writer(self._output_dir)
|
||||
self._assert_summaries(self._output_dir, writer)
|
||||
self._assert_ckpt(self._output_dir, False)
|
||||
loss = learn.graph_actions._monitored_train( # pylint: disable=protected-access
|
||||
g,
|
||||
output_dir=self._output_dir,
|
||||
train_op=train_op,
|
||||
loss_op=loss_op,
|
||||
steps=1)
|
||||
meta_graph_def = meta_graph.create_meta_graph_def(
|
||||
graph_def=g.as_graph_def(add_shapes=True),
|
||||
saver_def=monitored_session.Scaffold().finalize().saver.saver_def)
|
||||
self.assertEqual(2.0, loss)
|
||||
self._assert_summaries(
|
||||
self._output_dir,
|
||||
writer,
|
||||
expected_graphs=[g],
|
||||
expected_meta_graphs=[meta_graph_def],
|
||||
expected_summaries={1: {
|
||||
'loss': 2.0
|
||||
}})
|
||||
self._assert_ckpt(self._output_dir, True)
|
||||
|
||||
def test_train_override_saver(self):
|
||||
with ops.Graph().as_default() as g, self.test_session(g):
|
||||
with ops.control_dependencies(self._build_inference_graph()):
|
||||
train_op = state_ops.assign_add(variables_lib.get_global_step(), 1)
|
||||
self._assert_ckpt(self._output_dir, False)
|
||||
real_saver = saver_lib.Saver()
|
||||
saver = test.mock.Mock(wraps=real_saver, saver_def=real_saver.saver_def)
|
||||
ops.add_to_collection(ops.GraphKeys.SAVERS, saver)
|
||||
loss = learn.graph_actions._monitored_train( # pylint: disable=protected-access
|
||||
g,
|
||||
output_dir=self._output_dir,
|
||||
train_op=train_op,
|
||||
loss_op=constant_op.constant(2.0),
|
||||
steps=1)
|
||||
self.assertEqual(2.0, loss)
|
||||
self._assert_ckpt(self._output_dir, True)
|
||||
self.assertTrue(saver.build.called)
|
||||
self.assertEqual(1, saver.save.call_count)
|
||||
|
||||
|
||||
# TODO(ispir): remove following tests after deprecated train.
|
||||
class GraphActionsTrainTest(test.TestCase):
|
||||
|
@ -60,14 +60,6 @@ class LinearOperatorTriLTest(
|
||||
|
||||
return operator, mat, feed_dict
|
||||
|
||||
def test_assert_positive_definite(self):
|
||||
# Singlular matrix with one positive eigenvalue and one negative eigenvalue.
|
||||
with self.test_session():
|
||||
tril = [[1., 0.], [1., -1.]]
|
||||
operator = linalg.LinearOperatorTriL(tril)
|
||||
with self.assertRaisesOpError("was not positive definite"):
|
||||
operator.assert_positive_definite().run()
|
||||
|
||||
def test_assert_non_singular(self):
|
||||
# Singlular matrix with one positive eigenvalue and one zero eigenvalue.
|
||||
with self.test_session():
|
||||
|
@ -21,8 +21,10 @@ import numpy as np
|
||||
|
||||
from tensorflow.contrib import linalg as linalg_lib
|
||||
from tensorflow.contrib.linalg.python.ops import linear_operator_util
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
@ -91,6 +93,142 @@ class AssertNoEntriesWithModulusZeroTest(test.TestCase):
|
||||
z, message="ABC123").run()
|
||||
|
||||
|
||||
class BroadcastMatrixBatchDimsTest(test.TestCase):
|
||||
|
||||
def test_zero_batch_matrices_returned_as_empty_list(self):
|
||||
self.assertAllEqual(
|
||||
[], linear_operator_util.broadcast_matrix_batch_dims([]))
|
||||
|
||||
def test_one_batch_matrix_returned_after_tensor_conversion(self):
|
||||
arr = rng.rand(2, 3, 4)
|
||||
tensor, = linear_operator_util.broadcast_matrix_batch_dims([arr])
|
||||
self.assertTrue(isinstance(tensor, ops.Tensor))
|
||||
|
||||
with self.test_session():
|
||||
self.assertAllClose(arr, tensor.eval())
|
||||
|
||||
def test_static_dims_broadcast(self):
|
||||
# x.batch_shape = [3, 1, 2]
|
||||
# y.batch_shape = [4, 1]
|
||||
# broadcast batch shape = [3, 4, 2]
|
||||
x = rng.rand(3, 1, 2, 1, 5)
|
||||
y = rng.rand(4, 1, 3, 7)
|
||||
batch_of_zeros = np.zeros((3, 4, 2, 1, 1))
|
||||
x_bc_expected = x + batch_of_zeros
|
||||
y_bc_expected = y + batch_of_zeros
|
||||
|
||||
x_bc, y_bc = linear_operator_util.broadcast_matrix_batch_dims([x, y])
|
||||
|
||||
with self.test_session() as sess:
|
||||
self.assertAllEqual(x_bc_expected.shape, x_bc.get_shape())
|
||||
self.assertAllEqual(y_bc_expected.shape, y_bc.get_shape())
|
||||
x_bc_, y_bc_ = sess.run([x_bc, y_bc])
|
||||
self.assertAllClose(x_bc_expected, x_bc_)
|
||||
self.assertAllClose(y_bc_expected, y_bc_)
|
||||
|
||||
def test_static_dims_broadcast_second_arg_higher_rank(self):
|
||||
# x.batch_shape = [1, 2]
|
||||
# y.batch_shape = [1, 3, 1]
|
||||
# broadcast batch shape = [1, 3, 2]
|
||||
x = rng.rand(1, 2, 1, 5)
|
||||
y = rng.rand(1, 3, 2, 3, 7)
|
||||
batch_of_zeros = np.zeros((1, 3, 2, 1, 1))
|
||||
x_bc_expected = x + batch_of_zeros
|
||||
y_bc_expected = y + batch_of_zeros
|
||||
|
||||
x_bc, y_bc = linear_operator_util.broadcast_matrix_batch_dims([x, y])
|
||||
|
||||
with self.test_session() as sess:
|
||||
self.assertAllEqual(x_bc_expected.shape, x_bc.get_shape())
|
||||
self.assertAllEqual(y_bc_expected.shape, y_bc.get_shape())
|
||||
x_bc_, y_bc_ = sess.run([x_bc, y_bc])
|
||||
self.assertAllClose(x_bc_expected, x_bc_)
|
||||
self.assertAllClose(y_bc_expected, y_bc_)
|
||||
|
||||
def test_dynamic_dims_broadcast_32bit(self):
|
||||
# x.batch_shape = [3, 1, 2]
|
||||
# y.batch_shape = [4, 1]
|
||||
# broadcast batch shape = [3, 4, 2]
|
||||
x = rng.rand(3, 1, 2, 1, 5).astype(np.float32)
|
||||
y = rng.rand(4, 1, 3, 7).astype(np.float32)
|
||||
batch_of_zeros = np.zeros((3, 4, 2, 1, 1)).astype(np.float32)
|
||||
x_bc_expected = x + batch_of_zeros
|
||||
y_bc_expected = y + batch_of_zeros
|
||||
|
||||
x_ph = array_ops.placeholder(dtypes.float32)
|
||||
y_ph = array_ops.placeholder(dtypes.float32)
|
||||
|
||||
x_bc, y_bc = linear_operator_util.broadcast_matrix_batch_dims([x_ph, y_ph])
|
||||
|
||||
with self.test_session() as sess:
|
||||
x_bc_, y_bc_ = sess.run([x_bc, y_bc], feed_dict={x_ph: x, y_ph: y})
|
||||
self.assertAllClose(x_bc_expected, x_bc_)
|
||||
self.assertAllClose(y_bc_expected, y_bc_)
|
||||
|
||||
def test_dynamic_dims_broadcast_32bit_second_arg_higher_rank(self):
|
||||
# x.batch_shape = [1, 2]
|
||||
# y.batch_shape = [3, 4, 1]
|
||||
# broadcast batch shape = [3, 4, 2]
|
||||
x = rng.rand(1, 2, 1, 5).astype(np.float32)
|
||||
y = rng.rand(3, 4, 1, 3, 7).astype(np.float32)
|
||||
batch_of_zeros = np.zeros((3, 4, 2, 1, 1)).astype(np.float32)
|
||||
x_bc_expected = x + batch_of_zeros
|
||||
y_bc_expected = y + batch_of_zeros
|
||||
|
||||
x_ph = array_ops.placeholder(dtypes.float32)
|
||||
y_ph = array_ops.placeholder(dtypes.float32)
|
||||
|
||||
x_bc, y_bc = linear_operator_util.broadcast_matrix_batch_dims([x_ph, y_ph])
|
||||
|
||||
with self.test_session() as sess:
|
||||
x_bc_, y_bc_ = sess.run([x_bc, y_bc], feed_dict={x_ph: x, y_ph: y})
|
||||
self.assertAllClose(x_bc_expected, x_bc_)
|
||||
self.assertAllClose(y_bc_expected, y_bc_)
|
||||
|
||||
def test_less_than_two_dims_raises_static(self):
|
||||
x = rng.rand(3)
|
||||
y = rng.rand(1, 1)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, "at least two dimensions"):
|
||||
linear_operator_util.broadcast_matrix_batch_dims([x, y])
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, "at least two dimensions"):
|
||||
linear_operator_util.broadcast_matrix_batch_dims([y, x])
|
||||
|
||||
|
||||
class MatmulWithBroadcastTest(test.TestCase):
|
||||
|
||||
def test_static_dims_broadcast(self):
|
||||
# batch_shape = [2]
|
||||
# for each batch member, we have a 1x3 matrix times a 3x7 matrix ==> 1x7
|
||||
x = rng.rand(2, 1, 3)
|
||||
y = rng.rand(3, 7)
|
||||
y_broadcast = y + np.zeros((2, 1, 1))
|
||||
|
||||
with self.test_session():
|
||||
result = linear_operator_util.matmul_with_broadcast(x, y)
|
||||
self.assertAllEqual((2, 1, 7), result.get_shape())
|
||||
expected = math_ops.matmul(x, y_broadcast)
|
||||
self.assertAllEqual(expected.eval(), result.eval())
|
||||
|
||||
def test_dynamic_dims_broadcast_32bit(self):
|
||||
# batch_shape = [2]
|
||||
# for each batch member, we have a 1x3 matrix times a 3x7 matrix ==> 1x7
|
||||
x = rng.rand(2, 1, 3)
|
||||
y = rng.rand(3, 7)
|
||||
y_broadcast = y + np.zeros((2, 1, 1))
|
||||
|
||||
x_ph = array_ops.placeholder(dtypes.float64)
|
||||
y_ph = array_ops.placeholder(dtypes.float64)
|
||||
|
||||
with self.test_session() as sess:
|
||||
result, expected = sess.run(
|
||||
[linear_operator_util.matmul_with_broadcast(x_ph, y_ph),
|
||||
math_ops.matmul(x, y_broadcast)],
|
||||
feed_dict={x_ph: x, y_ph: y})
|
||||
self.assertAllEqual(expected, result)
|
||||
|
||||
|
||||
class DomainDimensionStubOperator(object):
|
||||
|
||||
def __init__(self, domain_dimension):
|
||||
|
@ -155,8 +155,9 @@ class LinearOperator(object):
|
||||
is_self_adjoint: Expect that this operator is equal to its hermitian
|
||||
transpose. If `dtype` is real, this is equivalent to being symmetric.
|
||||
is_positive_definite: Expect that this operator is positive definite,
|
||||
meaning the real part of all eigenvalues is positive. We do not require
|
||||
the operator to be self-adjoint to be positive-definite. See:
|
||||
meaning the quadratic form `x^H A x` has positive real part for all
|
||||
nonzero `x`. Note that we do not require the operator to be
|
||||
self-adjoint to be positive-definite. See:
|
||||
https://en.wikipedia.org/wiki/Positive-definite_matrix\
|
||||
#Extension_for_non_symmetric_matrices
|
||||
is_square: Expect that this operator acts like square [batch] matrices.
|
||||
@ -461,8 +462,9 @@ class LinearOperator(object):
|
||||
def assert_positive_definite(self, name="assert_positive_definite"):
|
||||
"""Returns an `Op` that asserts this operator is positive definite.
|
||||
|
||||
Here, positive definite means the real part of all eigenvalues is positive.
|
||||
We do not require the operator to be self-adjoint.
|
||||
Here, positive definite means that the quadratic form `x^H A x` has positive
|
||||
real part for all nonzero `x`. Note that we do not require the operator to
|
||||
be self-adjoint to be positive definite.
|
||||
|
||||
Args:
|
||||
name: A name to give this `Op`.
|
||||
|
@ -113,7 +113,7 @@ class LinearOperatorComposition(linear_operator.LinearOperator):
|
||||
is_self_adjoint=None,
|
||||
is_positive_definite=None,
|
||||
name=None):
|
||||
"""Initialize a `LinearOperatorComposition`.
|
||||
r"""Initialize a `LinearOperatorComposition`.
|
||||
|
||||
`LinearOperatorComposition` is initialized with a list of operators
|
||||
`[op_1,...,op_J]`. For the `apply` method to be well defined, the
|
||||
@ -127,9 +127,10 @@ class LinearOperatorComposition(linear_operator.LinearOperator):
|
||||
is_self_adjoint: Expect that this operator is equal to its hermitian
|
||||
transpose.
|
||||
is_positive_definite: Expect that this operator is positive definite,
|
||||
meaning the real part of all eigenvalues is positive. We do not require
|
||||
the operator to be self-adjoint to be positive-definite. See:
|
||||
https://en.wikipedia.org/wiki/Positive-definite_matrix
|
||||
meaning the quadratic form `x^H A x` has positive real part for all
|
||||
nonzero `x`. Note that we do not require the operator to be
|
||||
self-adjoint to be positive-definite. See:
|
||||
https://en.wikipedia.org/wiki/Positive-definite_matrix\
|
||||
#Extension_for_non_symmetric_matrices
|
||||
name: A name for this `LinearOperator`. Default is the individual
|
||||
operators names joined with `_o_`.
|
||||
|
@ -114,7 +114,7 @@ class LinearOperatorDiag(linear_operator.LinearOperator):
|
||||
is_self_adjoint=None,
|
||||
is_positive_definite=None,
|
||||
name="LinearOperatorDiag"):
|
||||
"""Initialize a `LinearOperatorDiag`.
|
||||
r"""Initialize a `LinearOperatorDiag`.
|
||||
|
||||
Args:
|
||||
diag: Shape `[B1,...,Bb, N]` `Tensor` with `b >= 0` `N >= 0`.
|
||||
@ -124,9 +124,10 @@ class LinearOperatorDiag(linear_operator.LinearOperator):
|
||||
is_self_adjoint: Expect that this operator is equal to its hermitian
|
||||
transpose. If `diag.dtype` is real, this is auto-set to `True`.
|
||||
is_positive_definite: Expect that this operator is positive definite,
|
||||
meaning the real part of all eigenvalues is positive. We do not require
|
||||
the operator to be self-adjoint to be positive-definite. See:
|
||||
https://en.wikipedia.org/wiki/Positive-definite_matrix
|
||||
meaning the quadratic form `x^H A x` has positive real part for all
|
||||
nonzero `x`. Note that we do not require the operator to be
|
||||
self-adjoint to be positive-definite. See:
|
||||
https://en.wikipedia.org/wiki/Positive-definite_matrix\
|
||||
#Extension_for_non_symmetric_matrices
|
||||
name: A name for this `LinearOperator`.
|
||||
|
||||
|
@ -109,7 +109,7 @@ class LinearOperatorFullMatrix(linear_operator.LinearOperator):
|
||||
is_self_adjoint=None,
|
||||
is_positive_definite=None,
|
||||
name="LinearOperatorFullMatrix"):
|
||||
"""Initialize a `LinearOperatorFullMatrix`.
|
||||
r"""Initialize a `LinearOperatorFullMatrix`.
|
||||
|
||||
Args:
|
||||
matrix: Shape `[B1,...,Bb, M, N]` with `b >= 0`, `M, N >= 0`.
|
||||
@ -118,9 +118,10 @@ class LinearOperatorFullMatrix(linear_operator.LinearOperator):
|
||||
is_self_adjoint: Expect that this operator is equal to its hermitian
|
||||
transpose.
|
||||
is_positive_definite: Expect that this operator is positive definite,
|
||||
meaning the real part of all eigenvalues is positive. We do not require
|
||||
the operator to be self-adjoint to be positive-definite. See:
|
||||
https://en.wikipedia.org/wiki/Positive-definite_matrix
|
||||
meaning the quadratic form `x^H A x` has positive real part for all
|
||||
nonzero `x`. Note that we do not require the operator to be
|
||||
self-adjoint to be positive-definite. See:
|
||||
https://en.wikipedia.org/wiki/Positive-definite_matrix\
|
||||
#Extension_for_non_symmetric_matrices
|
||||
name: A name for this `LinearOperator`.
|
||||
|
||||
|
@ -200,7 +200,7 @@ class LinearOperatorIdentity(BaseLinearOperatorIdentity):
|
||||
is_positive_definite=True,
|
||||
assert_proper_shapes=False,
|
||||
name="LinearOperatorIdentity"):
|
||||
"""Initialize a `LinearOperatorIdentity`.
|
||||
r"""Initialize a `LinearOperatorIdentity`.
|
||||
|
||||
The `LinearOperatorIdentity` is initialized with arguments defining `dtype`
|
||||
and shape.
|
||||
@ -218,7 +218,12 @@ class LinearOperatorIdentity(BaseLinearOperatorIdentity):
|
||||
is_non_singular: Expect that this operator is non-singular.
|
||||
is_self_adjoint: Expect that this operator is equal to its hermitian
|
||||
transpose.
|
||||
is_positive_definite: Expect that this operator is positive definite.
|
||||
is_positive_definite: Expect that this operator is positive definite,
|
||||
meaning the quadratic form `x^H A x` has positive real part for all
|
||||
nonzero `x`. Note that we do not require the operator to be
|
||||
self-adjoint to be positive-definite. See:
|
||||
https://en.wikipedia.org/wiki/Positive-definite_matrix\
|
||||
#Extension_for_non_symmetric_matrices
|
||||
assert_proper_shapes: Python `bool`. If `False`, only perform static
|
||||
checks that initialization and method arguments have proper shape.
|
||||
If `True`, and static checks are inconclusive, add asserts to the graph.
|
||||
@ -523,7 +528,7 @@ class LinearOperatorScaledIdentity(BaseLinearOperatorIdentity):
|
||||
is_positive_definite=None,
|
||||
assert_proper_shapes=False,
|
||||
name="LinearOperatorScaledIdentity"):
|
||||
"""Initialize a `LinearOperatorScaledIdentity`.
|
||||
r"""Initialize a `LinearOperatorScaledIdentity`.
|
||||
|
||||
The `LinearOperatorScaledIdentity` is initialized with `num_rows`, which
|
||||
determines the size of each identity matrix, and a `multiplier`,
|
||||
@ -538,7 +543,12 @@ class LinearOperatorScaledIdentity(BaseLinearOperatorIdentity):
|
||||
is_non_singular: Expect that this operator is non-singular.
|
||||
is_self_adjoint: Expect that this operator is equal to its hermitian
|
||||
transpose.
|
||||
is_positive_definite: Expect that this operator is positive definite.
|
||||
is_positive_definite: Expect that this operator is positive definite,
|
||||
meaning the quadratic form `x^H A x` has positive real part for all
|
||||
nonzero `x`. Note that we do not require the operator to be
|
||||
self-adjoint to be positive-definite. See:
|
||||
https://en.wikipedia.org/wiki/Positive-definite_matrix\
|
||||
#Extension_for_non_symmetric_matrices
|
||||
assert_proper_shapes: Python `bool`. If `False`, only perform static
|
||||
checks that initialization and method arguments have proper shape.
|
||||
If `True`, and static checks are inconclusive, add asserts to the graph.
|
||||
|
@ -23,7 +23,6 @@ from tensorflow.contrib.linalg.python.ops import linear_operator_util
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import linalg_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
||||
@ -108,7 +107,7 @@ class LinearOperatorTriL(linear_operator.LinearOperator):
|
||||
is_self_adjoint=None,
|
||||
is_positive_definite=None,
|
||||
name="LinearOperatorTriL"):
|
||||
"""Initialize a `LinearOperatorTriL`.
|
||||
r"""Initialize a `LinearOperatorTriL`.
|
||||
|
||||
Args:
|
||||
tril: Shape `[B1,...,Bb, N, N]` with `b >= 0`, `N >= 0`.
|
||||
@ -122,9 +121,10 @@ class LinearOperatorTriL(linear_operator.LinearOperator):
|
||||
real-valued diagonal entries. In this case it is advised to use
|
||||
`LinearOperatorDiag`.
|
||||
is_positive_definite: Expect that this operator is positive definite,
|
||||
meaning the real part of all eigenvalues is positive. We do not require
|
||||
the operator to be self-adjoint to be positive-definite. See:
|
||||
https://en.wikipedia.org/wiki/Positive-definite_matrix
|
||||
meaning the quadratic form `x^H A x` has positive real part for all
|
||||
nonzero `x`. Note that we do not require the operator to be
|
||||
self-adjoint to be positive-definite. See:
|
||||
https://en.wikipedia.org/wiki/Positive-definite_matrix\
|
||||
#Extension_for_non_symmetric_matrices
|
||||
name: A name for this `LinearOperator`.
|
||||
|
||||
@ -173,20 +173,6 @@ class LinearOperatorTriL(linear_operator.LinearOperator):
|
||||
self._diag,
|
||||
message="Singular operator: Diagonal contained zero values.")
|
||||
|
||||
def _assert_positive_definite(self):
|
||||
if self.dtype.is_complex:
|
||||
message = (
|
||||
"Diagonal operator had diagonal entries with non-positive real part, "
|
||||
"thus was not positive definite.")
|
||||
else:
|
||||
message = (
|
||||
"Real diagonal operator had non-positive diagonal entries, "
|
||||
"thus was not positive definite.")
|
||||
|
||||
return check_ops.assert_positive(
|
||||
math_ops.real(self._diag),
|
||||
message=message)
|
||||
|
||||
def _apply(self, x, adjoint=False):
|
||||
return math_ops.matmul(self._tril, x, adjoint_a=adjoint)
|
||||
|
||||
|
@ -170,6 +170,8 @@ class LinearOperatorUDVHUpdate(linear_operator.LinearOperator):
|
||||
Default is `None`, unless `base_operator` is positive-definite
|
||||
`v = None` (meaning `u=v`), and `is_diag_update_positive`, in which case
|
||||
this defaults to `True`.
|
||||
Note that we say an operator is positive definite when the quadratic
|
||||
form `x^H A x` has positive real part for all nonzero `x`.
|
||||
is_square: Expect that this operator acts like square [batch] matrices.
|
||||
name: A name for this `LinearOperator`.
|
||||
|
||||
|
@ -87,13 +87,208 @@ def assert_compatible_matrix_dimensions(operator, x):
|
||||
assert_same_dd = check_ops.assert_equal(
|
||||
array_ops.shape(x)[-2],
|
||||
operator.domain_dimension_tensor(),
|
||||
message=(
|
||||
"Incompatible matrix dimensions. "
|
||||
"shape[-2] of argument to be the same as this operator"))
|
||||
message=("Incompatible matrix dimensions. "
|
||||
"shape[-2] of argument to be the same as this operator"))
|
||||
|
||||
return assert_same_dd
|
||||
|
||||
|
||||
def assert_is_batch_matrix(tensor):
|
||||
"""Static assert that `tensor` has rank `2` or higher."""
|
||||
sh = tensor.get_shape()
|
||||
if sh.ndims is not None and sh.ndims < 2:
|
||||
raise ValueError(
|
||||
"Expected [batch] matrix to have at least two dimensions. Found: "
|
||||
"%s" % tensor)
|
||||
|
||||
|
||||
def broadcast_matrix_batch_dims(batch_matrices, name=None):
|
||||
"""Broadcast leading dimensions of zero or more [batch] matrices.
|
||||
|
||||
Example broadcasting one batch dim of two simple matrices.
|
||||
|
||||
```python
|
||||
x = [[1, 2],
|
||||
[3, 4]] # Shape [2, 2], no batch dims
|
||||
|
||||
y = [[[1]]] # Shape [1, 1, 1], 1 batch dim of shape [1]
|
||||
|
||||
x_bc, y_bc = broadcast_matrix_batch_dims([x, y])
|
||||
|
||||
x_bc
|
||||
==> [[[1, 2],
|
||||
[3, 4]]] # Shape [1, 2, 2], 1 batch dim of shape [1].
|
||||
|
||||
y_bc
|
||||
==> same as y
|
||||
```
|
||||
|
||||
Example broadcasting many batch dims
|
||||
|
||||
```python
|
||||
x = tf.random_normal(shape=(2, 3, 1, 4, 4))
|
||||
y = tf.random_normal(shape=(1, 3, 2, 5, 5))
|
||||
x_bc, y_bc = broadcast_matrix_batch_dims([x, y])
|
||||
|
||||
x_bc.shape
|
||||
==> (2, 3, 2, 4, 4)
|
||||
|
||||
y_bc.shape
|
||||
==> (2, 3, 2, 5, 5)
|
||||
```
|
||||
|
||||
Args:
|
||||
batch_matrices: Iterable of `Tensor`s, each having two or more dimensions.
|
||||
name: A string name to prepend to created ops.
|
||||
|
||||
Returns:
|
||||
bcast_matrices: List of `Tensor`s, with `bcast_matricies[i]` containing
|
||||
the values from `batch_matrices[i]`, with possibly broadcast batch dims.
|
||||
|
||||
Raises:
|
||||
ValueError: If any input `Tensor` is statically determined to have less
|
||||
than two dimensions.
|
||||
"""
|
||||
with ops.name_scope(
|
||||
name or "broadcast_matrix_batch_dims", values=batch_matrices):
|
||||
check_ops.assert_proper_iterable(batch_matrices)
|
||||
batch_matrices = list(batch_matrices)
|
||||
|
||||
for i, mat in enumerate(batch_matrices):
|
||||
batch_matrices[i] = ops.convert_to_tensor(mat)
|
||||
assert_is_batch_matrix(batch_matrices[i])
|
||||
|
||||
if len(batch_matrices) < 2:
|
||||
return batch_matrices
|
||||
|
||||
# Try static broadcasting.
|
||||
# bcast_batch_shape is the broadcast batch shape of ALL matrices.
|
||||
# E.g. if batch_matrices = [x, y], with
|
||||
# x.shape = [2, j, k] (batch shape = [2])
|
||||
# y.shape = [3, 1, l, m] (batch shape = [3, 1])
|
||||
# ==> bcast_batch_shape = [3, 2]
|
||||
bcast_batch_shape = batch_matrices[0].get_shape()[:-2]
|
||||
for mat in batch_matrices[1:]:
|
||||
bcast_batch_shape = array_ops.broadcast_static_shape(
|
||||
bcast_batch_shape, mat.get_shape()[:-2])
|
||||
if bcast_batch_shape.is_fully_defined():
|
||||
# The [1, 1] at the end will broadcast with anything.
|
||||
bcast_shape = bcast_batch_shape.concatenate([1, 1])
|
||||
for i, mat in enumerate(batch_matrices):
|
||||
if mat.get_shape()[:-2] != bcast_batch_shape:
|
||||
batch_matrices[i] = _broadcast_to_shape(mat, bcast_shape)
|
||||
return batch_matrices
|
||||
|
||||
# Since static didn't work, do dynamic, which always copies data.
|
||||
bcast_batch_shape = array_ops.shape(batch_matrices[0])[:-2]
|
||||
for mat in batch_matrices[1:]:
|
||||
bcast_batch_shape = array_ops.broadcast_dynamic_shape(
|
||||
bcast_batch_shape, array_ops.shape(mat)[:-2])
|
||||
bcast_shape = array_ops.concat([bcast_batch_shape, [1, 1]], axis=0)
|
||||
for i, mat in enumerate(batch_matrices):
|
||||
batch_matrices[i] = _broadcast_to_shape(mat, bcast_shape)
|
||||
|
||||
return batch_matrices
|
||||
|
||||
|
||||
def _broadcast_to_shape(x, shape):
|
||||
return x + array_ops.zeros(shape=shape, dtype=x.dtype)
|
||||
|
||||
|
||||
def matmul_with_broadcast(a,
|
||||
b,
|
||||
transpose_a=False,
|
||||
transpose_b=False,
|
||||
adjoint_a=False,
|
||||
adjoint_b=False,
|
||||
a_is_sparse=False,
|
||||
b_is_sparse=False,
|
||||
name=None):
|
||||
"""Multiplies matrix `a` by matrix `b`, producing `a @ b`.
|
||||
|
||||
The inputs must be matrices (or tensors of rank > 2, representing batches of
|
||||
matrices).
|
||||
|
||||
Both matrices must be of the same type. The supported types are:
|
||||
`float16`, `float32`, `float64`, `int32`, `complex64`, `complex128`.
|
||||
|
||||
Either matrix can be transposed or adjointed (conjugated and transposed) on
|
||||
the fly by setting one of the corresponding flag to `True`. These are `False`
|
||||
by default.
|
||||
|
||||
If one or both of the matrices contain a lot of zeros, a more efficient
|
||||
multiplication algorithm can be used by setting the corresponding
|
||||
`a_is_sparse` or `b_is_sparse` flag to `True`. These are `False` by default.
|
||||
This optimization is only available for plain matrices (rank-2 tensors) with
|
||||
datatypes `bfloat16` or `float32`.
|
||||
|
||||
For example:
|
||||
|
||||
```python
|
||||
# A 2-batch of 3x4 matrices
|
||||
a = tf.random_normal(shape=(2, 3, 4))
|
||||
|
||||
# A single 4x5 matrix
|
||||
b = tf.random_normal(shape=(4, 5))
|
||||
|
||||
result = matmul_with_broadcast(a, b)
|
||||
|
||||
result.shape
|
||||
==> (2, 3, 5)
|
||||
|
||||
result[0,...]
|
||||
==> tf.matmul(a[0,...], b)
|
||||
|
||||
result[1,...]
|
||||
==> tf.matmul(a[1,...], b)
|
||||
```
|
||||
|
||||
Args:
|
||||
a: `Tensor` of type `float16`, `float32`, `float64`, `int32`, `complex64`,
|
||||
`complex128` and `rank > 1`.
|
||||
b: `Tensor` with same type as `a` having compatible matrix dimensions and
|
||||
broadcastable batch dimensions.
|
||||
transpose_a: If `True`, `a` is transposed before multiplication.
|
||||
transpose_b: If `True`, `b` is transposed before multiplication.
|
||||
adjoint_a: If `True`, `a` is conjugated and transposed before
|
||||
multiplication.
|
||||
adjoint_b: If `True`, `b` is conjugated and transposed before
|
||||
multiplication.
|
||||
a_is_sparse: If `True`, `a` is treated as a sparse matrix.
|
||||
b_is_sparse: If `True`, `b` is treated as a sparse matrix.
|
||||
name: Name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
A `Tensor` of the same type as `a` and `b` where each inner-most matrix is
|
||||
the product of the corresponding matrices in `a` and `b`, e.g. if all
|
||||
transpose or adjoint attributes are `False`:
|
||||
|
||||
The leading shape of `output` is the result of broadcasting the leading
|
||||
dimensions of `a` and `b`.
|
||||
|
||||
`output`[..., i, j] = sum_k (`a`[..., i, k] * `b`[..., k, j]),
|
||||
for all indices i, j.
|
||||
|
||||
Note: This is matrix product, not element-wise product.
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If transpose_a and adjoint_a, or transpose_b and adjoint_b
|
||||
are both set to True.
|
||||
"""
|
||||
with ops.name_scope(name, "MatMulWithBroadcast", [a, b]) as name:
|
||||
a, b = broadcast_matrix_batch_dims([a, b])
|
||||
return math_ops.matmul(
|
||||
a,
|
||||
b,
|
||||
transpose_a=transpose_a,
|
||||
transpose_b=transpose_b,
|
||||
adjoint_a=adjoint_a,
|
||||
adjoint_b=adjoint_b,
|
||||
a_is_sparse=a_is_sparse,
|
||||
b_is_sparse=b_is_sparse)
|
||||
|
||||
|
||||
def shape_tensor(shape, name=None):
|
||||
"""Convert Tensor using default type, unless empty list or tuple."""
|
||||
# Works just like random_ops._ShapeTensor.
|
||||
|
@ -270,6 +270,8 @@ def sdca_model_fn(features, labels, mode, params, config=None):
|
||||
|
||||
with variable_scope.variable_op_scope(features.values(),
|
||||
parent_scope) as scope:
|
||||
features = features.copy()
|
||||
features.update(layers.transform_features(features, feature_columns))
|
||||
logits, columns_to_variables, bias = (
|
||||
layers.weighted_sum_from_feature_columns(
|
||||
columns_to_tensors=features,
|
||||
|
@ -1879,11 +1879,11 @@ def streaming_pearson_correlation(predictions,
|
||||
math_ops.multiply(math_ops.sqrt(var_predictions),
|
||||
math_ops.sqrt(var_labels)),
|
||||
'pearson_r')
|
||||
with ops.control_dependencies(
|
||||
[update_cov, update_var_predictions, update_var_labels]):
|
||||
update_op = _safe_div(update_cov, math_ops.multiply(
|
||||
math_ops.sqrt(update_var_predictions),
|
||||
math_ops.sqrt(update_var_labels)), 'update_op')
|
||||
update_op = _safe_div(
|
||||
update_cov,
|
||||
math_ops.multiply(math_ops.sqrt(update_var_predictions),
|
||||
math_ops.sqrt(update_var_labels)),
|
||||
'update_op')
|
||||
|
||||
if metrics_collections:
|
||||
ops.add_to_collections(metrics_collections, pearson_r)
|
||||
|
@ -42,6 +42,7 @@ See @{$python/contrib.rnn} guide.
|
||||
@@GridLSTMCell
|
||||
@@BidirectionalGridLSTMCell
|
||||
@@NASCell
|
||||
@@PhasedLSTMCell
|
||||
|
||||
### RNNCell wrappers
|
||||
@@AttentionCellWrapper
|
||||
|
@ -173,7 +173,6 @@ class RNNCellTest(test.TestCase):
|
||||
with self.test_session() as sess:
|
||||
num_units = 8
|
||||
batch_size = 3
|
||||
input_size = 4
|
||||
feature_size = 2
|
||||
frequency_skip = 1
|
||||
num_frequency_blocks = [1, 1]
|
||||
@ -844,6 +843,45 @@ class RNNCellTest(test.TestCase):
|
||||
"be set to num_units at cell init."):
|
||||
cell(inputs, init_state)
|
||||
|
||||
def testPhasedLSTMCell(self):
|
||||
with self.test_session() as sess:
|
||||
num_units = 2
|
||||
batch_size = 3
|
||||
input_size = 4
|
||||
expected_state_c = np.array(
|
||||
[[2.954548e-01, 8.354891e-04],
|
||||
[2.834632e-01, 8.158963e-01],
|
||||
[2.291694e-01, 1.325745e-04]],
|
||||
dtype=np.float32)
|
||||
expected_state_h = np.array(
|
||||
[[2.116566e-01, 5.985238e-04],
|
||||
[2.137760e-01, 6.153145e-01],
|
||||
[1.742966e-01, 1.008306e-04]],
|
||||
dtype=np.float32)
|
||||
with variable_scope.variable_scope(
|
||||
"root", initializer=init_ops.constant_initializer(0.5)):
|
||||
t = array_ops.zeros([batch_size, 1], dtype=dtypes.float64)
|
||||
x = array_ops.zeros([batch_size, input_size])
|
||||
c0 = array_ops.zeros([batch_size, 2])
|
||||
h0 = array_ops.zeros([batch_size, 2])
|
||||
state0 = core_rnn_cell_impl.LSTMStateTuple(c0, h0)
|
||||
output, state = rnn_cell.PhasedLSTMCell(num_units=num_units)((t, x),
|
||||
state0)
|
||||
sess.run([variables.global_variables_initializer()])
|
||||
res = sess.run([output, state], {
|
||||
t.name:
|
||||
np.array([[1.], [2.], [3.]]),
|
||||
x.name:
|
||||
np.array([[1., 1., 1., 1.],
|
||||
[2., 2., 2., 2.],
|
||||
[3., 3., 3., 3.]]),
|
||||
})
|
||||
# This is a smoke test, making sure expected values are unchanged.
|
||||
self.assertEqual(len(res), 2)
|
||||
self.assertAllClose(res[0], res[1].h)
|
||||
self.assertAllClose(res[1].c, expected_state_c)
|
||||
self.assertAllClose(res[1].h, expected_state_h)
|
||||
|
||||
|
||||
class LayerNormBasicLSTMCellTest(test.TestCase):
|
||||
|
||||
|
@ -33,6 +33,7 @@ from tensorflow.python.ops import clip_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import variable_scope as vs
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import nest
|
||||
@ -1683,3 +1684,178 @@ class CompiledWrapper(core_rnn_cell.RNNCell):
|
||||
|
||||
with jit.experimental_jit_scope(compile_ops=compile_ops):
|
||||
return self._cell(inputs, state, scope=scope)
|
||||
|
||||
|
||||
def _random_exp_initializer(minval,
|
||||
maxval,
|
||||
seed=None,
|
||||
dtype=dtypes.float32):
|
||||
"""Returns an exponential distribution initializer.
|
||||
|
||||
Args:
|
||||
minval: float or a scalar float Tensor. With value > 0. Lower bound of the
|
||||
range of random values to generate.
|
||||
maxval: float or a scalar float Tensor. With value > minval. Upper bound of
|
||||
the range of random values to generate.
|
||||
seed: An integer. Used to create random seeds.
|
||||
dtype: The data type.
|
||||
|
||||
Returns:
|
||||
An initializer that generates tensors with an exponential distribution.
|
||||
"""
|
||||
|
||||
def _initializer(shape, dtype=dtype, partition_info=None):
|
||||
del partition_info # Unused.
|
||||
return math_ops.exp(
|
||||
random_ops.random_uniform(
|
||||
shape,
|
||||
math_ops.log(minval),
|
||||
math_ops.log(maxval),
|
||||
dtype,
|
||||
seed=seed))
|
||||
|
||||
return _initializer
|
||||
|
||||
|
||||
class PhasedLSTMCell(core_rnn_cell.RNNCell):
|
||||
"""Phased LSTM recurrent network cell.
|
||||
|
||||
https://arxiv.org/pdf/1610.09513v1.pdf
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_units,
|
||||
use_peepholes=False,
|
||||
leak=0.001,
|
||||
ratio_on=0.1,
|
||||
trainable_ratio_on=True,
|
||||
period_init_min=1.0,
|
||||
period_init_max=1000.0,
|
||||
reuse=None):
|
||||
"""Initialize the Phased LSTM cell.
|
||||
|
||||
Args:
|
||||
num_units: int, The number of units in the Phased LSTM cell.
|
||||
use_peepholes: bool, set True to enable peephole connections.
|
||||
leak: float or scalar float Tensor with value in [0, 1]. Leak applied
|
||||
during training.
|
||||
ratio_on: float or scalar float Tensor with value in [0, 1]. Ratio of the
|
||||
period during which the gates are open.
|
||||
trainable_ratio_on: bool, weather ratio_on is trainable.
|
||||
period_init_min: float or scalar float Tensor. With value > 0.
|
||||
Minimum value of the initalized period.
|
||||
The period values are initialized by drawing from the distribution:
|
||||
e^U(log(period_init_min), log(period_init_max))
|
||||
Where U(.,.) is the uniform distribution.
|
||||
period_init_max: float or scalar float Tensor.
|
||||
With value > period_init_min. Maximum value of the initalized period.
|
||||
reuse: (optional) Python boolean describing whether to reuse variables
|
||||
in an existing scope. If not `True`, and the existing scope already has
|
||||
the given variables, an error is raised.
|
||||
"""
|
||||
self._num_units = num_units
|
||||
self._use_peepholes = use_peepholes
|
||||
self._leak = leak
|
||||
self._ratio_on = ratio_on
|
||||
self._trainable_ratio_on = trainable_ratio_on
|
||||
self._period_init_min = period_init_min
|
||||
self._period_init_max = period_init_max
|
||||
self._reuse = reuse
|
||||
|
||||
@property
|
||||
def state_size(self):
|
||||
return core_rnn_cell.LSTMStateTuple(self._num_units, self._num_units)
|
||||
|
||||
@property
|
||||
def output_size(self):
|
||||
return self._num_units
|
||||
|
||||
def _mod(self, x, y):
|
||||
"""Modulo function that propagates x gradients."""
|
||||
return array_ops.stop_gradient(math_ops.mod(x, y) - x) + x
|
||||
|
||||
def _get_cycle_ratio(self, time, phase, period):
|
||||
"""Compute the cycle ratio in the dtype of the time."""
|
||||
phase_casted = math_ops.cast(phase, dtype=time.dtype)
|
||||
period_casted = math_ops.cast(period, dtype=time.dtype)
|
||||
shifted_time = time - phase_casted
|
||||
cycle_ratio = self._mod(shifted_time, period_casted) / period_casted
|
||||
return math_ops.cast(cycle_ratio, dtype=dtypes.float32)
|
||||
|
||||
def __call__(self, inputs, state, scope=None):
|
||||
"""Phased LSTM Cell.
|
||||
|
||||
Args:
|
||||
inputs: A tuple of 2 Tensor.
|
||||
The first Tensor has shape [batch, 1], and type float32 or float64.
|
||||
It stores the time.
|
||||
The second Tensor has shape [batch, features_size], and type float32.
|
||||
It stores the features.
|
||||
state: core_rnn_cell.LSTMStateTuple, state from previous timestep.
|
||||
scope: string, id of the variable scope.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- A Tensor of float32, and shape [batch_size, num_units], representing the
|
||||
output of the cell.
|
||||
- A core_rnn_cell.LSTMStateTuple, containing 2 Tensors of float32, shape
|
||||
[batch_size, num_units], representing the new state and the output.
|
||||
"""
|
||||
with _checked_scope(self, scope or "phased_lstm_cell", reuse=self._reuse):
|
||||
(c_prev, h_prev) = state
|
||||
(time, x) = inputs
|
||||
|
||||
in_mask_gates = [x, h_prev]
|
||||
if self._use_peepholes:
|
||||
in_mask_gates.append(c_prev)
|
||||
|
||||
with vs.variable_scope("mask_gates"):
|
||||
mask_gates = math_ops.sigmoid(
|
||||
_linear(in_mask_gates, 2 * self._num_units, True))
|
||||
[input_gate, forget_gate] = array_ops.split(
|
||||
axis=1, num_or_size_splits=2, value=mask_gates)
|
||||
|
||||
with vs.variable_scope("new_input"):
|
||||
new_input = math_ops.tanh(
|
||||
_linear([x, h_prev], self._num_units, True))
|
||||
|
||||
new_c = (c_prev * forget_gate + input_gate * new_input)
|
||||
|
||||
in_out_gate = [x, h_prev]
|
||||
if self._use_peepholes:
|
||||
in_out_gate.append(new_c)
|
||||
|
||||
with vs.variable_scope("output_gate"):
|
||||
output_gate = math_ops.sigmoid(
|
||||
_linear(in_out_gate, self._num_units, True))
|
||||
|
||||
new_h = math_ops.tanh(new_c) * output_gate
|
||||
|
||||
period = vs.get_variable(
|
||||
"period", [self._num_units],
|
||||
initializer=_random_exp_initializer(
|
||||
self._period_init_min, self._period_init_max))
|
||||
phase = vs.get_variable(
|
||||
"phase", [self._num_units],
|
||||
initializer=init_ops.random_uniform_initializer(
|
||||
0., period.initial_value))
|
||||
ratio_on = vs.get_variable(
|
||||
"ratio_on", [self._num_units],
|
||||
initializer=init_ops.constant_initializer(self._ratio_on),
|
||||
trainable=self._trainable_ratio_on)
|
||||
|
||||
cycle_ratio = self._get_cycle_ratio(time, phase, period)
|
||||
|
||||
k_up = 2 * cycle_ratio / ratio_on
|
||||
k_down = 2 - k_up
|
||||
k_closed = self._leak * cycle_ratio
|
||||
|
||||
k = array_ops.where(cycle_ratio < ratio_on, k_down, k_closed)
|
||||
k = array_ops.where(cycle_ratio < 0.5 * ratio_on, k_up, k)
|
||||
|
||||
new_c = k * new_c + (1 - k) * c_prev
|
||||
new_h = k * new_h + (1 - k) * h_prev
|
||||
|
||||
new_state = core_rnn_cell.LSTMStateTuple(new_c, new_h)
|
||||
|
||||
return new_h, new_state
|
||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
# pylint: enable=unused-import
|
||||
|
||||
import sys
|
||||
import functools
|
||||
|
||||
import numpy as np
|
||||
@ -38,6 +39,15 @@ from tensorflow.python.util import nest
|
||||
# pylint: enable=g-import-not-at-top
|
||||
|
||||
|
||||
# for testing
|
||||
AttentionWrapperState = wrapper.AttentionWrapperState # pylint: disable=invalid-name
|
||||
LSTMStateTuple = core_rnn_cell.LSTMStateTuple # pylint: disable=invalid-name
|
||||
BasicDecoderOutput = basic_decoder.BasicDecoderOutput # pylint: disable=invalid-name
|
||||
float32 = np.float32
|
||||
int32 = np.int32
|
||||
array = np.array
|
||||
|
||||
|
||||
class AttentionWrapperTest(test.TestCase):
|
||||
|
||||
def assertAllClose(self, *args, **kwargs):
|
||||
@ -48,10 +58,11 @@ class AttentionWrapperTest(test.TestCase):
|
||||
|
||||
def _testWithAttention(self,
|
||||
create_attention_mechanism,
|
||||
expected_final_outputs,
|
||||
expected_final_output,
|
||||
expected_final_state,
|
||||
attention_mechanism_depth=3,
|
||||
attention_history=False):
|
||||
attention_history=False,
|
||||
name=""):
|
||||
encoder_sequence_length = [3, 2, 3, 1, 0]
|
||||
decoder_sequence_length = [2, 0, 1, 2, 3]
|
||||
batch_size = 5
|
||||
@ -126,7 +137,13 @@ class AttentionWrapperTest(test.TestCase):
|
||||
"state_attention_history": state_attention_history,
|
||||
})
|
||||
|
||||
nest.map_structure(self.assertAllClose, expected_final_outputs,
|
||||
print("Copy/paste (%s)\nexpected_final_output = " % name,
|
||||
sess_results["final_outputs"])
|
||||
sys.stdout.flush()
|
||||
print("Copy/paste (%s)\nexpected_final_state = " % name,
|
||||
sess_results["final_state"])
|
||||
sys.stdout.flush()
|
||||
nest.map_structure(self.assertAllClose, expected_final_output,
|
||||
sess_results["final_outputs"])
|
||||
nest.map_structure(self.assertAllClose, expected_final_state,
|
||||
sess_results["final_state"])
|
||||
@ -137,533 +154,534 @@ class AttentionWrapperTest(test.TestCase):
|
||||
np.transpose(sess_results["final_outputs"].rnn_output,
|
||||
(1, 0, 2)))
|
||||
|
||||
def testBahndahauNotNormalized(self):
|
||||
def testBahdanauNotNormalized(self):
|
||||
create_attention_mechanism = wrapper.BahdanauAttention
|
||||
|
||||
array = np.array
|
||||
float32 = np.float32
|
||||
int32 = np.int32
|
||||
|
||||
expected_final_outputs = basic_decoder.BasicDecoderOutput(
|
||||
expected_final_output = BasicDecoderOutput(
|
||||
rnn_output=array(
|
||||
[[[
|
||||
1.25166783e-02, -6.88887993e-03, 3.17239435e-03,
|
||||
-1.98234897e-03, 4.77387803e-03, -1.38330357e-02
|
||||
1.89980457e-03, 1.89681584e-03, 2.05339328e-03, -3.83376027e-03,
|
||||
-4.31808922e-03, -6.45466987e-03
|
||||
], [
|
||||
1.28883058e-02, -6.76271692e-03, 3.13419267e-03,
|
||||
-2.02183682e-03, 5.62057737e-03, -1.35373026e-02
|
||||
2.27232254e-03, 2.02509761e-03, 2.01666891e-03, -3.87230632e-03,
|
||||
-3.47119337e-03, -6.15991233e-03
|
||||
], [
|
||||
1.24917831e-02, -6.71574520e-03, 3.42238229e-03,
|
||||
-1.79501204e-03, 5.33161033e-03, -1.36620644e-02
|
||||
1.87640532e-03, 2.07374478e-03, 2.30582547e-03, -3.64564802e-03,
|
||||
-3.75995948e-03, -6.28685066e-03
|
||||
]], [[
|
||||
1.55150667e-02, -1.07274549e-02, 4.44198400e-03,
|
||||
-9.73310322e-04, 1.27242506e-02, -1.21861566e-02
|
||||
4.89835022e-03, -1.94158917e-03, 3.32316267e-03,
|
||||
-2.82446202e-03, 3.63192149e-03, -4.80734091e-03
|
||||
], [
|
||||
1.57585666e-02, -1.07965544e-02, 4.61554807e-03,
|
||||
-1.01510016e-03, 1.22341057e-02, -1.27029382e-02
|
||||
5.14256489e-03, -2.00877781e-03, 3.49807227e-03,
|
||||
-2.86567654e-03, 3.14202951e-03, -5.32575324e-03
|
||||
], [
|
||||
1.58304181e-02, -1.09712025e-02, 4.67861444e-03,
|
||||
-1.03920139e-03, 1.23004699e-02, -1.25949886e-02
|
||||
5.21511910e-03, -2.18198029e-03, 3.56219849e-03,
|
||||
-2.88951304e-03, 3.20866983e-03, -5.21918852e-03
|
||||
]], [[
|
||||
9.26700700e-03, -9.75431874e-03, -9.95740294e-04,
|
||||
-1.27463136e-06, 3.81659716e-03, -1.64887272e-02
|
||||
-1.34951377e-03, -9.68646549e-04, -2.11444520e-03,
|
||||
-1.85243192e-03, -5.27541339e-03, -9.10969637e-03
|
||||
], [
|
||||
9.25191958e-03, -9.80092678e-03, -8.48566880e-04,
|
||||
5.02091134e-05, 3.46567202e-03, -1.67435352e-02
|
||||
-1.36390887e-03, -1.01293903e-03, -1.96592091e-03,
|
||||
-1.80044665e-03, -5.62618347e-03, -9.36636236e-03
|
||||
], [
|
||||
9.48173273e-03, -9.52653307e-03, -8.79382715e-04,
|
||||
-3.07094306e-05, 4.05955408e-03, -1.67226996e-02
|
||||
-1.13357347e-03, -7.37126335e-04, -1.99582824e-03,
|
||||
-1.88097963e-03, -5.03196474e-03, -9.34652984e-03
|
||||
]], [[
|
||||
1.21462569e-02, -1.27578378e-02, 1.54045003e-04, 2.70257704e-03,
|
||||
7.79421115e-03, -8.14041123e-04
|
||||
1.52963377e-03, -3.97205260e-03, -9.64675564e-04,
|
||||
8.51404853e-04, -1.29804458e-03, 6.56467676e-03
|
||||
], [
|
||||
1.18412934e-02, -1.33513296e-02, 3.54760559e-05, 2.67801876e-03,
|
||||
6.99122995e-03, -9.46014654e-04
|
||||
1.22557906e-03, -4.56343032e-03, -1.08188344e-03,
|
||||
8.27252632e-04, -2.10058759e-03, 6.43082103e-03
|
||||
], [
|
||||
1.16087487e-02, -1.31632648e-02, -2.98853614e-04,
|
||||
2.49515846e-03, 6.92677684e-03, -6.92734495e-04
|
||||
9.93478228e-04, -4.37378604e-03, -1.41531695e-03,
|
||||
6.44775166e-04, -2.16480484e-03, 6.68286439e-03
|
||||
]], [[
|
||||
1.02377674e-02, -8.72955937e-03, 1.22555892e-03, 2.03830865e-03,
|
||||
8.93574394e-03, -7.28237582e-03
|
||||
-3.78854020e-04, 5.62231544e-05, 1.06837302e-04, 1.87137164e-04,
|
||||
-1.56512906e-04, 9.63474595e-05
|
||||
], [
|
||||
1.05115287e-02, -8.92531779e-03, 1.14568521e-03, 1.91635895e-03,
|
||||
8.94328393e-03, -7.39541650e-03
|
||||
-1.04306288e-04, -1.37411975e-04, 2.82689070e-05,
|
||||
6.56487318e-05, -1.48634164e-04, -1.84347919e-05
|
||||
], [
|
||||
1.07398070e-02, -8.56867433e-03, 1.52354129e-03, 2.06834078e-03,
|
||||
9.36511997e-03, -7.64556089e-03
|
||||
1.24452345e-04, 2.20821079e-04, 4.07114130e-04, 2.18028668e-04,
|
||||
2.73401442e-04, -2.69805576e-04
|
||||
]]],
|
||||
dtype=float32),
|
||||
sample_id=array(
|
||||
[[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
|
||||
[[2, 0, 2], [0, 0, 0], [1, 1, 1], [5, 5, 5], [3, 3, 2]],
|
||||
dtype=int32))
|
||||
|
||||
expected_final_state = wrapper.AttentionWrapperState(
|
||||
time=3,
|
||||
attention_history=(),
|
||||
cell_state=core_rnn_cell.LSTMStateTuple(
|
||||
expected_final_state = AttentionWrapperState(
|
||||
cell_state=LSTMStateTuple(
|
||||
c=array(
|
||||
[[
|
||||
-0.0220502, -0.008058, -0.00160266, 0.01609341, -0.01380513,
|
||||
-0.00749483, -0.00816989, -0.01210028, 0.01795324
|
||||
-2.18963176e-02, -8.04424379e-03, -1.48289464e-03,
|
||||
1.61068402e-02, -1.37983467e-02, -7.57976994e-03,
|
||||
-8.28560349e-03, -1.18737305e-02, 1.78835373e-02
|
||||
], [
|
||||
0.01727026, -0.0142065, -0.00399991, 0.03195379,
|
||||
-0.03547479, -0.02138772, -0.00610318, -0.00191625,
|
||||
-0.01937846
|
||||
1.74205080e-02, -1.41929444e-02, -3.88092734e-03,
|
||||
3.19708064e-02, -3.54689620e-02, -2.14698724e-02,
|
||||
-6.21716119e-03, -1.69295724e-03, -1.94495302e-02
|
||||
], [
|
||||
-0.0116077, 0.00876439, -0.01641787, -0.01400803,
|
||||
0.01347527, -0.01036386, 0.00627491, -0.0096361, -0.00650565
|
||||
-1.14528481e-02, 8.77819210e-03, -1.62970200e-02,
|
||||
-1.39963552e-02, 1.34831406e-02, -1.04494914e-02,
|
||||
6.16127765e-03, -9.41022579e-03, -6.57590060e-03
|
||||
], [
|
||||
-0.04763387, -0.01192631, -0.00019412, 0.04103886,
|
||||
-0.00137999, 0.02126684, -0.02793711, -0.05467696,
|
||||
-0.02912051
|
||||
-4.74753827e-02, -1.19123599e-02, -7.40140676e-05,
|
||||
4.10552323e-02, -1.36711076e-03, 2.11795457e-02,
|
||||
-2.80460119e-02, -5.44509329e-02, -2.91906092e-02
|
||||
], [
|
||||
0.02241185, -0.00141741, 0.01911988, 0.00547728,
|
||||
-0.01280068, -0.00307024, -0.00494239, 0.02169247,
|
||||
0.01631995
|
||||
2.25644894e-02, -1.40382675e-03, 1.92396250e-02,
|
||||
5.49034867e-03, -1.27930511e-02, -3.15603940e-03,
|
||||
-5.05525898e-03, 2.19191350e-02, 1.62497871e-02
|
||||
]],
|
||||
dtype=float32),
|
||||
h=array(
|
||||
[[
|
||||
-1.10613741e-02, -3.98175791e-03, -8.15514475e-04,
|
||||
7.90482666e-03, -7.02390168e-03, -3.76394135e-03,
|
||||
-4.16183751e-03, -6.17114361e-03, 8.95532221e-03
|
||||
-1.09840557e-02, -3.97477299e-03, -7.54582870e-04,
|
||||
7.91188516e-03, -7.02184858e-03, -3.80711886e-03,
|
||||
-4.22059745e-03, -6.05464494e-03, 8.92061181e-03
|
||||
], [
|
||||
8.60657450e-03, -7.17655150e-03, -1.94156705e-03,
|
||||
1.62583217e-02, -1.76821016e-02, -1.06200138e-02,
|
||||
-3.01904045e-03, -9.57608980e-04, -9.95732192e-03
|
||||
8.68131686e-03, -7.16938032e-03, -1.88384682e-03,
|
||||
1.62678920e-02, -1.76827926e-02, -1.06622791e-02,
|
||||
-3.07528162e-03, -8.45885137e-04, -9.99388192e-03
|
||||
], [
|
||||
-5.78935863e-03, 4.49362956e-03, -8.13615043e-03,
|
||||
-6.95384294e-03, 6.75151078e-03, -5.07845683e-03,
|
||||
3.11869266e-03, -4.72904649e-03, -3.20469099e-03
|
||||
-5.71205560e-03, 4.50050412e-03, -8.07640795e-03,
|
||||
-6.94844872e-03, 6.75682165e-03, -5.12113515e-03,
|
||||
3.06208082e-03, -4.61743120e-03, -3.23931244e-03
|
||||
], [
|
||||
-2.38025561e-02, -5.89242764e-03, -9.76260417e-05,
|
||||
2.01697368e-02, -6.82076614e-04, 1.07111251e-02,
|
||||
-1.42077375e-02, -2.70790439e-02, -1.44685479e-02
|
||||
-2.37231534e-02, -5.88526297e-03, -3.72226204e-05,
|
||||
2.01789513e-02, -6.75848918e-04, 1.06686354e-02,
|
||||
-1.42624676e-02, -2.69628745e-02, -1.45034352e-02
|
||||
], [
|
||||
1.11825848e-02, -6.99267141e-04, 9.82748345e-03,
|
||||
2.74566701e-03, -6.56377291e-03, -1.53681310e-03,
|
||||
-2.48806458e-03, 1.10462429e-02, 7.97568541e-03
|
||||
1.12585640e-02, -6.92534202e-04, 9.88917705e-03,
|
||||
2.75237625e-03, -6.56115822e-03, -1.57997780e-03,
|
||||
-2.54477374e-03, 1.11598391e-02, 7.94144534e-03
|
||||
]],
|
||||
dtype=float32)),
|
||||
attention=array(
|
||||
[[
|
||||
1.24917831e-02, -6.71574520e-03, 3.42238229e-03,
|
||||
-1.79501204e-03, 5.33161033e-03, -1.36620644e-02
|
||||
0.00187641, 0.00207374, 0.00230583, -0.00364565, -0.00375996,
|
||||
-0.00628685
|
||||
], [
|
||||
1.58304181e-02, -1.09712025e-02, 4.67861444e-03,
|
||||
-1.03920139e-03, 1.23004699e-02, -1.25949886e-02
|
||||
0.00521512, -0.00218198, 0.0035622, -0.00288951, 0.00320867,
|
||||
-0.00521919
|
||||
], [
|
||||
9.48173273e-03, -9.52653307e-03, -8.79382715e-04,
|
||||
-3.07094306e-05, 4.05955408e-03, -1.67226996e-02
|
||||
-0.00113357, -0.00073713, -0.00199583, -0.00188098, -0.00503196,
|
||||
-0.00934653
|
||||
], [
|
||||
1.16087487e-02, -1.31632648e-02, -2.98853614e-04,
|
||||
2.49515846e-03, 6.92677684e-03, -6.92734495e-04
|
||||
0.00099348, -0.00437379, -0.00141532, 0.00064478, -0.0021648,
|
||||
0.00668286
|
||||
], [
|
||||
1.07398070e-02, -8.56867433e-03, 1.52354129e-03, 2.06834078e-03,
|
||||
9.36511997e-03, -7.64556089e-03
|
||||
0.00012445, 0.00022082, 0.00040711, 0.00021803, 0.0002734,
|
||||
-0.00026981
|
||||
]],
|
||||
dtype=float32))
|
||||
self._testWithAttention(create_attention_mechanism, expected_final_outputs,
|
||||
expected_final_state, attention_history=True)
|
||||
dtype=float32),
|
||||
time=3,
|
||||
attention_history=())
|
||||
|
||||
def testBahndahauNormalized(self):
|
||||
self._testWithAttention(
|
||||
create_attention_mechanism,
|
||||
expected_final_output,
|
||||
expected_final_state,
|
||||
attention_history=True,
|
||||
name="testBahdanauNotNormalized")
|
||||
|
||||
def testBahdanauNormalized(self):
|
||||
create_attention_mechanism = functools.partial(
|
||||
wrapper.BahdanauAttention, normalize=True, attention_r_initializer=2.0)
|
||||
wrapper.BahdanauAttention, normalize=True)
|
||||
|
||||
array = np.array
|
||||
float32 = np.float32
|
||||
int32 = np.int32
|
||||
|
||||
expected_final_output = basic_decoder.BasicDecoderOutput(
|
||||
expected_final_output = BasicDecoderOutput(
|
||||
rnn_output=array(
|
||||
[[[
|
||||
1.72670335e-02, -5.83671592e-03, 6.38638902e-03,
|
||||
-8.11776379e-04, 1.12681929e-03, -1.24236047e-02
|
||||
6.64783875e-03, 2.94425711e-03, 5.26542449e-03, -2.64955591e-03,
|
||||
-7.95925129e-03, -5.02286293e-03
|
||||
], [
|
||||
1.75918192e-02, -5.73426578e-03, 6.29768707e-03,
|
||||
-8.63141613e-04, 2.03352375e-03, -1.21420780e-02
|
||||
7.01954123e-03, 3.07301106e-03, 5.22849336e-03, -2.68844375e-03,
|
||||
-7.11239874e-03, -4.72904276e-03
|
||||
], [
|
||||
1.72424167e-02, -5.66471322e-03, 6.63427915e-03,
|
||||
-6.23903936e-04, 1.68706616e-03, -1.22524602e-02
|
||||
6.62360899e-03, 3.12234787e-03, 5.51807694e-03, -2.46222341e-03,
|
||||
-7.40198931e-03, -4.85701021e-03
|
||||
]], [[
|
||||
1.79958157e-02, -9.80986748e-03, 4.73218597e-03,
|
||||
-3.89962713e-03, 1.41502675e-02, -1.48344040e-02
|
||||
7.37589924e-03, -1.02620223e-03, 3.61374952e-03,
|
||||
-5.74620720e-03, 5.05625410e-03, -7.45209027e-03
|
||||
], [
|
||||
1.82184577e-02, -9.88379307e-03, 4.90130857e-03,
|
||||
-3.91892251e-03, 1.36479288e-02, -1.53291579e-02
|
||||
7.61946291e-03, -1.09287468e-03, 3.78817180e-03,
|
||||
-5.78709645e-03, 4.56611114e-03, -7.96987582e-03
|
||||
], [
|
||||
1.83001235e-02, -1.00617753e-02, 4.97077405e-03,
|
||||
-3.94908339e-03, 1.37211196e-02, -1.52311027e-02
|
||||
7.69207766e-03, -1.26582675e-03, 3.85218812e-03,
|
||||
-5.81111759e-03, 4.63287206e-03, -7.86337163e-03
|
||||
]], [[
|
||||
7.93476030e-03, -8.46967567e-03, -7.16930721e-04,
|
||||
4.37953044e-04, 1.04503892e-03, -1.82424393e-02
|
||||
-2.69413739e-03, 3.47183552e-04, -1.82145904e-03,
|
||||
-1.39805069e-03, -8.05486552e-03, -1.08372131e-02
|
||||
], [
|
||||
7.90629163e-03, -8.48819874e-03, -5.57833235e-04,
|
||||
5.02390554e-04, 6.79406337e-04, -1.84837580e-02
|
||||
-2.70848931e-03, 3.03293345e-04, -1.67230750e-03,
|
||||
-1.34555507e-03, -8.40565283e-03, -1.10935047e-02
|
||||
], [
|
||||
8.14734399e-03, -8.23053624e-03, -5.92814526e-04,
|
||||
4.16347990e-04, 1.29250437e-03, -1.84548404e-02
|
||||
-2.47822329e-03, 5.79408603e-04, -1.70188327e-03,
|
||||
-1.42583530e-03, -7.81180616e-03, -1.10740755e-02
|
||||
]], [[
|
||||
1.21026095e-02, -1.26739489e-02, 1.78718648e-04, 2.68748170e-03,
|
||||
7.80996867e-03, -9.69076063e-04
|
||||
1.48582947e-03, -3.88786104e-03, -9.39912978e-04,
|
||||
8.36255029e-04, -1.28223014e-03, 6.40908210e-03
|
||||
], [
|
||||
1.17978491e-02, -1.32678337e-02, 6.00410858e-05, 2.66301399e-03,
|
||||
7.00691342e-03, -1.10030361e-03
|
||||
1.18177081e-03, -4.47923271e-03, -1.05711201e-03,
|
||||
8.12121783e-04, -2.08477327e-03, 6.27523474e-03
|
||||
], [
|
||||
1.15651665e-02, -1.30795036e-02, -2.74205930e-04,
|
||||
2.48012133e-03, 6.94250735e-03, -8.47495161e-04
|
||||
9.49664740e-04, -4.28957958e-03, -1.39053771e-03,
|
||||
6.29657647e-04, -2.14899099e-03, 6.52727811e-03
|
||||
]], [[
|
||||
1.02377674e-02, -8.72955937e-03, 1.22555892e-03, 2.03830865e-03,
|
||||
8.93574394e-03, -7.28237582e-03
|
||||
-3.78854020e-04, 5.62231544e-05, 1.06837302e-04, 1.87137164e-04,
|
||||
-1.56512906e-04, 9.63474595e-05
|
||||
], [
|
||||
1.05115287e-02, -8.92531779e-03, 1.14568521e-03, 1.91635895e-03,
|
||||
8.94328393e-03, -7.39541650e-03
|
||||
-1.04306288e-04, -1.37411975e-04, 2.82689070e-05,
|
||||
6.56487318e-05, -1.48634164e-04, -1.84347919e-05
|
||||
], [
|
||||
1.07398070e-02, -8.56867433e-03, 1.52354129e-03, 2.06834078e-03,
|
||||
9.36511997e-03, -7.64556089e-03
|
||||
1.24452345e-04, 2.20821079e-04, 4.07114130e-04, 2.18028668e-04,
|
||||
2.73401442e-04, -2.69805576e-04
|
||||
]]],
|
||||
dtype=float32),
|
||||
sample_id=array(
|
||||
[[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
|
||||
[[0, 0, 0], [0, 0, 0], [1, 1, 1], [5, 5, 5], [3, 3, 2]],
|
||||
dtype=int32))
|
||||
|
||||
expected_final_state = wrapper.AttentionWrapperState(
|
||||
time=3,
|
||||
attention_history=(),
|
||||
cell_state=core_rnn_cell.LSTMStateTuple(
|
||||
expected_final_state = AttentionWrapperState(
|
||||
cell_state=LSTMStateTuple(
|
||||
c=array(
|
||||
[[
|
||||
-0.02209264, -0.00794879, -0.00157153, 0.01614309,
|
||||
-0.01383773, -0.00750943, -0.00824213, -0.01210296,
|
||||
0.01794949
|
||||
-2.19389871e-02, -7.93421268e-03, -1.45148858e-03,
|
||||
1.61569901e-02, -1.38310911e-02, -7.59426132e-03,
|
||||
-8.35836027e-03, -1.18763093e-02, 1.78797375e-02
|
||||
], [
|
||||
0.01726926, -0.01418139, -0.0040099, 0.0319339, -0.03545783,
|
||||
-0.02142831, -0.00609501, -0.00195033, -0.01938949
|
||||
1.74194798e-02, -1.41677596e-02, -3.89095861e-03,
|
||||
3.19508761e-02, -3.54519747e-02, -2.15105712e-02,
|
||||
-6.20894879e-03, -1.72719418e-03, -1.94605980e-02
|
||||
], [
|
||||
-0.01159083, 0.0087524, -0.01639001, -0.01400012,
|
||||
0.01342422, -0.01041037, 0.00620991, -0.00960796,
|
||||
-0.00650131
|
||||
-1.14357909e-02, 8.76635592e-03, -1.62690803e-02,
|
||||
-1.39883338e-02, 1.34323873e-02, -1.04959216e-02,
|
||||
6.09614328e-03, -9.38197412e-03, -6.57159975e-03
|
||||
], [
|
||||
-0.04763237, -0.01192762, -0.00019377, 0.04103839,
|
||||
-0.00138058, 0.02126443, -0.02793917, -0.05467755,
|
||||
-0.02912025
|
||||
-4.74738739e-02, -1.19136795e-02, -7.36564398e-05,
|
||||
4.10547666e-02, -1.36771239e-03, 2.11771261e-02,
|
||||
-2.80481018e-02, -5.44515178e-02, -2.91903559e-02
|
||||
], [
|
||||
0.02241185, -0.00141741, 0.01911988, 0.00547728,
|
||||
-0.01280068, -0.00307024, -0.00494239, 0.02169247,
|
||||
0.01631995
|
||||
2.25644894e-02, -1.40382675e-03, 1.92396250e-02,
|
||||
5.49034867e-03, -1.27930511e-02, -3.15603940e-03,
|
||||
-5.05525898e-03, 2.19191350e-02, 1.62497871e-02
|
||||
]],
|
||||
dtype=float32),
|
||||
h=array(
|
||||
[[
|
||||
-1.10821165e-02, -3.92766716e-03, -7.99638336e-04,
|
||||
7.92923011e-03, -7.04019284e-03, -3.77124036e-03,
|
||||
-4.19876305e-03, -6.17261464e-03, 8.95325281e-03
|
||||
-1.10049099e-02, -3.92028037e-03, -7.38571223e-04,
|
||||
7.93652050e-03, -7.03821564e-03, -3.81436548e-03,
|
||||
-4.25778655e-03, -6.05606195e-03, 8.91851448e-03
|
||||
], [
|
||||
8.60597286e-03, -7.16368994e-03, -1.94644753e-03,
|
||||
1.62479617e-02, -1.76739115e-02, -1.06403306e-02,
|
||||
-3.01484042e-03, -9.74688213e-04, -9.96260438e-03
|
||||
8.68070032e-03, -7.15647917e-03, -1.88874488e-03,
|
||||
1.62575077e-02, -1.76745858e-02, -1.06826536e-02,
|
||||
-3.07105901e-03, -8.63034453e-04, -9.99918394e-03
|
||||
], [
|
||||
-5.78098884e-03, 4.48751403e-03, -8.12216662e-03,
|
||||
-6.94991415e-03, 6.72604749e-03, -5.10144979e-03,
|
||||
3.08637507e-03, -4.71517537e-03, -3.20256175e-03
|
||||
-5.70359221e-03, 4.49446775e-03, -8.06238409e-03,
|
||||
-6.94446685e-03, 6.73149945e-03, -5.14409645e-03,
|
||||
3.02969781e-03, -4.60351165e-03, -3.23720207e-03
|
||||
], [
|
||||
-2.38018110e-02, -5.89307398e-03, -9.74484938e-05,
|
||||
2.01694984e-02, -6.82370039e-04, 1.07099237e-02,
|
||||
-1.42087601e-02, -2.70793457e-02, -1.44684138e-02
|
||||
-2.37224046e-02, -5.88591257e-03, -3.70427515e-05,
|
||||
2.01787166e-02, -6.76146999e-04, 1.06674293e-02,
|
||||
-1.42635051e-02, -2.69631781e-02, -1.45033030e-02
|
||||
], [
|
||||
1.11825848e-02, -6.99267141e-04, 9.82748345e-03,
|
||||
2.74566701e-03, -6.56377291e-03, -1.53681310e-03,
|
||||
-2.48806458e-03, 1.10462429e-02, 7.97568541e-03
|
||||
1.12585640e-02, -6.92534202e-04, 9.88917705e-03,
|
||||
2.75237625e-03, -6.56115822e-03, -1.57997780e-03,
|
||||
-2.54477374e-03, 1.11598391e-02, 7.94144534e-03
|
||||
]],
|
||||
dtype=float32)),
|
||||
attention=array(
|
||||
[[
|
||||
0.01724242, -0.00566471, 0.00663428, -0.0006239, 0.00168707,
|
||||
-0.01225246
|
||||
0.00662361, 0.00312235, 0.00551808, -0.00246222, -0.00740199,
|
||||
-0.00485701
|
||||
], [
|
||||
0.01830012, -0.01006178, 0.00497077, -0.00394908, 0.01372112,
|
||||
-0.0152311
|
||||
0.00769208, -0.00126583, 0.00385219, -0.00581112, 0.00463287,
|
||||
-0.00786337
|
||||
], [
|
||||
0.00814734, -0.00823054, -0.00059281, 0.00041635, 0.0012925,
|
||||
-0.01845484
|
||||
-0.00247822, 0.00057941, -0.00170188, -0.00142584, -0.00781181,
|
||||
-0.01107408
|
||||
], [
|
||||
0.01156517, -0.0130795, -0.00027421, 0.00248012, 0.00694251,
|
||||
-0.0008475
|
||||
0.00094966, -0.00428958, -0.00139054, 0.00062966, -0.00214899,
|
||||
0.00652728
|
||||
], [
|
||||
0.01073981, -0.00856867, 0.00152354, 0.00206834, 0.00936512,
|
||||
-0.00764556
|
||||
0.00012445, 0.00022082, 0.00040711, 0.00021803, 0.0002734,
|
||||
-0.00026981
|
||||
]],
|
||||
dtype=float32))
|
||||
dtype=float32),
|
||||
time=3,
|
||||
attention_history=())
|
||||
|
||||
self._testWithAttention(create_attention_mechanism, expected_final_output,
|
||||
expected_final_state)
|
||||
self._testWithAttention(
|
||||
create_attention_mechanism,
|
||||
expected_final_output,
|
||||
expected_final_state,
|
||||
name="testBahdanauNormalized")
|
||||
|
||||
def testLuongNotNormalized(self):
|
||||
create_attention_mechanism = wrapper.LuongAttention
|
||||
|
||||
array = np.array
|
||||
float32 = np.float32
|
||||
int32 = np.int32
|
||||
|
||||
expected_final_output = basic_decoder.BasicDecoderOutput(
|
||||
expected_final_output = BasicDecoderOutput(
|
||||
rnn_output=array(
|
||||
[[[
|
||||
1.23641128e-02, -6.82715839e-03, 3.24165262e-03,
|
||||
-1.90772023e-03, 4.69654519e-03, -1.37025211e-02
|
||||
1.74749165e-03, 1.95862399e-03, 2.12293095e-03, -3.75889172e-03,
|
||||
-4.39571124e-03, -6.32379763e-03
|
||||
], [
|
||||
1.29463980e-02, -6.79699238e-03, 3.10124992e-03,
|
||||
-2.02869414e-03, 5.66399656e-03, -1.35517996e-02
|
||||
2.33045570e-03, 1.99094601e-03, 1.98377599e-03, -3.87950847e-03,
|
||||
-3.42792575e-03, -6.17497414e-03
|
||||
], [
|
||||
1.22659411e-02, -6.81970268e-03, 3.15135531e-03,
|
||||
-1.96937821e-03, 5.62768336e-03, -1.39173865e-02
|
||||
1.65032526e-03, 1.96972815e-03, 2.03462853e-03, -3.82007333e-03,
|
||||
-3.46369296e-03, -6.54224353e-03
|
||||
]], [[
|
||||
1.53944232e-02, -1.07725551e-02, 4.42822604e-03,
|
||||
-8.30623554e-04, 1.26549732e-02, -1.20573286e-02
|
||||
4.77780215e-03, -1.98677275e-03, 3.30950436e-03,
|
||||
-2.68179504e-03, 3.56271653e-03, -4.67860466e-03
|
||||
], [
|
||||
1.57453734e-02, -1.08157266e-02, 4.62466478e-03,
|
||||
-9.88351414e-04, 1.22286947e-02, -1.26876952e-02
|
||||
5.13039157e-03, -2.02797214e-03, 3.50760575e-03,
|
||||
-2.83981953e-03, 3.13726603e-03, -5.31156827e-03
|
||||
], [
|
||||
1.57857724e-02, -1.09536834e-02, 4.64798324e-03,
|
||||
-1.01319887e-03, 1.22695938e-02, -1.25500849e-02
|
||||
5.17205056e-03, -2.16446724e-03, 3.53219034e-03,
|
||||
-2.86490913e-03, 3.17879021e-03, -5.17592067e-03
|
||||
]], [[
|
||||
9.23123397e-03, -9.42669343e-03, -9.09919385e-04,
|
||||
6.09827694e-05, 3.90436035e-03, -1.63374804e-02
|
||||
-1.38538703e-03, -6.40910701e-04, -2.02864106e-03,
|
||||
-1.79018872e-03, -5.18789608e-03, -8.95875692e-03
|
||||
], [
|
||||
9.22935922e-03, -9.57853813e-03, -7.92966573e-04,
|
||||
8.89014918e-05, 3.52671882e-03, -1.66499857e-02
|
||||
-1.38620089e-03, -7.92010222e-04, -1.91070826e-03,
|
||||
-1.76206254e-03, -5.56525169e-03, -9.27332044e-03
|
||||
], [
|
||||
9.49526206e-03, -9.39475093e-03, -8.49372707e-04,
|
||||
-1.72815053e-05, 4.16132808e-03, -1.66336838e-02
|
||||
-1.11966045e-03, -6.07630936e-04, -1.96643686e-03,
|
||||
-1.86803937e-03, -4.93048411e-03, -9.25842486e-03
|
||||
]], [[
|
||||
1.21248290e-02, -1.27166547e-02, 1.66158192e-04, 2.69516627e-03,
|
||||
7.80194718e-03, -8.90152063e-04
|
||||
1.50820788e-03, -3.93087184e-03, -9.52563598e-04,
|
||||
8.43994785e-04, -1.29030924e-03, 6.48857141e-03
|
||||
], [
|
||||
1.17861275e-02, -1.32453050e-02, 6.66640699e-05, 2.65894993e-03,
|
||||
7.01114535e-03, -1.14195189e-03
|
||||
1.17029145e-03, -4.45716921e-03, -1.05062663e-03,
|
||||
8.08141369e-04, -2.08062865e-03, 6.23444980e-03
|
||||
], [
|
||||
1.15833860e-02, -1.31145213e-02, -2.84505659e-04,
|
||||
2.48642010e-03, 6.93593081e-03, -7.82784075e-04
|
||||
9.67921398e-04, -4.32466762e-03, -1.40085898e-03,
|
||||
6.35969569e-04, -2.15558149e-03, 6.59212377e-03
|
||||
]], [[
|
||||
1.02377674e-02, -8.72955937e-03, 1.22555892e-03, 2.03830865e-03,
|
||||
8.93574394e-03, -7.28237582e-03
|
||||
-3.78854020e-04, 5.62231544e-05, 1.06837302e-04, 1.87137164e-04,
|
||||
-1.56512906e-04, 9.63474595e-05
|
||||
], [
|
||||
1.05115287e-02, -8.92531779e-03, 1.14568521e-03, 1.91635895e-03,
|
||||
8.94328393e-03, -7.39541650e-03
|
||||
-1.04306288e-04, -1.37411975e-04, 2.82689070e-05,
|
||||
6.56487318e-05, -1.48634164e-04, -1.84347919e-05
|
||||
], [
|
||||
1.07398070e-02, -8.56867433e-03, 1.52354129e-03, 2.06834078e-03,
|
||||
9.36511997e-03, -7.64556089e-03
|
||||
1.24452345e-04, 2.20821079e-04, 4.07114130e-04, 2.18028668e-04,
|
||||
2.73401442e-04, -2.69805576e-04
|
||||
]]],
|
||||
dtype=float32),
|
||||
sample_id=array(
|
||||
[[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
|
||||
[[2, 0, 2], [0, 0, 0], [1, 1, 1], [5, 5, 5], [3, 3, 2]],
|
||||
dtype=int32))
|
||||
expected_final_state = wrapper.AttentionWrapperState(
|
||||
time=3,
|
||||
attention_history=(),
|
||||
cell_state=core_rnn_cell.LSTMStateTuple(
|
||||
|
||||
expected_final_state = AttentionWrapperState(
|
||||
cell_state=LSTMStateTuple(
|
||||
c=array(
|
||||
[[
|
||||
-0.02204997, -0.00805805, -0.00160245, 0.01609369,
|
||||
-0.01380494, -0.00749439, -0.00817, -0.01209992, 0.01795316
|
||||
-2.18960866e-02, -8.04429129e-03, -1.48267671e-03,
|
||||
1.61071159e-02, -1.37981661e-02, -7.57933082e-03,
|
||||
-8.28570686e-03, -1.18733812e-02, 1.78834442e-02
|
||||
], [
|
||||
0.01727016, -0.01420713, -0.00399972, 0.03195436,
|
||||
-0.03547532, -0.02138666, -0.00610335, -0.00191557,
|
||||
-0.01937821
|
||||
1.74204130e-02, -1.41935758e-02, -3.88074201e-03,
|
||||
3.19713727e-02, -3.54694910e-02, -2.14688145e-02,
|
||||
-6.21731905e-03, -1.69229065e-03, -1.94492843e-02
|
||||
], [
|
||||
-0.01160429, 0.00876595, -0.01641685, -0.01400784,
|
||||
0.01348004, -0.01036458, 0.00627241, -0.00963544,
|
||||
-0.00650568
|
||||
-1.14494488e-02, 8.77974741e-03, -1.62960067e-02,
|
||||
-1.39961652e-02, 1.34879015e-02, -1.04502086e-02,
|
||||
6.15879148e-03, -9.40956455e-03, -6.57592434e-03
|
||||
], [
|
||||
-0.04763246, -0.01192755, -0.00019379, 0.04103841,
|
||||
-0.00138055, 0.02126456, -0.02793905, -0.0546775,
|
||||
-0.02912027
|
||||
-4.74739634e-02, -1.19136050e-02, -7.36759976e-05,
|
||||
4.10547927e-02, -1.36767328e-03, 2.11772677e-02,
|
||||
-2.80479677e-02, -5.44514805e-02, -2.91903690e-02
|
||||
], [
|
||||
0.02241185, -0.00141741, 0.01911988, 0.00547728,
|
||||
-0.01280068, -0.00307024, -0.00494239, 0.02169247,
|
||||
0.01631995
|
||||
2.25644894e-02, -1.40382675e-03, 1.92396250e-02,
|
||||
5.49034867e-03, -1.27930511e-02, -3.15603940e-03,
|
||||
-5.05525898e-03, 2.19191350e-02, 1.62497871e-02
|
||||
]],
|
||||
dtype=float32),
|
||||
h=array(
|
||||
[[
|
||||
-1.10612623e-02, -3.98178305e-03, -8.15406092e-04,
|
||||
7.90496264e-03, -7.02379830e-03, -3.76371504e-03,
|
||||
-4.16189339e-03, -6.17096573e-03, 8.95528216e-03
|
||||
-1.09839402e-02, -3.97479767e-03, -7.54472159e-04,
|
||||
7.91201927e-03, -7.02175125e-03, -3.80689627e-03,
|
||||
-4.22065007e-03, -6.05447078e-03, 8.92056432e-03
|
||||
], [
|
||||
8.60652886e-03, -7.17687514e-03, -1.94147555e-03,
|
||||
1.62586085e-02, -1.76823605e-02, -1.06194830e-02,
|
||||
-3.01912241e-03, -9.57269047e-04, -9.95719433e-03
|
||||
8.68127123e-03, -7.16970162e-03, -1.88375649e-03,
|
||||
1.62681788e-02, -1.76830534e-02, -1.06617520e-02,
|
||||
-3.07536125e-03, -8.45551898e-04, -9.99375992e-03
|
||||
], [
|
||||
-5.78764686e-03, 4.49441886e-03, -8.13564472e-03,
|
||||
-6.95375400e-03, 6.75391173e-03, -5.07880514e-03,
|
||||
3.11744539e-03, -4.72871540e-03, -3.20470310e-03
|
||||
-5.71034756e-03, 4.50129062e-03, -8.07590690e-03,
|
||||
-6.94835978e-03, 6.75921654e-03, -5.12148207e-03,
|
||||
3.06083867e-03, -4.61710012e-03, -3.23932176e-03
|
||||
], [
|
||||
-2.38018595e-02, -5.89303859e-03, -9.74571449e-05,
|
||||
2.01695058e-02, -6.82353624e-04, 1.07099945e-02,
|
||||
-1.42086931e-02, -2.70793252e-02, -1.44684194e-02
|
||||
-2.37224493e-02, -5.88587578e-03, -3.70525813e-05,
|
||||
2.01787278e-02, -6.76127791e-04, 1.06675029e-02,
|
||||
-1.42634306e-02, -2.69631632e-02, -1.45033058e-02
|
||||
], [
|
||||
1.11825848e-02, -6.99267141e-04, 9.82748345e-03,
|
||||
2.74566701e-03, -6.56377291e-03, -1.53681310e-03,
|
||||
-2.48806458e-03, 1.10462429e-02, 7.97568541e-03
|
||||
1.12585640e-02, -6.92534202e-04, 9.88917705e-03,
|
||||
2.75237625e-03, -6.56115822e-03, -1.57997780e-03,
|
||||
-2.54477374e-03, 1.11598391e-02, 7.94144534e-03
|
||||
]],
|
||||
dtype=float32)),
|
||||
attention=array(
|
||||
[[
|
||||
1.22659411e-02, -6.81970268e-03, 3.15135531e-03,
|
||||
-1.96937821e-03, 5.62768336e-03, -1.39173865e-02
|
||||
0.00165033, 0.00196973, 0.00203463, -0.00382007, -0.00346369,
|
||||
-0.00654224
|
||||
], [
|
||||
1.57857724e-02, -1.09536834e-02, 4.64798324e-03,
|
||||
-1.01319887e-03, 1.22695938e-02, -1.25500849e-02
|
||||
0.00517205, -0.00216447, 0.00353219, -0.00286491, 0.00317879,
|
||||
-0.00517592
|
||||
], [
|
||||
9.49526206e-03, -9.39475093e-03, -8.49372707e-04,
|
||||
-1.72815053e-05, 4.16132808e-03, -1.66336838e-02
|
||||
-0.00111966, -0.00060763, -0.00196644, -0.00186804, -0.00493048,
|
||||
-0.00925842
|
||||
], [
|
||||
1.15833860e-02, -1.31145213e-02, -2.84505659e-04,
|
||||
2.48642010e-03, 6.93593081e-03, -7.82784075e-04
|
||||
0.00096792, -0.00432467, -0.00140086, 0.00063597, -0.00215558,
|
||||
0.00659212
|
||||
], [
|
||||
1.07398070e-02, -8.56867433e-03, 1.52354129e-03, 2.06834078e-03,
|
||||
9.36511997e-03, -7.64556089e-03
|
||||
0.00012445, 0.00022082, 0.00040711, 0.00021803, 0.0002734,
|
||||
-0.00026981
|
||||
]],
|
||||
dtype=float32))
|
||||
dtype=float32),
|
||||
time=3,
|
||||
attention_history=())
|
||||
|
||||
self._testWithAttention(
|
||||
create_attention_mechanism,
|
||||
expected_final_output,
|
||||
expected_final_state,
|
||||
attention_mechanism_depth=9)
|
||||
attention_mechanism_depth=9,
|
||||
name="testLuongNotNormalized")
|
||||
|
||||
def testLuongNormalized(self):
|
||||
def testLuongScaled(self):
|
||||
create_attention_mechanism = functools.partial(
|
||||
wrapper.LuongAttention, normalize=True, attention_r_initializer=2.0)
|
||||
wrapper.LuongAttention, scale=True)
|
||||
|
||||
array = np.array
|
||||
float32 = np.float32
|
||||
int32 = np.int32
|
||||
|
||||
expected_final_output = basic_decoder.BasicDecoderOutput(
|
||||
expected_final_output = BasicDecoderOutput(
|
||||
rnn_output=array(
|
||||
[[[
|
||||
1.23956744e-02, -6.88115368e-03, 3.15234554e-03,
|
||||
-1.97300944e-03, 4.79680905e-03, -1.38076628e-02
|
||||
1.74749165e-03, 1.95862399e-03, 2.12293095e-03, -3.75889172e-03,
|
||||
-4.39571124e-03, -6.32379763e-03
|
||||
], [
|
||||
1.28376717e-02, -6.78718928e-03, 3.07988771e-03,
|
||||
-2.03956687e-03, 5.68403490e-03, -1.35601182e-02
|
||||
2.33045570e-03, 1.99094601e-03, 1.98377599e-03, -3.87950847e-03,
|
||||
-3.42792575e-03, -6.17497414e-03
|
||||
], [
|
||||
1.23463338e-02, -6.76322030e-03, 3.28891934e-03,
|
||||
-1.86874042e-03, 5.47897862e-03, -1.37654068e-02
|
||||
1.65032526e-03, 1.96972815e-03, 2.03462853e-03, -3.82007333e-03,
|
||||
-3.46369296e-03, -6.54224353e-03
|
||||
]], [[
|
||||
1.54412268e-02, -1.07613346e-02, 4.43824846e-03,
|
||||
-8.81063985e-04, 1.26828086e-02, -1.21067995e-02
|
||||
4.77780215e-03, -1.98677275e-03, 3.30950436e-03,
|
||||
-2.68179504e-03, 3.56271653e-03, -4.67860466e-03
|
||||
], [
|
||||
1.57206059e-02, -1.08218864e-02, 4.61952807e-03,
|
||||
-9.61483689e-04, 1.22140013e-02, -1.26614980e-02
|
||||
5.13039157e-03, -2.02797214e-03, 3.50760575e-03,
|
||||
-2.83981953e-03, 3.13726603e-03, -5.31156827e-03
|
||||
], [
|
||||
1.57821011e-02, -1.09842420e-02, 4.66934917e-03,
|
||||
-9.85997496e-04, 1.22719472e-02, -1.25438003e-02
|
||||
5.17205056e-03, -2.16446724e-03, 3.53219034e-03,
|
||||
-2.86490913e-03, 3.17879021e-03, -5.17592067e-03
|
||||
]], [[
|
||||
9.27361846e-03, -9.66077764e-03, -9.69522633e-04,
|
||||
1.48308463e-05, 3.88664147e-03, -1.64083000e-02
|
||||
-1.38538703e-03, -6.40910701e-04, -2.02864106e-03,
|
||||
-1.79018872e-03, -5.18789608e-03, -8.95875692e-03
|
||||
], [
|
||||
9.26287938e-03, -9.74234194e-03, -8.32488062e-04,
|
||||
5.83778601e-05, 3.52663640e-03, -1.66827720e-02
|
||||
-1.38620089e-03, -7.92010222e-04, -1.91070826e-03,
|
||||
-1.76206254e-03, -5.56525169e-03, -9.27332044e-03
|
||||
], [
|
||||
9.50474478e-03, -9.49789397e-03, -8.71829456e-04,
|
||||
-3.09986062e-05, 4.13423358e-03, -1.66635048e-02
|
||||
-1.11966045e-03, -6.07630936e-04, -1.96643686e-03,
|
||||
-1.86803937e-03, -4.93048411e-03, -9.25842486e-03
|
||||
]], [[
|
||||
1.21398102e-02, -1.27454493e-02, 1.57688977e-04, 2.70034792e-03,
|
||||
7.79653806e-03, -8.36936757e-04
|
||||
1.50820788e-03, -3.93087184e-03, -9.52563598e-04,
|
||||
8.43994785e-04, -1.29030924e-03, 6.48857141e-03
|
||||
], [
|
||||
1.18234595e-02, -1.33170560e-02, 4.55579720e-05, 2.67185434e-03,
|
||||
6.99766818e-03, -1.00935437e-03
|
||||
1.17029145e-03, -4.45716921e-03, -1.05062663e-03,
|
||||
8.08141369e-04, -2.08062865e-03, 6.23444980e-03
|
||||
], [
|
||||
1.16009805e-02, -1.31483339e-02, -2.94458936e-04,
|
||||
2.49248254e-03, 6.92958105e-03, -7.20315147e-04
|
||||
9.67921398e-04, -4.32466762e-03, -1.40085898e-03,
|
||||
6.35969569e-04, -2.15558149e-03, 6.59212377e-03
|
||||
]], [[
|
||||
1.02377674e-02, -8.72955937e-03, 1.22555892e-03, 2.03830865e-03,
|
||||
8.93574394e-03, -7.28237582e-03
|
||||
-3.78854020e-04, 5.62231544e-05, 1.06837302e-04, 1.87137164e-04,
|
||||
-1.56512906e-04, 9.63474595e-05
|
||||
], [
|
||||
1.05115287e-02, -8.92531779e-03, 1.14568521e-03, 1.91635895e-03,
|
||||
8.94328393e-03, -7.39541650e-03
|
||||
-1.04306288e-04, -1.37411975e-04, 2.82689070e-05,
|
||||
6.56487318e-05, -1.48634164e-04, -1.84347919e-05
|
||||
], [
|
||||
1.07398070e-02, -8.56867433e-03, 1.52354129e-03, 2.06834078e-03,
|
||||
9.36511997e-03, -7.64556089e-03
|
||||
1.24452345e-04, 2.20821079e-04, 4.07114130e-04, 2.18028668e-04,
|
||||
2.73401442e-04, -2.69805576e-04
|
||||
]]],
|
||||
dtype=float32),
|
||||
sample_id=array(
|
||||
[[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]],
|
||||
[[2, 0, 2], [0, 0, 0], [1, 1, 1], [5, 5, 5], [3, 3, 2]],
|
||||
dtype=int32))
|
||||
expected_final_state = wrapper.AttentionWrapperState(
|
||||
time=3,
|
||||
attention_history=(),
|
||||
cell_state=core_rnn_cell.LSTMStateTuple(
|
||||
|
||||
expected_final_state = AttentionWrapperState(
|
||||
cell_state=LSTMStateTuple(
|
||||
c=array(
|
||||
[[
|
||||
-0.02204949, -0.00805957, -0.001603, 0.01609283,
|
||||
-0.01380462, -0.0074945, -0.00816895, -0.01210009,
|
||||
0.01795324
|
||||
-2.18960866e-02, -8.04429129e-03, -1.48267671e-03,
|
||||
1.61071159e-02, -1.37981661e-02, -7.57933082e-03,
|
||||
-8.28570686e-03, -1.18733812e-02, 1.78834442e-02
|
||||
], [
|
||||
0.01727016, -0.01420708, -0.00399973, 0.03195432,
|
||||
-0.03547529, -0.02138673, -0.00610332, -0.00191565,
|
||||
-0.01937822
|
||||
1.74204130e-02, -1.41935758e-02, -3.88074201e-03,
|
||||
3.19713727e-02, -3.54694910e-02, -2.14688145e-02,
|
||||
-6.21731905e-03, -1.69229065e-03, -1.94492843e-02
|
||||
], [
|
||||
-0.01160676, 0.00876512, -0.01641791, -0.01400807,
|
||||
0.01347767, -0.01036341, 0.00627499, -0.00963627,
|
||||
-0.00650573
|
||||
-1.14494488e-02, 8.77974741e-03, -1.62960067e-02,
|
||||
-1.39961652e-02, 1.34879015e-02, -1.04502086e-02,
|
||||
6.15879148e-03, -9.40956455e-03, -6.57592434e-03
|
||||
], [
|
||||
-0.04763342, -0.01192671, -0.00019402, 0.04103871,
|
||||
-0.00138017, 0.02126611, -0.02793773, -0.05467714,
|
||||
-0.02912043
|
||||
-4.74739634e-02, -1.19136050e-02, -7.36759976e-05,
|
||||
4.10547927e-02, -1.36767328e-03, 2.11772677e-02,
|
||||
-2.80479677e-02, -5.44514805e-02, -2.91903690e-02
|
||||
], [
|
||||
0.02241185, -0.00141741, 0.01911988, 0.00547728,
|
||||
-0.01280068, -0.00307024, -0.00494239, 0.02169247,
|
||||
0.01631995
|
||||
2.25644894e-02, -1.40382675e-03, 1.92396250e-02,
|
||||
5.49034867e-03, -1.27930511e-02, -3.15603940e-03,
|
||||
-5.05525898e-03, 2.19191350e-02, 1.62497871e-02
|
||||
]],
|
||||
dtype=float32),
|
||||
h=array(
|
||||
[[
|
||||
-1.10610286e-02, -3.98253463e-03, -8.15684092e-04,
|
||||
7.90454168e-03, -7.02364743e-03, -3.76377185e-03,
|
||||
-4.16135695e-03, -6.17104582e-03, 8.95532966e-03
|
||||
-1.09839402e-02, -3.97479767e-03, -7.54472159e-04,
|
||||
7.91201927e-03, -7.02175125e-03, -3.80689627e-03,
|
||||
-4.22065007e-03, -6.05447078e-03, 8.92056432e-03
|
||||
], [
|
||||
8.60653073e-03, -7.17685232e-03, -1.94147974e-03,
|
||||
1.62585936e-02, -1.76823437e-02, -1.06195193e-02,
|
||||
-3.01911240e-03, -9.57308919e-04, -9.95720550e-03
|
||||
8.68127123e-03, -7.16970162e-03, -1.88375649e-03,
|
||||
1.62681788e-02, -1.76830534e-02, -1.06617520e-02,
|
||||
-3.07536125e-03, -8.45551898e-04, -9.99375992e-03
|
||||
], [
|
||||
-5.78888878e-03, 4.49400023e-03, -8.13617278e-03,
|
||||
-6.95386063e-03, 6.75271638e-03, -5.07823005e-03,
|
||||
3.11873178e-03, -4.72912844e-03, -3.20472987e-03
|
||||
-5.71034756e-03, 4.50129062e-03, -8.07590690e-03,
|
||||
-6.94835978e-03, 6.75921654e-03, -5.12148207e-03,
|
||||
3.06083867e-03, -4.61710012e-03, -3.23932176e-03
|
||||
], [
|
||||
-2.38023344e-02, -5.89262368e-03, -9.75721487e-05,
|
||||
2.01696623e-02, -6.82163402e-04, 1.07107637e-02,
|
||||
-1.42080421e-02, -2.70791352e-02, -1.44685050e-02
|
||||
-2.37224493e-02, -5.88587578e-03, -3.70525813e-05,
|
||||
2.01787278e-02, -6.76127791e-04, 1.06675029e-02,
|
||||
-1.42634306e-02, -2.69631632e-02, -1.45033058e-02
|
||||
], [
|
||||
1.11825848e-02, -6.99267141e-04, 9.82748345e-03,
|
||||
2.74566701e-03, -6.56377291e-03, -1.53681310e-03,
|
||||
-2.48806458e-03, 1.10462429e-02, 7.97568541e-03
|
||||
1.12585640e-02, -6.92534202e-04, 9.88917705e-03,
|
||||
2.75237625e-03, -6.56115822e-03, -1.57997780e-03,
|
||||
-2.54477374e-03, 1.11598391e-02, 7.94144534e-03
|
||||
]],
|
||||
dtype=float32)),
|
||||
attention=array(
|
||||
[[
|
||||
1.23463338e-02, -6.76322030e-03, 3.28891934e-03,
|
||||
-1.86874042e-03, 5.47897862e-03, -1.37654068e-02
|
||||
0.00165033, 0.00196973, 0.00203463, -0.00382007, -0.00346369,
|
||||
-0.00654224
|
||||
], [
|
||||
1.57821011e-02, -1.09842420e-02, 4.66934917e-03,
|
||||
-9.85997496e-04, 1.22719472e-02, -1.25438003e-02
|
||||
0.00517205, -0.00216447, 0.00353219, -0.00286491, 0.00317879,
|
||||
-0.00517592
|
||||
], [
|
||||
9.50474478e-03, -9.49789397e-03, -8.71829456e-04,
|
||||
-3.09986062e-05, 4.13423358e-03, -1.66635048e-02
|
||||
-0.00111966, -0.00060763, -0.00196644, -0.00186804, -0.00493048,
|
||||
-0.00925842
|
||||
], [
|
||||
1.16009805e-02, -1.31483339e-02, -2.94458936e-04,
|
||||
2.49248254e-03, 6.92958105e-03, -7.20315147e-04
|
||||
0.00096792, -0.00432467, -0.00140086, 0.00063597, -0.00215558,
|
||||
0.00659212
|
||||
], [
|
||||
1.07398070e-02, -8.56867433e-03, 1.52354129e-03, 2.06834078e-03,
|
||||
9.36511997e-03, -7.64556089e-03
|
||||
0.00012445, 0.00022082, 0.00040711, 0.00021803, 0.0002734,
|
||||
-0.00026981
|
||||
]],
|
||||
dtype=float32))
|
||||
dtype=float32),
|
||||
time=3,
|
||||
attention_history=())
|
||||
|
||||
self._testWithAttention(
|
||||
create_attention_mechanism,
|
||||
expected_final_output,
|
||||
expected_final_state,
|
||||
attention_mechanism_depth=9)
|
||||
attention_mechanism_depth=9,
|
||||
name="testLuongScaled")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -176,19 +176,18 @@ class LuongAttention(_BaseAttentionMechanism):
|
||||
"Effective Approaches to Attention-based Neural Machine Translation."
|
||||
EMNLP 2015. https://arxiv.org/abs/1508.04025
|
||||
|
||||
The second is the normalized form. This form is inspired by the
|
||||
normalization proposed for Bahdanau attention in
|
||||
|
||||
Colin Raffel, Thang Luong, Peter J. Liu, Ron J. Weiss, and Douglas Eck.
|
||||
"Online and Linear-Time Attention by Enforcing Monotonic Alignments."
|
||||
(Eq. 15).
|
||||
The second is the scaled form inspired partly by the normalized form of
|
||||
Bahdanau attention.
|
||||
|
||||
To enable the second form, construct the object with parameter
|
||||
`normalize=True`.
|
||||
`scale=True`.
|
||||
"""
|
||||
|
||||
def __init__(self, num_units, memory, memory_sequence_length=None,
|
||||
normalize=False, attention_r_initializer=None,
|
||||
def __init__(self,
|
||||
num_units,
|
||||
memory,
|
||||
memory_sequence_length=None,
|
||||
scale=False,
|
||||
name="LuongAttention"):
|
||||
"""Construct the AttentionMechanism mechanism.
|
||||
|
||||
@ -199,31 +198,21 @@ class LuongAttention(_BaseAttentionMechanism):
|
||||
memory_sequence_length (optional): Sequence lengths for the batch entries
|
||||
in memory. If provided, the memory tensor rows are masked with zeros
|
||||
for values past the respective sequence lengths.
|
||||
normalize: Python boolean. Whether to normalize the energy term.
|
||||
attention_r_initializer: Initial value of the post-normalization bias
|
||||
when normalizing. Default is `0`.
|
||||
scale: Python boolean. Whether to scale the energy term.
|
||||
name: Name to use when creating ops.
|
||||
"""
|
||||
# For LuongAttention, we only transform the memory layer; thus
|
||||
# num_units **must** match expected the query depth.
|
||||
super(LuongAttention, self).__init__(
|
||||
query_layer=None,
|
||||
memory_layer=layers_core.Dense(num_units, name="memory_layer"),
|
||||
memory_layer=layers_core.Dense(
|
||||
num_units, name="memory_layer", use_bias=False),
|
||||
memory=memory,
|
||||
memory_sequence_length=memory_sequence_length,
|
||||
name=name)
|
||||
self._num_units = num_units
|
||||
self._normalize = normalize
|
||||
self._scale = scale
|
||||
self._name = name
|
||||
if normalize and attention_r_initializer is None:
|
||||
attention_r_initializer = 0
|
||||
if normalize:
|
||||
with ops.name_scope(name, "LuongAttention",
|
||||
[memory, attention_r_initializer]):
|
||||
attention_r_initializer = ops.convert_to_tensor(
|
||||
attention_r_initializer, dtype=self.values.dtype,
|
||||
name="attention_r_initializer")
|
||||
self._attention_r_initializer = attention_r_initializer
|
||||
|
||||
def __call__(self, query):
|
||||
"""Score the query based on the keys and values.
|
||||
@ -249,7 +238,7 @@ class LuongAttention(_BaseAttentionMechanism):
|
||||
% (query, depth, self.keys, key_units, key_units))
|
||||
dtype = query.dtype
|
||||
|
||||
with ops.name_scope(None, "LuongAttentionCall", [query]):
|
||||
with variable_scope.variable_scope(None, "luong_attention", [query]):
|
||||
# Reshape from [batch_size, depth] to [batch_size, 1, depth]
|
||||
# for matmul.
|
||||
query = array_ops.expand_dims(query, 1)
|
||||
@ -266,16 +255,11 @@ class LuongAttention(_BaseAttentionMechanism):
|
||||
score = math_ops.matmul(query, self.keys, transpose_b=True)
|
||||
score = array_ops.squeeze(score, [1])
|
||||
|
||||
if self._normalize:
|
||||
# Scalar used in weight normalization
|
||||
if self._scale:
|
||||
# Scalar used in weight scaling
|
||||
g = variable_scope.get_variable(
|
||||
"attention_g", dtype=dtype,
|
||||
initializer=math.sqrt((1. / self._num_units)))
|
||||
# Scalar bias added to attention scores
|
||||
r = variable_scope.get_variable(
|
||||
"attention_r", dtype=dtype,
|
||||
initializer=self._attention_r_initializer)
|
||||
score = g * score + r
|
||||
"attention_g", dtype=dtype, initializer=1.)
|
||||
score = g * score
|
||||
|
||||
return score
|
||||
|
||||
@ -290,18 +274,23 @@ class BahdanauAttention(_BaseAttentionMechanism):
|
||||
"Neural Machine Translation by Jointly Learning to Align and Translate."
|
||||
ICLR 2015. https://arxiv.org/abs/1409.0473
|
||||
|
||||
The second is the normalized form, Raffel attention, as described in:
|
||||
The second is the normalized form. This form is inspired by the
|
||||
weight normalization article:
|
||||
|
||||
Colin Raffel, Thang Luong, Peter J. Liu, Ron J. Weiss, and Douglas Eck.
|
||||
"Online and Linear-Time Attention by Enforcing Monotonic Alignments."
|
||||
(Eq. 15).
|
||||
Tim Salimans, Diederik P. Kingma.
|
||||
"Weight Normalization: A Simple Reparameterization to Accelerate
|
||||
Training of Deep Neural Networks."
|
||||
https://arxiv.org/abs/1602.07868
|
||||
|
||||
To enable the second form, construct the object with parameter
|
||||
`normalize=True`.
|
||||
"""
|
||||
|
||||
def __init__(self, num_units, memory, memory_sequence_length=None,
|
||||
normalize=False, attention_r_initializer=None,
|
||||
def __init__(self,
|
||||
num_units,
|
||||
memory,
|
||||
memory_sequence_length=None,
|
||||
normalize=False,
|
||||
name="BahdanauAttention"):
|
||||
"""Construct the Attention mechanism.
|
||||
|
||||
@ -313,28 +302,19 @@ class BahdanauAttention(_BaseAttentionMechanism):
|
||||
in memory. If provided, the memory tensor rows are masked with zeros
|
||||
for values past the respective sequence lengths.
|
||||
normalize: Python boolean. Whether to normalize the energy term.
|
||||
attention_r_initializer: Initial value of the post-normalization bias
|
||||
when normalizing. Default is `0`.
|
||||
name: Name to use when creating ops.
|
||||
"""
|
||||
super(BahdanauAttention, self).__init__(
|
||||
query_layer=layers_core.Dense(num_units, name="query_layer"),
|
||||
memory_layer=layers_core.Dense(num_units, name="memory_layer"),
|
||||
query_layer=layers_core.Dense(
|
||||
num_units, name="query_layer", use_bias=False),
|
||||
memory_layer=layers_core.Dense(
|
||||
num_units, name="memory_layer", use_bias=False),
|
||||
memory=memory,
|
||||
memory_sequence_length=memory_sequence_length,
|
||||
name=name)
|
||||
self._num_units = num_units
|
||||
self._normalize = normalize
|
||||
self._name = name
|
||||
if normalize and attention_r_initializer is None:
|
||||
attention_r_initializer = 0
|
||||
if normalize:
|
||||
with ops.name_scope(name, "BahdanauAttention",
|
||||
[memory, attention_r_initializer]):
|
||||
attention_r_initializer = ops.convert_to_tensor(
|
||||
attention_r_initializer, dtype=self.values.dtype,
|
||||
name="attention_r_initializer")
|
||||
self._attention_r_initializer = attention_r_initializer
|
||||
|
||||
def __call__(self, query):
|
||||
"""Score the query based on the keys and values.
|
||||
@ -347,7 +327,7 @@ class BahdanauAttention(_BaseAttentionMechanism):
|
||||
score: Tensor of dtype matching `self.values` and shape
|
||||
`[batch_size, max_time]` (`max_time` is memory's `max_time`).
|
||||
"""
|
||||
with ops.name_scope(None, "BahndahauAttentionCall", [query]):
|
||||
with variable_scope.variable_scope(None, "bahdanau_attention", [query]):
|
||||
processed_query = self.query_layer(query) if self.query_layer else query
|
||||
dtype = processed_query.dtype
|
||||
# Reshape from [batch_size, ...] to [batch_size, 1, ...] for broadcasting.
|
||||
@ -363,15 +343,11 @@ class BahdanauAttention(_BaseAttentionMechanism):
|
||||
b = variable_scope.get_variable(
|
||||
"attention_b", [self._num_units], dtype=dtype,
|
||||
initializer=init_ops.zeros_initializer())
|
||||
# Scalar bias added to attention scores
|
||||
r = variable_scope.get_variable(
|
||||
"attention_r", dtype=dtype,
|
||||
initializer=self._attention_r_initializer)
|
||||
# normed_v = g * v / ||v||
|
||||
normed_v = g * v * math_ops.rsqrt(
|
||||
math_ops.reduce_sum(math_ops.square(v)))
|
||||
score = math_ops.reduce_sum(
|
||||
normed_v * math_ops.tanh(self.keys + processed_query + b), [2]) + r
|
||||
normed_v * math_ops.tanh(self.keys + processed_query + b), [2])
|
||||
else:
|
||||
score = math_ops.reduce_sum(
|
||||
v * math_ops.tanh(self.keys + processed_query), [2])
|
||||
@ -481,7 +457,7 @@ class AttentionWrapper(core_rnn_cell.RNNCell):
|
||||
self._attention_mechanism = attention_mechanism
|
||||
self._attention_size = attention_size
|
||||
self._attention_layer = layers_core.Dense(
|
||||
attention_size, bias_initializer=None)
|
||||
attention_size, name="attention_layer", use_bias=False)
|
||||
self._cell_input_fn = cell_input_fn
|
||||
self._probability_fn = probability_fn
|
||||
self._output_attention = output_attention
|
||||
@ -550,44 +526,44 @@ class AttentionWrapper(core_rnn_cell.RNNCell):
|
||||
if scope is not None:
|
||||
raise NotImplementedError("scope not None is not supported")
|
||||
|
||||
# Step 1: Calculate the true inputs to the cell based on the
|
||||
# previous attention value.
|
||||
cell_inputs = self._cell_input_fn(inputs, state.attention)
|
||||
cell_state = state.cell_state
|
||||
with variable_scope.variable_scope("attention"):
|
||||
# Step 1: Calculate the true inputs to the cell based on the
|
||||
# previous attention value.
|
||||
cell_inputs = self._cell_input_fn(inputs, state.attention)
|
||||
cell_state = state.cell_state
|
||||
|
||||
cell_output, next_cell_state = self._cell(cell_inputs, cell_state)
|
||||
cell_output, next_cell_state = self._cell(cell_inputs, cell_state)
|
||||
|
||||
score = self._attention_mechanism(cell_output)
|
||||
alignments = self._probability_fn(score)
|
||||
score = self._attention_mechanism(cell_output)
|
||||
alignments = self._probability_fn(score)
|
||||
|
||||
# Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time]
|
||||
alignments = array_ops.expand_dims(alignments, 1)
|
||||
# Context is the inner product of alignments and values along the
|
||||
# memory time dimension.
|
||||
# alignments shape is
|
||||
# [batch_size, 1, memory_time]
|
||||
# attention_mechanism.values shape is
|
||||
# [batch_size, memory_time, attention_mechanism.num_units]
|
||||
# the batched matmul is over memory_time, so the output shape is
|
||||
# [batch_size, 1, attention_mechanism.num_units].
|
||||
# we then squeeze out the singleton dim.
|
||||
context = math_ops.matmul(alignments, self._attention_mechanism.values)
|
||||
context = array_ops.squeeze(context, [1])
|
||||
# Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time]
|
||||
alignments = array_ops.expand_dims(alignments, 1)
|
||||
# Context is the inner product of alignments and values along the
|
||||
# memory time dimension.
|
||||
# alignments shape is
|
||||
# [batch_size, 1, memory_time]
|
||||
# attention_mechanism.values shape is
|
||||
# [batch_size, memory_time, attention_mechanism.num_units]
|
||||
# the batched matmul is over memory_time, so the output shape is
|
||||
# [batch_size, 1, attention_mechanism.num_units].
|
||||
# we then squeeze out the singleton dim.
|
||||
context = math_ops.matmul(alignments, self._attention_mechanism.values)
|
||||
context = array_ops.squeeze(context, [1])
|
||||
|
||||
attention = self._attention_layer(
|
||||
array_ops.concat([cell_output, context], 1))
|
||||
attention = self._attention_layer(
|
||||
array_ops.concat([cell_output, context], 1))
|
||||
|
||||
if self._attention_history:
|
||||
attention_history = state.attention_history.write(
|
||||
state.time, attention)
|
||||
else:
|
||||
attention_history = ()
|
||||
if self._attention_history:
|
||||
attention_history = state.attention_history.write(state.time, attention)
|
||||
else:
|
||||
attention_history = ()
|
||||
|
||||
next_state = AttentionWrapperState(
|
||||
time=state.time + 1,
|
||||
cell_state=next_cell_state,
|
||||
attention=attention,
|
||||
attention_history=attention_history)
|
||||
next_state = AttentionWrapperState(
|
||||
time=state.time + 1,
|
||||
cell_state=next_cell_state,
|
||||
attention=attention,
|
||||
attention_history=attention_history)
|
||||
|
||||
if self._output_attention:
|
||||
return attention, next_state
|
||||
|
@ -269,7 +269,12 @@ class SparseTensor(ItemHandler):
|
||||
class Image(ItemHandler):
|
||||
"""An ItemHandler that decodes a parsed Tensor as an image."""
|
||||
|
||||
def __init__(self, image_key=None, format_key=None, shape=None, channels=3):
|
||||
def __init__(self,
|
||||
image_key=None,
|
||||
format_key=None,
|
||||
shape=None,
|
||||
channels=3,
|
||||
dtype=dtypes.uint8):
|
||||
"""Initializes the image.
|
||||
|
||||
Args:
|
||||
@ -282,6 +287,11 @@ class Image(ItemHandler):
|
||||
accordingly. If left as None, no reshaping is done. A shape should
|
||||
be supplied only if all the stored images have the same shape.
|
||||
channels: the number of channels in the image.
|
||||
dtype: images will be decoded at this bit depth. Different formats
|
||||
support different bit depths.
|
||||
See tf.image.decode_png,
|
||||
tf.decode_raw,
|
||||
tf.image.decode_jpeg: only supports tf.uint8
|
||||
"""
|
||||
if not image_key:
|
||||
image_key = 'image/encoded'
|
||||
@ -293,6 +303,7 @@ class Image(ItemHandler):
|
||||
self._format_key = format_key
|
||||
self._shape = shape
|
||||
self._channels = channels
|
||||
self._dtype = dtype
|
||||
|
||||
def tensors_to_item(self, keys_to_tensors):
|
||||
"""See base class."""
|
||||
@ -314,12 +325,17 @@ class Image(ItemHandler):
|
||||
"""
|
||||
|
||||
def decode_png():
|
||||
return image_ops.decode_png(image_buffer, self._channels)
|
||||
return image_ops.decode_png(
|
||||
image_buffer, self._channels, dtype=self._dtype)
|
||||
|
||||
def decode_raw():
|
||||
return parsing_ops.decode_raw(image_buffer, dtypes.uint8)
|
||||
return parsing_ops.decode_raw(image_buffer, out_type=self._dtype)
|
||||
|
||||
def decode_jpg():
|
||||
if self._dtype != dtypes.uint8:
|
||||
raise ValueError(
|
||||
'jpeg decoder can only be used to decode to tf.uint8 but %s was '
|
||||
'requested for a jpeg image.' % self._dtype)
|
||||
return image_ops.decode_jpeg(image_buffer, self._channels)
|
||||
|
||||
# For RGBA images JPEG is not a valid decoder option.
|
||||
@ -401,6 +417,7 @@ class TFExampleDecoder(data_decoder.DataDecoder):
|
||||
"""
|
||||
example = parsing_ops.parse_single_example(serialized_example,
|
||||
self._keys_to_features)
|
||||
print(example.keys())
|
||||
|
||||
# Reshape non-sparse elements just once:
|
||||
for k in self._keys_to_features:
|
||||
|
@ -224,6 +224,18 @@ class TFExampleDecoderTest(test.TestCase):
|
||||
|
||||
self.assertAllClose(image, decoded_image, atol=0)
|
||||
|
||||
def testDecodeExampleWithJpegEncodingAt16BitCausesError(self):
|
||||
image_shape = (2, 3, 3)
|
||||
unused_image, serialized_example = self.GenerateImage(
|
||||
image_format='jpeg', image_shape=image_shape)
|
||||
expected_regex = ('jpeg decoder can only be used to decode to tf.uint8 but '
|
||||
'.* was requested for a jpeg image.')
|
||||
with self.assertRaisesRegexp(ValueError, expected_regex):
|
||||
unused_decoded_image = self.RunDecodeExample(
|
||||
serialized_example,
|
||||
tfexample_decoder.Image(dtype=dtypes.uint16),
|
||||
image_format='jpeg')
|
||||
|
||||
def testDecodeExampleWithStringTensor(self):
|
||||
tensor_shape = (2, 3, 1)
|
||||
np_array = np.array([[['ab'], ['cd'], ['ef']],
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user