Merge pull request #9582 from vrv/branch_154781831

Branch 154781831
This commit is contained in:
Vijay Vasudevan 2017-05-01 19:57:43 -07:00 committed by GitHub
commit 0cfb68cb26
117 changed files with 4231 additions and 1050 deletions

View File

@ -406,7 +406,46 @@ class ReferenceUtil {
const PaddingConfig& padding,
const float pad);
// ApplyElementwise2D(f, x, y, ...) returns the Array2D formed by running
// f(x[i], y[i], ...) for each array element in the Array2Ds x, y, ....
//
// The given arrays must have the same size and element type, and the return
// type of f must be implicitly convertible to the arrays' element type.
//
// Example usage:
//
// Array2D<float> x, y, z = ...;
// std::unique_ptr<Array2D> result = ReferenceUtil::ApplyElementwise2D(
// [](float a, float b, float c) { return a * b + c; }, x, y, z);
//
template <typename F, typename T1, typename... Ts>
static std::unique_ptr<Array2D<T1>> ApplyElementwise2D(
F&& f, const Array2D<T1>& array1, const Array2D<Ts>&... arrays) {
AssertSameSize2D(array1, arrays...);
auto result = MakeUnique<Array2D<T1>>(array1.n1(), array1.n1());
for (int64 i = 0; i < array1.n1(); ++i) {
for (int64 j = 0; j < array1.n2(); ++j) {
(*result)(i, j) = f(array1(i, j), arrays(i, j)...);
}
}
return result;
}
private:
template <typename T1, typename T2, typename... Ts>
static void AssertSameSize2D(const Array2D<T1>& array1,
const Array2D<T2>& array2,
const Array2D<Ts>&... arrays) {
static_assert(std::is_same<T1, T2>::value, "Args must be same type.");
CHECK_EQ(array1.n1(), array2.n1());
CHECK_EQ(array1.n2(), array2.n2());
AssertSameSize2D(array2, arrays...);
}
// Recursive base case for AssertSameSize2D.
template <typename Array1>
static void AssertSameSize2D(const Array1& array1) {}
TF_DISALLOW_COPY_AND_ASSIGN(ReferenceUtil);
};

View File

@ -52,9 +52,9 @@ class ReferenceUtilTest : public ::testing::Test {
TEST_F(ReferenceUtilTest, TransposeArray2D) {
auto result = ReferenceUtil::TransposeArray2D(*matrix_);
auto result_literal = LiteralUtil::CreateR2FromArray2D(*result);
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
LiteralTestUtil::ExpectR2Near<float>({{1.f, 4.f}, {2.f, 5.f}, {3.f, 6.f}},
*result_literal, ErrorSpec(0.0001));
*actual_literal, ErrorSpec(0.0001));
}
TEST_F(ReferenceUtilTest, MatmulArray2D) {
@ -62,32 +62,32 @@ TEST_F(ReferenceUtilTest, MatmulArray2D) {
{7.f, 8.f}, {9.f, 10.f}, {11.f, 12.f},
});
auto result = ReferenceUtil::MatmulArray2D(*matrix_, rhs);
auto result_literal = LiteralUtil::CreateR2FromArray2D(*result);
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
LiteralTestUtil::ExpectR2Near<float>({{58.f, 64.f}, {139.f, 154.f}},
*result_literal, ErrorSpec(0.0001));
*actual_literal, ErrorSpec(0.0001));
}
TEST_F(ReferenceUtilTest, ReduceToColArray2D) {
auto add = [](float lhs, float rhs) { return lhs + rhs; };
auto result = ReferenceUtil::ReduceToColArray2D(*matrix_, 0.0f, add);
auto result_literal = LiteralUtil::CreateR1<float>(*result);
LiteralTestUtil::ExpectR1Near<float>({6.f, 15.f}, *result_literal,
auto actual_literal = LiteralUtil::CreateR1<float>(*result);
LiteralTestUtil::ExpectR1Near<float>({6.f, 15.f}, *actual_literal,
ErrorSpec(0.0001));
}
TEST_F(ReferenceUtilTest, ReduceToRowArray2D) {
auto add = [](float lhs, float rhs) { return lhs + rhs; };
auto result = ReferenceUtil::ReduceToRowArray2D(*matrix_, 0.0f, add);
auto result_literal = LiteralUtil::CreateR1<float>(*result);
LiteralTestUtil::ExpectR1Near<float>({5.f, 7.f, 9.f}, *result_literal,
auto actual_literal = LiteralUtil::CreateR1<float>(*result);
LiteralTestUtil::ExpectR1Near<float>({5.f, 7.f, 9.f}, *actual_literal,
ErrorSpec(0.0001));
}
TEST_F(ReferenceUtilTest, MapArray2D) {
auto identity = [](float value) { return log(exp(value)); };
auto result = ReferenceUtil::MapArray2D(*matrix_, identity);
auto result_literal = LiteralUtil::CreateR2FromArray2D(*result);
LiteralTestUtil::ExpectR2NearArray2D(*matrix_, *result_literal,
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
LiteralTestUtil::ExpectR2NearArray2D(*matrix_, *actual_literal,
ErrorSpec(0.0001));
}
@ -96,9 +96,9 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray2D) {
return value + row + col;
};
auto result = ReferenceUtil::MapWithIndexArray2D(*matrix_, add_index);
auto result_literal = LiteralUtil::CreateR2FromArray2D(*result);
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result);
LiteralTestUtil::ExpectR2Near<float>({{1.f, 3.f, 5.f}, {5.f, 7.f, 9.f}},
*result_literal, ErrorSpec(0.0001));
*actual_literal, ErrorSpec(0.0001));
}
TEST_F(ReferenceUtilTest, MapArray4D) {
@ -107,11 +107,11 @@ TEST_F(ReferenceUtilTest, MapArray4D) {
input->FillWithMultiples(1.0f);
auto multiply_by_two = [](float value) { return 2 * value; };
auto result = ReferenceUtil::MapArray4D(*input, multiply_by_two);
auto result_literal = LiteralUtil::CreateR4FromArray4D(*result);
auto actual_literal = LiteralUtil::CreateR4FromArray4D(*result);
Array4D<float> expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5);
expected.FillWithMultiples(2.0f);
LiteralTestUtil::ExpectR4NearArray4D(expected, *result_literal,
LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal,
ErrorSpec(0.0001));
}
@ -124,11 +124,11 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray4D) {
return value - (3 * 4 * 5 * plane + 4 * 5 * depth + 5 * height + width);
};
auto result = ReferenceUtil::MapWithIndexArray4D(*input, subtract_index);
auto result_literal = LiteralUtil::CreateR4FromArray4D(*result);
auto actual_literal = LiteralUtil::CreateR4FromArray4D(*result);
Array4D<float> expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5);
expected.Fill(0.0f);
LiteralTestUtil::ExpectR4NearArray4D(expected, *result_literal,
LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal,
ErrorSpec(0.0001));
}
@ -302,5 +302,17 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithValidPadding) {
ErrorSpec(0.0001));
}
TEST_F(ReferenceUtilTest, ApplyElementwise2D) {
Array2D<float> a({{1, 2}, {3, 4}});
Array2D<float> b({{10, 20}, {30, 40}});
Array2D<float> c({{100, 200}, {300, 400}});
auto actual = ReferenceUtil::ApplyElementwise2D(
[](float x, float y, float z) { return 100 * x + 10 * y + z; }, a, b, c);
auto actual_literal = LiteralUtil::CreateR2FromArray2D(*actual);
LiteralTestUtil::ExpectR2Near({{300.f, 600.f}, {900.f, 1200.f}},
*actual_literal, ErrorSpec(0.0001));
}
} // namespace
} // namespace xla

View File

@ -859,7 +859,9 @@ cc_library(
":hlo_pass",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
],
)

View File

@ -410,7 +410,9 @@ HloInstruction::CreateSelectAndScatter(
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReshape(
const Shape& shape, HloInstruction* operand) {
CHECK_EQ(ShapeUtil::ElementsIn(shape),
ShapeUtil::ElementsIn(operand->shape()));
ShapeUtil::ElementsIn(operand->shape()))
<< "shape: " << ShapeUtil::HumanString(shape)
<< " operand: " << ShapeUtil::HumanString(operand->shape());
auto instruction = WrapUnique(new HloInstruction(HloOpcode::kReshape, shape));
instruction->AppendOperand(operand);
return instruction;
@ -1428,7 +1430,8 @@ string HloInstruction::ExtendedOpcodeStr() const {
return opc_name;
}
string HloInstruction::ToString(bool compact_operands) const {
string HloInstruction::ToString(bool compact_operands,
bool include_metadata) const {
string operands;
if (opcode() == HloOpcode::kConstant) {
// For constants, show the actual value in place of an empty operand list.
@ -1509,8 +1512,9 @@ string HloInstruction::ToString(bool compact_operands) const {
if (opcode() == HloOpcode::kGetTupleElement) {
StrAppend(&extra, ", index=", tuple_index());
}
if (!metadata_.op_type().empty() || !metadata_.op_name().empty() ||
!metadata_.source_file().empty()) {
if (include_metadata &&
(!metadata_.op_type().empty() || !metadata_.op_name().empty() ||
!metadata_.source_file().empty())) {
StrAppend(&extra, " # metadata=", metadata_.ShortDebugString());
}

View File

@ -489,7 +489,10 @@ class HloInstruction {
string SignatureString() const;
// Returns a debugging string that represents this instruction.
string ToString(bool compact_operands = false) const;
string ToString(bool compact_operands = false,
bool include_metadata = true) const;
string ToStringNoMetadata() const { return ToString(false, false); }
// As ToString, but returns a shorter string.
string ToShortString() const;

View File

@ -13,17 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/reshape_mover.h"
#include <algorithm>
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/util.h"
namespace xla {
namespace {
// Implementation note:
//
// The general idea behind this pass is that we're converting from this:
// %param.A = OldShape
// %param.B = OldShape
@ -44,6 +35,19 @@ namespace {
// only implicit scalar broadcast is on Pred, not on A or B. Since reshapes or
// transposes to a scalar should be cheap, we simply never move them.
#include "tensorflow/compiler/xla/service/reshape_mover.h"
#include <algorithm>
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
namespace xla {
namespace {
// Finds the first non-scalar operand of an instruction that is a reshape or
// transpose and returns the operand if it is found or nullptr if not found.
HloInstruction* FirstNonScalarReshapeOperand(const HloInstruction* hlo) {
@ -51,6 +55,9 @@ HloInstruction* FirstNonScalarReshapeOperand(const HloInstruction* hlo) {
if (!ShapeUtil::IsScalar(operand->shape()) &&
(operand->opcode() == HloOpcode::kReshape ||
operand->opcode() == HloOpcode::kTranspose)) {
VLOG(5) << "Found first non-scalar reshape operand of "
<< hlo->ToStringNoMetadata() << ":\n\t"
<< operand->ToStringNoMetadata();
return operand;
}
}
@ -70,6 +77,9 @@ bool OperandCanTrivallyChangeShape(const HloInstruction* instruction,
// A constant can trivially reshape the literal it holds.
if (operand->opcode() == HloOpcode::kConstant &&
ShapeUtil::SameDimensions(operand->shape(), instruction->shape())) {
VLOG(5) << "Constant had same dimensions as instruction:\n\toperand: "
<< operand->ToStringNoMetadata()
<< "\n\tinstruction: " << instruction->ToStringNoMetadata();
return true;
}
@ -116,119 +126,159 @@ bool IsElementwiseOfEquivalentReshapesOrTransposes(
if (!first_reshape_operand) {
return false;
}
return (instruction->user_count() > 0 ||
instruction == instruction->parent()->root_instruction()) &&
instruction->IsElementwise() && !operands.empty() &&
// Check whether all operands:
// 1. are all reshapes or transposes that have the same input and
// output shapes as all other reshaped or transposed operands.
// or
// 2. can be any shape like kConstant, kRng, and scalars.
std::all_of(
operands.begin(), operands.end(),
[instruction,
first_reshape_operand](const HloInstruction* operand) {
return AreEquivalentReshapes(first_reshape_operand, operand) ||
OperandCanTrivallyChangeShape(instruction, operand);
});
VLOG(3) << "** Checking whether instruction is an elementwise operation of "
"equivalent reshapes/transposes: "
<< instruction->ToStringNoMetadata();
bool result =
(instruction->user_count() > 0 ||
instruction == instruction->parent()->root_instruction()) &&
instruction->IsElementwise() && !operands.empty() &&
// Check whether all operands:
// 0. Have the same dimensions as the output -- if not, it may be
// implicitly broadcast, which can confound the movement's
// correctness.
// 1. Are all reshapes or transposes that have the same input and
// output shapes as all other reshaped or transposed operands.
// or
// 2. Can be any shape like kConstant, kRng, and scalars.
std::all_of(
operands.begin(), operands.end(),
[instruction, first_reshape_operand](const HloInstruction* operand) {
if (!ShapeUtil::SameDimensions(operand->shape(),
instruction->shape())) {
VLOG(5) << "Operand shape differs from output shape; may be "
"implicitly broadcast, so preventing "
"movement\n\toperand: "
<< operand->ToStringNoMetadata() << "\n\tinstruction: "
<< instruction->ToStringNoMetadata();
return false;
}
if (AreEquivalentReshapes(first_reshape_operand, operand)) {
VLOG(5) << "Are equivalent reshapes:\n\tfirst_reshape_operand: "
<< first_reshape_operand->ToStringNoMetadata()
<< "\n\toperand: " << operand->ToStringNoMetadata();
return true;
}
if (OperandCanTrivallyChangeShape(instruction, operand)) {
VLOG(5) << "Operand can trivially change shape: "
<< operand->ToStringNoMetadata();
return true;
}
return false;
});
VLOG(3) << "ElementwiseOfEquivalentReshapesOrTransposes result for "
<< instruction->ToStringNoMetadata() << ": " << result;
return result;
}
// Try to sink any reshape or transpose operands of `instruction` across it. We
// do so if `instruction` is elementwise and all operands are equivalent
// reshapes or transposes.
bool TrySinkReshapeOrTranspose(HloComputation* computation,
HloInstruction* instruction) {
if (IsElementwiseOfEquivalentReshapesOrTransposes(instruction)) {
std::vector<HloInstruction*> operands = instruction->operands();
HloInstruction* old_reshape = FirstNonScalarReshapeOperand(instruction);
CHECK(old_reshape != nullptr);
Shape new_elementwise_shape = old_reshape->operand(0)->shape();
for (size_t i = 0; i < operands.size(); ++i) {
// All scalar operands remain as-is, even if they're reshape or transpose,
// to simplify handling wrt special scalar broadcast rules for ops like
// Select. Scalar reshapes should be cheap anyways.
if (ShapeUtil::IsScalar(operands[i]->shape())) {
continue;
}
auto element_type = operands[i]->shape().element_type();
switch (operands[i]->opcode()) {
case HloOpcode::kConstant: {
if (old_reshape->opcode() == HloOpcode::kReshape) {
operands[i] = instruction->parent()->AddInstruction(
HloInstruction::CreateReshape(
ShapeUtil::ChangeElementType(new_elementwise_shape,
element_type),
operands[i]));
} else {
CHECK_EQ(old_reshape->opcode(), HloOpcode::kTranspose);
std::vector<int64> inverse_permutation =
InversePermutation(old_reshape->dimensions());
operands[i] = instruction->parent()->AddInstruction(
HloInstruction::CreateTranspose(
ShapeUtil::ChangeElementType(new_elementwise_shape,
element_type),
operands[i], inverse_permutation));
}
break;
}
case HloOpcode::kRng: {
CHECK_EQ(operands[i]->user_count(), 1);
StatusOr<bool> TrySinkReshapeOrTranspose(HloComputation* computation,
HloInstruction* instruction) {
if (!IsElementwiseOfEquivalentReshapesOrTransposes(instruction)) {
return false;
}
std::vector<HloInstruction*> operands = instruction->operands();
HloInstruction* old_reshape = FirstNonScalarReshapeOperand(instruction);
TF_RET_CHECK(old_reshape != nullptr);
Shape new_elementwise_shape = old_reshape->operand(0)->shape();
VLOG(3) << "** Trying to sink reshape or transpose: "
<< instruction->ToStringNoMetadata()
<< "\n\told reshape: " << old_reshape->ToStringNoMetadata()
<< "\n\tnew elementwise shape: "
<< ShapeUtil::HumanString(new_elementwise_shape);
for (size_t i = 0; i < operands.size(); ++i) {
// All scalar operands remain as-is, even if they're reshape or transpose,
// to simplify handling wrt special scalar broadcast rules for ops like
// Select. Scalar reshapes should be cheap anyways.
if (ShapeUtil::IsScalar(operands[i]->shape())) {
continue;
}
PrimitiveType element_type = operands[i]->shape().element_type();
switch (operands[i]->opcode()) {
case HloOpcode::kConstant: {
if (old_reshape->opcode() == HloOpcode::kReshape) {
VLOG(3) << "Creating reshape for kConstant operand " << i << ": "
<< operands[i]->ToStringNoMetadata();
operands[i] = instruction->parent()->AddInstruction(
operands[i]->CloneWithNewOperands(
HloInstruction::CreateReshape(
ShapeUtil::ChangeElementType(new_elementwise_shape,
element_type),
operands[i]->operands()));
break;
operands[i]));
} else {
TF_RET_CHECK(old_reshape->opcode() == HloOpcode::kTranspose);
std::vector<int64> inverse_permutation =
InversePermutation(old_reshape->dimensions());
operands[i] = instruction->parent()->AddInstruction(
HloInstruction::CreateTranspose(
ShapeUtil::ChangeElementType(new_elementwise_shape,
element_type),
operands[i], inverse_permutation));
}
case HloOpcode::kReshape:
case HloOpcode::kTranspose:
operands[i] = operands[i]->mutable_operand(0);
break;
default:
LOG(FATAL) << "Unexpected opcode while trying to sink reshapes or "
"transposes.";
}
}
if (HloOpcode::kFusion == instruction->opcode()) {
// Here we already know `instruction` is elementwise, and no operand is
// implicit broadcast as if it were the operands would not be equivalent
// reshapes, so all the fused instructions have the same dimensions.
for (const auto& fused_instruction : instruction->fused_instructions()) {
Shape* shape = fused_instruction->mutable_shape();
*shape->mutable_dimensions() = new_elementwise_shape.dimensions();
*shape->mutable_layout() = new_elementwise_shape.layout();
}
}
auto new_elementwise =
computation->AddInstruction(instruction->CloneWithNewOperands(
// `instruction` may change the element type, e.g., from
// operands[0] -> reshape -> convert (`instruction`)
// to
// operands[0] -> convert' -> reshape'
//
// In this case, convert' should have the same element type as
// `convert` and the same dimensions as operands[0].
ShapeUtil::ChangeElementType(new_elementwise_shape,
instruction->shape().element_type()),
operands));
std::unique_ptr<HloInstruction> new_reshape;
switch (old_reshape->opcode()) {
case HloOpcode::kReshape:
new_reshape = HloInstruction::CreateReshape(instruction->shape(),
new_elementwise);
break;
}
case HloOpcode::kRng: {
CHECK_EQ(operands[i]->user_count(), 1);
operands[i] = instruction->parent()->AddInstruction(
operands[i]->CloneWithNewOperands(
ShapeUtil::ChangeElementType(new_elementwise_shape,
element_type),
operands[i]->operands()));
break;
}
case HloOpcode::kReshape:
case HloOpcode::kTranspose:
new_reshape = HloInstruction::CreateTranspose(
instruction->shape(), new_elementwise, old_reshape->dimensions());
operands[i] = operands[i]->mutable_operand(0);
break;
default:
LOG(FATAL) << "Bad opcode";
LOG(FATAL) << "Unexpected opcode while trying to sink reshapes or "
"transposes.";
}
TF_CHECK_OK(computation->ReplaceWithNewInstruction(instruction,
std::move(new_reshape)));
return true;
}
return false;
if (HloOpcode::kFusion == instruction->opcode()) {
// Here we already know `instruction` is elementwise, and no operand is
// implicit broadcast as if it were the operands would not be equivalent
// reshapes, so all the fused instructions have the same dimensions.
for (const auto& fused_instruction : instruction->fused_instructions()) {
Shape* shape = fused_instruction->mutable_shape();
*shape->mutable_dimensions() = new_elementwise_shape.dimensions();
*shape->mutable_layout() = new_elementwise_shape.layout();
}
}
HloInstruction* new_elementwise =
computation->AddInstruction(instruction->CloneWithNewOperands(
// `instruction` may change the element type, e.g., from
// operands[0] -> reshape -> convert (`instruction`)
// to
// operands[0] -> convert' -> reshape'
//
// In this case, convert' should have the same element type as
// `convert` and the same dimensions as operands[0].
ShapeUtil::ChangeElementType(new_elementwise_shape,
instruction->shape().element_type()),
operands));
std::unique_ptr<HloInstruction> new_reshape;
switch (old_reshape->opcode()) {
case HloOpcode::kReshape:
VLOG(3) << "Creating new reshape for new elementwise op: "
<< new_elementwise->ToStringNoMetadata();
new_reshape =
HloInstruction::CreateReshape(instruction->shape(), new_elementwise);
break;
case HloOpcode::kTranspose:
new_reshape = HloInstruction::CreateTranspose(
instruction->shape(), new_elementwise, old_reshape->dimensions());
break;
default:
LOG(FATAL) << "Bad opcode";
}
TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction(
instruction, std::move(new_reshape)));
return true;
}
} // namespace
@ -237,9 +287,9 @@ StatusOr<bool> ReshapeMover::Run(HloModule* module) {
bool changed = false;
for (const auto& comp : module->computations()) {
for (HloInstruction* instruction : comp->MakeInstructionPostOrder()) {
if (TrySinkReshapeOrTranspose(comp.get(), instruction)) {
changed = true;
}
TF_ASSIGN_OR_RETURN(bool did_change,
TrySinkReshapeOrTranspose(comp.get(), instruction));
changed |= did_change;
}
}
return changed;

View File

@ -234,6 +234,58 @@ TEST_F(ReshapeMoverTest, ScalarReshapeNotMovedAcrossSelect) {
EXPECT_EQ(select, computation->root_instruction());
}
// Tree looks like:
//
// param0 [1,128,1]
// |
// reshape [128,1] constant [128,1024]
// \ /
// multiply w/implicit broadcast [128,1024]
//
// The reshape mover would like to sink the reshape below the multiply.
//
// Previously we would attempt to insert a reshape of the constant to [1,128,1]
// (which is unsound, because it has a different number of elements) as
// preparation for sinking the reshape.
//
// To eliminate the unsoundness, we outlaw reshape sinking when one of the
// operands is implicitly broadcast in the elementwise consumer.
//
// TODO(b/37799338) However, it would be possible in this case to do a more
// in-depth analysis to get reshape movement to occur:
//
// 1. Note that the broadcast dimension (logical dimension 1) in the operands
// would map back to logical dimension 2 in the param0 node.
// 2. Match rank of the constant to the param0 node (by prepending a trivial 1
// dimension).
// 3. Reshape to [128,1024] at the root.
//
// But this is not currently done.
TEST_F(ReshapeMoverTest, ImplicitlyBroadcastReshapeIsNotMovedBug37787999) {
HloComputation::Builder builder(TestName());
auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {1, 128, 1}), "param0"));
auto reshape = builder.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(F32, {128, 1}), param0));
Array2D<float> a(128, 1024);
auto literal = LiteralUtil::CreateR2FromArray2D<float>(a);
auto constant = builder.AddInstruction(
HloInstruction::CreateConstant(std::move(literal)));
auto multiply = builder.AddInstruction(HloInstruction::CreateBinary(
constant->shape(), HloOpcode::kMultiply, constant, reshape));
auto module = MakeUnique<HloModule>(TestName());
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
op::Multiply(op::Constant(), op::Reshape(param0)));
EXPECT_FALSE(ReshapeMover().Run(module.get()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
op::Multiply(op::Constant(), op::Reshape(param0)));
EXPECT_EQ(multiply, computation->root_instruction());
}
// Tree looks like this:
//
// add1

View File

@ -78,7 +78,6 @@ cc_library(
"//tensorflow/contrib/batching:batch_ops_kernels",
"//tensorflow/contrib/factorization/kernels:all_kernels",
"//tensorflow/contrib/input_pipeline:input_pipeline_ops_kernels",
"//tensorflow/contrib/layers:bucketization_op_kernel",
"//tensorflow/contrib/layers:sparse_feature_cross_op_kernel",
"//tensorflow/contrib/nccl:nccl_kernels",
"//tensorflow/contrib/tensor_forest:tensor_forest_kernels",
@ -93,7 +92,6 @@ cc_library(
"//tensorflow/contrib/factorization:all_ops",
"//tensorflow/contrib/framework:all_ops",
"//tensorflow/contrib/input_pipeline:input_pipeline_ops_op_lib",
"//tensorflow/contrib/layers:bucketization_op_op_lib",
"//tensorflow/contrib/layers:sparse_feature_cross_op_op_lib",
"//tensorflow/contrib/nccl:nccl_ops_op_lib",
"//tensorflow/contrib/tensor_forest:tensor_forest_ops_op_lib",

View File

@ -95,7 +95,7 @@ class ElboRatioTest(test.TestCase):
n=n_samples,
form=entropy.ELBOForms.sample,
seed=42)
actual_kl = distributions.kl(q, p)
actual_kl = distributions.kl_divergence(q, p)
# Relative tolerance (rtol) chosen 2 times as large as minimim needed to
# pass.
@ -123,7 +123,7 @@ class ElboRatioTest(test.TestCase):
n=n_samples,
form=entropy.ELBOForms.analytic_entropy,
seed=42)
actual_kl = distributions.kl(q, p)
actual_kl = distributions.kl_divergence(q, p)
# Relative tolerance (rtol) chosen 2 times as large as minimim needed to
# pass.

View File

@ -22,11 +22,11 @@ import numpy as np
from tensorflow.contrib.bayesflow.python.ops import stochastic_gradient_estimators
from tensorflow.contrib.bayesflow.python.ops import stochastic_tensor_impl
from tensorflow.contrib.distributions.python.ops import normal
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops.distributions import normal
from tensorflow.python.platform import test
sge = stochastic_gradient_estimators

View File

@ -22,12 +22,12 @@ from tensorflow.contrib import distributions as distributions_lib
from tensorflow.contrib import layers
from tensorflow.contrib.bayesflow.python.ops import stochastic_tensor
from tensorflow.contrib.bayesflow.python.ops import variational_inference_impl
from tensorflow.contrib.distributions.python.ops import kullback_leibler
from tensorflow.contrib.distributions.python.ops import normal
from tensorflow.python.framework import constant_op
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.distributions import kullback_leibler
from tensorflow.python.ops.distributions import normal
from tensorflow.python.platform import test
st = stochastic_tensor
@ -68,7 +68,7 @@ class VariationalInferenceTest(test.TestCase):
def testDefaultVariationalAndPrior(self):
_, prior, variational, _, log_likelihood = mini_vae()
elbo = vi.elbo(log_likelihood)
expected_elbo = log_likelihood - kullback_leibler.kl(
expected_elbo = log_likelihood - kullback_leibler.kl_divergence(
variational.distribution, prior)
with self.test_session() as sess:
sess.run(variables.global_variables_initializer())
@ -80,7 +80,7 @@ class VariationalInferenceTest(test.TestCase):
prior = normal.Normal(loc=3., scale=2.)
elbo = vi.elbo(
log_likelihood, variational_with_prior={variational: prior})
expected_elbo = log_likelihood - kullback_leibler.kl(
expected_elbo = log_likelihood - kullback_leibler.kl_divergence(
variational.distribution, prior)
sess.run(variables.global_variables_initializer())
self.assertAllEqual(*sess.run([expected_elbo, elbo]))
@ -121,7 +121,7 @@ class VariationalInferenceTest(test.TestCase):
# No analytic KL available between prior and variational distributions.
with self.assertRaisesRegexp(NotImplementedError, "No KL"):
distributions.kl(variational.distribution, prior)
distributions.kl_divergence(variational.distribution, prior)
elbo = vi.elbo(
variational_with_prior={variational: prior},

View File

@ -84,8 +84,9 @@ def elbo_ratio(log_p,
KL[q || p] = E[ Log[q(Z)] - Log[p(Z)] ]
```
Note that if `p` is a `Distribution`, then `distributions.kl(q, p)` may be
defined and available as an exact result.
Note that if `p` is a `Distribution`, then
`distributions.kl_divergence(q, p)` may be defined and available as an
exact result.
#### ELBO

View File

@ -28,10 +28,10 @@ from __future__ import print_function
from tensorflow.contrib.bayesflow.python.ops import stochastic_graph_impl as sg
from tensorflow.contrib.bayesflow.python.ops import stochastic_tensor_impl as st
from tensorflow.contrib.distributions.python.ops import kullback_leibler
from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import kullback_leibler
from tensorflow.python.platform import tf_logging as logging
VI_PRIORS = "__vi_priors__"
@ -259,7 +259,7 @@ def _elbo(form, log_likelihood, log_joint, variational_with_prior,
kl = None
if log_joint is None and form in {ELBOForms.default, ELBOForms.analytic_kl}:
try:
kl = kullback_leibler.kl(q, p)
kl = kullback_leibler.kl_divergence(q, p)
logging.info("Using analytic KL between q:%s, p:%s", q, p)
except NotImplementedError as e:
if form == ELBOForms.analytic_kl:

View File

@ -47,9 +47,7 @@ if(tensorflow_BUILD_CONTRIB_KERNELS)
"${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/factorization_ops.cc"
#"${tensorflow_source_dir}/tensorflow/contrib/ffmpeg/decode_audio_op.cc"
#"${tensorflow_source_dir}/tensorflow/contrib/ffmpeg/encode_audio_op.cc"
"${tensorflow_source_dir}/tensorflow/contrib/layers/kernels/bucketization_kernel.cc"
"${tensorflow_source_dir}/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc"
"${tensorflow_source_dir}/tensorflow/contrib/layers/ops/bucketization_op.cc"
"${tensorflow_source_dir}/tensorflow/contrib/layers/ops/sparse_feature_cross_op.cc"
"${tensorflow_source_dir}/tensorflow/contrib/nccl/kernels/nccl_manager.cc"
"${tensorflow_source_dir}/tensorflow/contrib/nccl/kernels/nccl_ops.cc"

View File

@ -70,7 +70,6 @@ GENERATE_CONTRIB_OP_LIBRARY(factorization_factorization "${tensorflow_source_dir
GENERATE_CONTRIB_OP_LIBRARY(framework_variable "${tensorflow_source_dir}/tensorflow/contrib/framework/ops/variable_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(input_pipeline "${tensorflow_source_dir}/tensorflow/contrib/input_pipeline/ops/input_pipeline_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(image "${tensorflow_source_dir}/tensorflow/contrib/image/ops/image_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(layers_bucketization "${tensorflow_source_dir}/tensorflow/contrib/layers/ops/bucketization_op.cc")
GENERATE_CONTRIB_OP_LIBRARY(layers_sparse_feature_cross "${tensorflow_source_dir}/tensorflow/contrib/layers/ops/sparse_feature_cross_op.cc")
GENERATE_CONTRIB_OP_LIBRARY(memory_stats "${tensorflow_source_dir}/tensorflow/contrib/memory_stats/ops/memory_stats_ops.cc")
GENERATE_CONTRIB_OP_LIBRARY(nccl "${tensorflow_source_dir}/tensorflow/contrib/nccl/ops/nccl_ops.cc")

View File

@ -623,8 +623,6 @@ GENERATE_PYTHON_OP_LIB("contrib_input_pipeline_ops"
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/input_pipeline/ops/gen_input_pipeline_ops.py)
GENERATE_PYTHON_OP_LIB("contrib_image_ops"
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/image/ops/gen_image_ops.py)
GENERATE_PYTHON_OP_LIB("contrib_layers_bucketization_ops"
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/layers/ops/gen_bucketization_op.py)
GENERATE_PYTHON_OP_LIB("contrib_layers_sparse_feature_cross_ops"
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/layers/ops/gen_sparse_feature_cross_op.py)
GENERATE_PYTHON_OP_LIB("contrib_memory_stats_ops"

View File

@ -496,24 +496,6 @@ cuda_py_test(
],
)
cuda_py_test(
name = "normal_test",
size = "medium",
srcs = ["python/kernel_tests/normal_test.py"],
additional_deps = [
":distributions_py",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:gradients",
"//tensorflow/python:nn_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:variables",
],
)
cuda_py_test(
name = "poisson_test",
size = "small",
@ -614,18 +596,6 @@ cuda_py_test(
],
)
cuda_py_test(
name = "kullback_leibler_test",
size = "small",
srcs = ["python/kernel_tests/kullback_leibler_test.py"],
additional_deps = [
":distributions_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:platform_test",
],
)
cuda_py_test(
name = "normal_conjugate_posteriors_test",
size = "small",
@ -751,22 +721,6 @@ cuda_py_test(
tags = ["no_pip"],
)
cuda_py_test(
name = "special_math_test",
size = "medium",
srcs = ["python/kernel_tests/special_math_test.py"],
additional_deps = [
":distributions_py",
"//third_party/py/numpy",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:gradients",
"//tensorflow/python:platform_test",
"//tensorflow/python:variables",
],
)
cuda_py_test(
name = "distribution_util_test",
size = "small",

View File

@ -71,7 +71,7 @@ See the @{$python/contrib.distributions} guide.
@@RelaxedOneHotCategorical
## Kullback-Leibler Divergence
@@kl
@@kl_divergence
@@RegisterKL
## Helper Functions
@ -106,7 +106,6 @@ from tensorflow.contrib.distributions.python.ops.exponential import *
from tensorflow.contrib.distributions.python.ops.gamma import *
from tensorflow.contrib.distributions.python.ops.geometric import *
from tensorflow.contrib.distributions.python.ops.inverse_gamma import *
from tensorflow.contrib.distributions.python.ops.kullback_leibler import *
from tensorflow.contrib.distributions.python.ops.laplace import *
from tensorflow.contrib.distributions.python.ops.logistic import *
from tensorflow.contrib.distributions.python.ops.mixture import *
@ -115,7 +114,6 @@ from tensorflow.contrib.distributions.python.ops.mvn_diag import *
from tensorflow.contrib.distributions.python.ops.mvn_diag_plus_low_rank import *
from tensorflow.contrib.distributions.python.ops.mvn_tril import *
from tensorflow.contrib.distributions.python.ops.negative_binomial import *
from tensorflow.contrib.distributions.python.ops.normal import *
from tensorflow.contrib.distributions.python.ops.normal_conjugate_posteriors import *
from tensorflow.contrib.distributions.python.ops.onehot_categorical import *
from tensorflow.contrib.distributions.python.ops.poisson import *
@ -129,6 +127,8 @@ from tensorflow.contrib.distributions.python.ops.uniform import *
from tensorflow.contrib.distributions.python.ops.wishart import *
from tensorflow.python.ops.distributions.conditional_distribution import *
from tensorflow.python.ops.distributions.distribution import *
from tensorflow.python.ops.distributions.kullback_leibler import *
from tensorflow.python.ops.distributions.normal import *
# pylint: enable=unused-import,wildcard-import,line-too-long,g-importing-member

View File

@ -21,11 +21,11 @@ from __future__ import print_function
import numpy as np
import scipy.special
from tensorflow.contrib.distributions.python.ops import bernoulli
from tensorflow.contrib.distributions.python.ops import kullback_leibler
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import kullback_leibler
from tensorflow.python.platform import test
@ -286,7 +286,7 @@ class BernoulliTest(test.TestCase):
a = bernoulli.Bernoulli(probs=a_p)
b = bernoulli.Bernoulli(probs=b_p)
kl = kullback_leibler.kl(a, b)
kl = kullback_leibler.kl_divergence(a, b)
kl_val = sess.run(kl)
kl_expected = (a_p * np.log(a_p / b_p) + (1. - a_p) * np.log(

View File

@ -20,13 +20,13 @@ import numpy as np
from scipy import special
from scipy import stats
from tensorflow.contrib.distributions.python.ops import beta as beta_lib
from tensorflow.contrib.distributions.python.ops import kullback_leibler
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops.distributions import kullback_leibler
from tensorflow.python.platform import test
@ -349,13 +349,13 @@ class BetaTest(test.TestCase):
for dist1 in [d1, d1_sp]:
for dist2 in [d2, d2_sp]:
kl = kullback_leibler.kl(dist1, dist2)
kl = kullback_leibler.kl_divergence(dist1, dist2)
kl_val = sess.run(kl)
self.assertEqual(kl.get_shape(), shape)
self.assertAllClose(kl_val, kl_expected)
# Make sure KL(d1||d1) is 0
kl_same = sess.run(kullback_leibler.kl(d1, d1))
kl_same = sess.run(kullback_leibler.kl_divergence(d1, d1))
self.assertAllClose(kl_same, np.zeros_like(kl_expected))

View File

@ -41,6 +41,12 @@ class SoftplusBijectorTest(test.TestCase):
"""Inverse log det jacobian, before being reduced."""
return -np.log(1 - np.exp(-y))
def testHingeSoftnessZeroRaises(self):
with self.test_session():
bijector = Softplus(event_ndims=0, hinge_softness=0., validate_args=True)
with self.assertRaisesOpError("must be non-zero"):
bijector.forward([1., 1.]).eval()
def testBijectorForwardInverseEventDimsZero(self):
with self.test_session():
bijector = Softplus(event_ndims=0)

View File

@ -21,7 +21,6 @@ from __future__ import print_function
import numpy as np
from tensorflow.contrib.distributions.python.ops import categorical
from tensorflow.contrib.distributions.python.ops import kullback_leibler
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_util
@ -30,6 +29,7 @@ from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import kullback_leibler
from tensorflow.python.platform import test
@ -278,10 +278,10 @@ class CategoricalTest(test.TestCase):
a = categorical.Categorical(logits=a_logits)
b = categorical.Categorical(logits=b_logits)
kl = kullback_leibler.kl(a, b)
kl = kullback_leibler.kl_divergence(a, b)
kl_val = sess.run(kl)
# Make sure KL(a||a) is 0
kl_same = sess.run(kullback_leibler.kl(a, a))
kl_same = sess.run(kullback_leibler.kl_divergence(a, a))
prob_a = np_softmax(a_logits)
prob_b = np_softmax(b_logits)

View File

@ -22,12 +22,12 @@ from scipy import special
from scipy import stats
from tensorflow.contrib.distributions.python.ops import gamma as gamma_lib
from tensorflow.contrib.distributions.python.ops import kullback_leibler
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops.distributions import kullback_leibler
from tensorflow.python.platform import test
@ -345,7 +345,7 @@ class GammaTest(test.TestCase):
g1 = gamma_lib.Gamma(concentration=alpha1, rate=beta1)
x = g0.sample(int(1e4), seed=0)
kl_sample = math_ops.reduce_mean(g0.log_prob(x) - g1.log_prob(x), 0)
kl_actual = kullback_leibler.kl(g0, g1)
kl_actual = kullback_leibler.kl_divergence(g0, g1)
# Execute graph.
[kl_sample_, kl_actual_] = sess.run([kl_sample, kl_actual])

View File

@ -185,19 +185,19 @@ class MultivariateNormalDiagPlusLowRankTest(test.TestCase):
sample_kl_identity = math_ops.reduce_mean(
dist.log_prob(samps) - mvn_identity.log_prob(samps), 0)
analytical_kl_identity = ds.kl(dist, mvn_identity)
analytical_kl_identity = ds.kl_divergence(dist, mvn_identity)
sample_kl_scaled = math_ops.reduce_mean(
dist.log_prob(samps) - mvn_scaled.log_prob(samps), 0)
analytical_kl_scaled = ds.kl(dist, mvn_scaled)
analytical_kl_scaled = ds.kl_divergence(dist, mvn_scaled)
sample_kl_diag = math_ops.reduce_mean(
dist.log_prob(samps) - mvn_diag.log_prob(samps), 0)
analytical_kl_diag = ds.kl(dist, mvn_diag)
analytical_kl_diag = ds.kl_divergence(dist, mvn_diag)
sample_kl_chol = math_ops.reduce_mean(
dist.log_prob(samps) - mvn_chol.log_prob(samps), 0)
analytical_kl_chol = ds.kl(dist, mvn_chol)
analytical_kl_chol = ds.kl_divergence(dist, mvn_chol)
n = int(10e3)
baseline = ds.MultivariateNormalDiag(
@ -208,19 +208,21 @@ class MultivariateNormalDiagPlusLowRankTest(test.TestCase):
sample_kl_identity_diag_baseline = math_ops.reduce_mean(
baseline.log_prob(samps) - mvn_identity.log_prob(samps), 0)
analytical_kl_identity_diag_baseline = ds.kl(baseline, mvn_identity)
analytical_kl_identity_diag_baseline = ds.kl_divergence(
baseline, mvn_identity)
sample_kl_scaled_diag_baseline = math_ops.reduce_mean(
baseline.log_prob(samps) - mvn_scaled.log_prob(samps), 0)
analytical_kl_scaled_diag_baseline = ds.kl(baseline, mvn_scaled)
analytical_kl_scaled_diag_baseline = ds.kl_divergence(
baseline, mvn_scaled)
sample_kl_diag_diag_baseline = math_ops.reduce_mean(
baseline.log_prob(samps) - mvn_diag.log_prob(samps), 0)
analytical_kl_diag_diag_baseline = ds.kl(baseline, mvn_diag)
analytical_kl_diag_diag_baseline = ds.kl_divergence(baseline, mvn_diag)
sample_kl_chol_diag_baseline = math_ops.reduce_mean(
baseline.log_prob(samps) - mvn_chol.log_prob(samps), 0)
analytical_kl_chol_diag_baseline = ds.kl(baseline, mvn_chol)
analytical_kl_chol_diag_baseline = ds.kl_divergence(baseline, mvn_chol)
[
sample_mean_,

View File

@ -241,7 +241,7 @@ class MultivariateNormalTriLTest(test.TestCase):
scale_tril=np.linalg.cholesky(sigma_b),
validate_args=True)
kl = ds.kl(mvn_a, mvn_b)
kl = ds.kl_divergence(mvn_a, mvn_b)
self.assertEqual(batch_shape, kl.get_shape())
kl_v = kl.eval()
@ -263,7 +263,7 @@ class MultivariateNormalTriLTest(test.TestCase):
scale_tril=np.linalg.cholesky(sigma_b),
validate_args=True)
kl = ds.kl(mvn_a, mvn_b)
kl = ds.kl_divergence(mvn_a, mvn_b)
self.assertEqual(batch_shape, kl.get_shape())
kl_v = kl.eval()
@ -285,7 +285,7 @@ class MultivariateNormalTriLTest(test.TestCase):
validate_args=True)
# Should be zero since KL(p || p) = =.
kl = ds.kl(mvn_a, mvn_a)
kl = ds.kl_divergence(mvn_a, mvn_a)
self.assertEqual(batch_shape, kl.get_shape())
kl_v = kl.eval()
@ -323,7 +323,7 @@ class MultivariateNormalTriLTest(test.TestCase):
sample_kl_chol = math_ops.reduce_mean(
dist.log_prob(samps) - mvn_chol.log_prob(samps), 0)
analytical_kl_chol = ds.kl(dist, mvn_chol)
analytical_kl_chol = ds.kl_divergence(dist, mvn_chol)
scale = dist.scale.to_dense()

View File

@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.contrib.distributions.python.ops import kullback_leibler
from tensorflow.contrib.distributions.python.ops import onehot_categorical
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@ -27,6 +26,7 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import kullback_leibler
from tensorflow.python.platform import test
@ -178,8 +178,8 @@ class OneHotCategoricalTest(test.TestCase):
kl_expected = np.sum(
prob_p * (np.log(prob_p) - np.log(prob_q)), axis=-1)
kl_actual = kullback_leibler.kl(p, q)
kl_same = kullback_leibler.kl(p, p)
kl_actual = kullback_leibler.kl_divergence(p, q)
kl_same = kullback_leibler.kl_divergence(p, p)
x = p.sample(int(2e4), seed=0)
x = math_ops.cast(x, dtype=dtypes.float32)
# Compute empirical KL(p||q).

View File

@ -18,7 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.distributions.python.ops import kullback_leibler
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
@ -29,6 +28,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import kullback_leibler
from tensorflow.python.ops.distributions import util as distribution_util

View File

@ -20,7 +20,6 @@ from __future__ import print_function
import numpy as np
from tensorflow.contrib.distributions.python.ops import kullback_leibler
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@ -32,6 +31,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import kullback_leibler
from tensorflow.python.ops.distributions import util as distribution_util

View File

@ -94,7 +94,9 @@ class Softplus(bijector.Bijector):
if validate_args:
nonzero_check = check_ops.assert_none_equal(
ops.convert_to_tensor(
0, dtype=self.hinge_softness.dtype), self.hinge_softness)
0, dtype=self.hinge_softness.dtype),
self.hinge_softness,
message="hinge_softness must be non-zero")
self._hinge_softness = control_flow_ops.with_dependencies(
[nonzero_check], self.hinge_softness)

View File

@ -18,7 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.distributions.python.ops import kullback_leibler
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@ -28,6 +27,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import kullback_leibler
from tensorflow.python.ops.distributions import util as distribution_util

View File

@ -20,7 +20,6 @@ from __future__ import print_function
import numpy as np
from tensorflow.contrib.distributions.python.ops import kullback_leibler
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@ -32,6 +31,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import kullback_leibler
from tensorflow.python.ops.distributions import util as distribution_util

View File

@ -22,7 +22,6 @@ import math
import numpy as np
from tensorflow.contrib.distributions.python.ops import special_math
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@ -33,6 +32,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import special_math
__all__ = [

View File

@ -20,8 +20,6 @@ from __future__ import print_function
from tensorflow.contrib import linalg
from tensorflow.contrib.distributions.python.ops import bijectors
from tensorflow.contrib.distributions.python.ops import kullback_leibler
from tensorflow.contrib.distributions.python.ops import normal
from tensorflow.contrib.distributions.python.ops import transformed_distribution
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
@ -29,6 +27,8 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import kullback_leibler
from tensorflow.python.ops.distributions import normal
from tensorflow.python.ops.distributions import util as distribution_util

View File

@ -18,9 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.distributions.python.ops.normal import Normal # pylint: disable=line-too-long
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import normal
def normal_conjugates_known_scale_posterior(prior, scale, s, n):
@ -65,7 +64,7 @@ def normal_conjugates_known_scale_posterior(prior, scale, s, n):
TypeError: if dtype of `s` does not match `dtype`, or `prior` is not a
Normal object.
"""
if not isinstance(prior, Normal):
if not isinstance(prior, normal.Normal):
raise TypeError("Expected prior to be an instance of type Normal")
if s.dtype != prior.dtype:
@ -77,7 +76,7 @@ def normal_conjugates_known_scale_posterior(prior, scale, s, n):
scale0_2 = math_ops.square(prior.scale)
scale_2 = math_ops.square(scale)
scalep_2 = 1.0/(1/scale0_2 + n/scale_2)
return Normal(
return normal.Normal(
loc=(prior.loc/scale0_2 + s/scale_2) * scalep_2,
scale=math_ops.sqrt(scalep_2))
@ -131,7 +130,7 @@ def normal_conjugates_known_scale_predictive(prior, scale, s, n):
TypeError: if dtype of `s` does not match `dtype`, or `prior` is not a
Normal object.
"""
if not isinstance(prior, Normal):
if not isinstance(prior, normal.Normal):
raise TypeError("Expected prior to be an instance of type Normal")
if s.dtype != prior.dtype:
@ -143,6 +142,6 @@ def normal_conjugates_known_scale_predictive(prior, scale, s, n):
scale0_2 = math_ops.square(prior.scale)
scale_2 = math_ops.square(scale)
scalep_2 = 1.0/(1/scale0_2 + n/scale_2)
return Normal(
return normal.Normal(
loc=(prior.loc/scale0_2 + s/scale_2) * scalep_2,
scale=math_ops.sqrt(scalep_2 + scale_2))

View File

@ -18,7 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.distributions.python.ops import kullback_leibler
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@ -28,6 +27,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import kullback_leibler
from tensorflow.python.ops.distributions import util as distribution_util

View File

@ -232,7 +232,7 @@ class QuantizedDistribution(distributions.Distribution):
graph_parents = self._dist._graph_parents # pylint: disable=protected-access
checks = []
if low is not None and high is not None:
if validate_args and low is not None and high is not None:
message = "low must be strictly less than high."
checks.append(
check_ops.assert_less(

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/command_line_flags.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
@ -36,6 +37,9 @@ static int ParseFlags(int argc, char* argv[], string* in_graph) {
Flag("in_graph", in_graph, "input graph file name"),
};
CHECK(Flags::Parse(&argc, argv, flag_list));
// We need to call this to set up global state for TensorFlow.
port::InitMain(argv[0], &argc, &argv);
string usage = Flags::Usage(argv[0], flag_list);
CHECK(!in_graph->empty()) << "in_graph graph can't be empty.\n" << usage;

View File

@ -18,35 +18,6 @@ load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs")
load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
tf_custom_op_library(
# TODO(sibyl-Mooth6ku,ptucker): Understand why 'python/ops/_' is needed and fix it.
name = "python/ops/_bucketization_op.so",
srcs = [
"ops/bucketization_op.cc",
],
deps = [
"//tensorflow/contrib/layers/kernels:bucketization_kernel",
],
)
tf_gen_op_libs(
op_lib_names = ["bucketization_op"],
)
tf_gen_op_wrapper_py(
name = "bucketization_op",
deps = [":bucketization_op_op_lib"],
)
tf_kernel_library(
name = "bucketization_op_kernel",
deps = [
"//tensorflow/contrib/layers/kernels:bucketization_kernel",
"//tensorflow/core:framework",
],
alwayslink = 1,
)
tf_custom_op_library(
# TODO(sibyl-Mooth6ku,ptucker): Understand why 'python/ops/_' is needed and fix it.
name = "python/ops/_sparse_feature_cross_op.so",
@ -97,18 +68,14 @@ tf_custom_op_py_library(
"python/ops/sparse_ops.py",
],
dso = [
":python/ops/_bucketization_op.so",
":python/ops/_sparse_feature_cross_op.so",
],
kernels = [
":bucketization_op_kernel",
":sparse_feature_cross_op_kernel",
":bucketization_op_op_lib",
":sparse_feature_cross_op_op_lib",
],
srcs_version = "PY2AND3",
deps = [
":bucketization_op",
":sparse_feature_cross_op",
"//tensorflow/contrib/framework:framework_py",
"//tensorflow/contrib/lookup:lookup_py",
@ -315,22 +282,6 @@ py_test(
],
)
py_test(
name = "bucketization_op_test",
size = "small",
srcs = ["python/kernel_tests/bucketization_op_test.py"],
srcs_version = "PY2AND3",
deps = [
":layers_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
"//third_party/py/numpy",
],
)
py_test(
name = "sparse_feature_cross_op_test",
size = "medium",

View File

@ -18,17 +18,14 @@ these arguments.
…,
weight_init=<DEFAULT>,
bias_init=<DEFAULT>,
weight_collections=(tf.GraphKeys.WEIGHTS,),
bias_collections=(tf.GraphKeys.BIASES,),
output_collections=(tf.GraphKeys.ACTIVATIONS,),
weight_regularizer=None,
bias_regularizer=None,
name=None) : Tensor`
`x` is the input tensor.
Weights, biases, and activations (i.e., outputs) are, by default, added to the specified collections. Weights and biases are also added to
`tf.GraphKeys.GLOBAL_VARIABLES` and `tf.GraphKeys.TRAINABLE_VARIABLES`.
Weights and biases are added to `tf.GraphKeys.GLOBAL_VARIABLES` and
`tf.GraphKeys.TRAINABLE_VARIABLES`.
## optimizers.py

View File

@ -7,17 +7,6 @@ exports_files(["LICENSE"])
package(default_visibility = ["//tensorflow:__subpackages__"])
cc_library(
name = "bucketization_kernel",
srcs = ["bucketization_kernel.cc"],
deps = [
"//tensorflow/core:framework_headers_lib",
"//third_party/eigen3",
"@protobuf//:protobuf_headers",
],
alwayslink = 1,
)
cc_library(
name = "sparse_feature_cross_kernel",
srcs = ["sparse_feature_cross_kernel.cc"],

View File

@ -1,46 +0,0 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
namespace tensorflow {
REGISTER_OP("Bucketize")
.Input("input: T")
.Output("output: int32")
.Attr("T: {int32, int64, float, double}")
.Attr("boundaries: list(float)")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
Bucketizes 'input' based on 'boundaries'.
For example, if the inputs are
boundaries = [0, 10, 100]
input = [[-5, 10000]
[150, 10]
[5, 100]]
then the output will be
output = [[0, 3]
[3, 2]
[1, 3]]
input: Any shape of Tensor contains with int or float type.
boundaries: A sorted list of floats gives the boundary of the buckets.
output: Same shape with 'input', each value of input replaced with bucket index.
)doc");
} // namespace tensorflow

View File

@ -17,13 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.layers.ops import gen_bucketization_op
from tensorflow.contrib.util import loader
from tensorflow.python.framework import ops
from tensorflow.python.platform import resource_loader
_bucketization_op = loader.load_op_library(
resource_loader.get_path_to_datafile("_bucketization_op.so"))
from tensorflow.python.ops import math_ops
def bucketize(input_tensor, boundaries, name=None):
@ -43,10 +37,5 @@ def bucketize(input_tensor, boundaries, name=None):
Raises:
TypeError: If boundaries is not a list.
"""
if not isinstance(boundaries, list):
raise TypeError("boundaries must be a list")
return gen_bucketization_op.bucketize(input_tensor, boundaries, name=name)
ops.NotDifferentiable("Bucketize")
return math_ops._bucketize( # pylint: disable=protected-access
input_tensor, boundaries=boundaries, name=name)

View File

@ -1,247 +0,0 @@
# TF Learn
TF Learn is a simplified interface for TensorFlow, to get people started on predictive analytics and data mining. The library covers a variety of needs: from linear models to *Deep Learning* applications like text and image understanding.
### Why *TensorFlow*?
* TensorFlow provides a good backbone for building different shapes of machine learning applications.
* It will continue to evolve both in the distributed direction and as general pipelining machinery.
### Why *TensorFlow Learn*?
- To smooth the transition from the [scikit-learn](http://scikit-learn.org/stable/) world of one-liner machine learning into the more open world of building different shapes of ML models. You can start by using [fit](https://www.tensorflow.org/api_docs/python/tf/contrib/learn/Estimator#fit)/[predict](https://www.tensorflow.org/api_docs/python/tf/contrib/learn/Estimator#predict) and slide into TensorFlow APIs as you are getting comfortable.
- To provide a set of reference models that will be easy to integrate with existing code.
## Installation
[Install TensorFlow](https://www.tensorflow.org/install/), and then simply import `learn` via `from tensorflow.contrib.learn` or use `tf.contrib.learn`.
Optionally you can install [scikit-learn](http://scikit-learn.org/stable/) and [pandas](http://pandas.pydata.org/) for additional functionality.
### Tutorials
- [TF Learn Quickstart](https://www.tensorflow.org/get_started/tflearn). Build,
train, and evaluate a neural network with just a few lines of code.
- [Input Functions](https://www.tensorflow.org/get_started/input_fn). Learn how
to create input functions to feed data into your models.
- [Linear Model](https://www.tensorflow.org/tutorials/wide). Learn the basics
of building linear models.
- [Wide and Deep Learning](https://www.tensorflow.org/tutorials/wide_and_deep).
Jointly train a linear model and a deep neural network.
- [Logging and Monitoring](https://www.tensorflow.org/get_started/monitors).
Use the Monitor API to audit training of a neural network.
- [Custom Estimators](https://www.tensorflow.org/extend/estimators). Learn
how to create a custom estimator.
- More coming soon.
### Community
- Twitter [#tensorflow](https://twitter.com/search?q=tensorflow&src=typd).
- StackOverflow with [tensorflow tag](http://stackoverflow.com/questions/tagged/tensorflow) for questions and struggles.
- GitHub [issues](https://github.com/tensorflow/tensorflow/issues) for technical discussions and feature requests.
### Existing Estimator Implementations
- [`LinearClassifier`](https://www.tensorflow.org/code/tensorflow/contrib/learn/python/learn/estimators/linear.py)
([docs](https://www.tensorflow.org/api_docs/python/tf/contrib/learn/LinearClassifier))
- [`LinearRegressor`](https://www.tensorflow.org/code/tensorflow/contrib/learn/python/learn/estimators/linear.py)
([docs](https://www.tensorflow.org/api_docs/python/tf/contrib/learn/LinearRegressor))
- [`DNNClassifier`](https://www.tensorflow.org/code/tensorflow/contrib/learn/python/learn/estimators/dnn.py)
([docs](https://www.tensorflow.org/api_docs/python/tf/contrib/learn/DNNClassifier))
- [`DNNRegressor`](https://www.tensorflow.org/code/tensorflow/contrib/learn/python/learn/estimators/dnn.py)
([docs](https://www.tensorflow.org/api_docs/python/tf/contrib/learn/DNNRegressor))
- [`DNNLinearCombinedClassifier`](https://www.tensorflow.org/code/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py)
([docs](https://www.tensorflow.org/api_docs/python/tf/contrib/learn/DNNLinearCombinedClassifier))
- [`DNNLinearCombinedRegressor`](https://www.tensorflow.org/code/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py)
([docs](https://www.tensorflow.org/api_docs/python/tf/contrib/learn/DNNLinearCombinedRegressor))
- [`SVM`](https://www.tensorflow.org/code/tensorflow/contrib/learn/python/learn/estimators/svm.py)
([docs](https://www.tensorflow.org/code/tensorflow/contrib/learn/python/learn/estimators/g3doc/svm.md))
- [`GMM`](https://www.tensorflow.org/code/tensorflow/contrib/factorization/python/ops/gmm.py)
([docs](https://www.tensorflow.org/code/tensorflow/contrib/factorization/g3doc/gmm.md))
- [`KMeansClustering`](https://www.tensorflow.org/code/tensorflow/contrib/learn/python/learn/estimators/kmeans.py)
([docs](https://www.tensorflow.org/code/tensorflow/contrib/factorization/g3doc/kmeans.md))
### Usage Examples
Below are a few simple examples of the API. For more examples, please see [examples](https://www.tensorflow.org/code/tensorflow/examples/learn).
General tips:
- It's useful to rescale a dataset to 0 mean and unit standard deviation before passing it to an [`Estimator`](https://www.tensorflow.org/api_docs/python/tf/contrib/learn/Estimator). [Stochastic Gradient Descent](https://en.wikipedia.org/wiki/Stochastic_gradient_descent) doesn't always do the right thing when variable are at very different scales.
- Categorical variables should be managed before passing input to the estimator.
## Linear Classifier
Simple linear classification:
```python
import tensorflow.contrib.learn.python.learn as learn
from sklearn import datasets, metrics
iris = datasets.load_iris()
feature_columns = learn.infer_real_valued_columns_from_input(iris.data)
classifier = learn.LinearClassifier(n_classes=3, feature_columns=feature_columns)
classifier.fit(iris.data, iris.target, steps=200, batch_size=32)
iris_predictions = list(classifier.predict(iris.data, as_iterable=True))
score = metrics.accuracy_score(iris.target, iris_predictions)
print("Accuracy: %f" % score)
```
## Linear Regressor
Simple linear regression:
```python
import tensorflow.contrib.learn.python.learn as learn
from sklearn import datasets, metrics, preprocessing
boston = datasets.load_boston()
x = preprocessing.StandardScaler().fit_transform(boston.data)
feature_columns = learn.infer_real_valued_columns_from_input(x)
regressor = learn.LinearRegressor(feature_columns=feature_columns)
regressor.fit(x, boston.target, steps=200, batch_size=32)
boston_predictions = list(regressor.predict(x, as_iterable=True))
score = metrics.mean_squared_error(boston_predictions, boston.target)
print ("MSE: %f" % score)
```
## Deep Neural Network
Example of 3 layer network with 10, 20 and 10 hidden units respectively:
```python
import tensorflow.contrib.learn.python.learn as learn
from sklearn import datasets, metrics
iris = datasets.load_iris()
feature_columns = learn.infer_real_valued_columns_from_input(iris.data)
classifier = learn.DNNClassifier(hidden_units=[10, 20, 10], n_classes=3, feature_columns=feature_columns)
classifier.fit(iris.data, iris.target, steps=200, batch_size=32)
iris_predictions = list(classifier.predict(iris.data, as_iterable=True))
score = metrics.accuracy_score(iris.target, iris_predictions)
print("Accuracy: %f" % score)
```
## Custom model
Example of how to pass a custom model to the Estimator:
```python
from sklearn import datasets
from sklearn import metrics
import tensorflow as tf
import tensorflow.contrib.layers.python.layers as layers
import tensorflow.contrib.learn.python.learn as learn
iris = datasets.load_iris()
def my_model(features, labels):
"""DNN with three hidden layers."""
# Convert the labels to a one-hot tensor of shape (length of features, 3) and
# with a on-value of 1 for each one-hot vector of length 3.
labels = tf.one_hot(labels, 3, 1, 0)
# Create three fully connected layers respectively of size 10, 20, and 10.
features = layers.stack(features, layers.fully_connected, [10, 20, 10])
# Create two tensors respectively for prediction and loss.
prediction, loss = (
tf.contrib.learn.models.logistic_regression(features, labels)
)
# Create a tensor for training op.
train_op = tf.contrib.layers.optimize_loss(
loss, tf.contrib.framework.get_global_step(), optimizer='Adagrad',
learning_rate=0.1)
return {'class': tf.argmax(prediction, 1), 'prob': prediction}, loss, train_op
classifier = learn.Estimator(model_fn=my_model)
classifier.fit(iris.data, iris.target, steps=1000)
y_predicted = [
p['class'] for p in classifier.predict(iris.data, as_iterable=True)]
score = metrics.accuracy_score(iris.target, y_predicted)
print('Accuracy: {0:f}'.format(score))
```
## Saving / Restoring models
Each estimator supports a `model_dir` argument, which takes a folder path where all model information will be saved:
```python
classifier = learn.DNNClassifier(..., model_dir="/tmp/my_model")
```
If you run multiple `fit` operations on the same `Estimator`, training will resume where the last operation left off, e.g.:
<pre><strong>classifier = learn.DNNClassifier(..., model_dir="/tmp/my_model")
classifier.fit(..., steps=300)</strong>
INFO:tensorflow:Create CheckpointSaverHook
INFO:tensorflow:loss = 2.40115, step = 1
INFO:tensorflow:Saving checkpoints for 1 into /tmp/leftoff/model.ckpt.
INFO:tensorflow:loss = 0.338706, step = 101
INFO:tensorflow:loss = 0.159414, step = 201
INFO:tensorflow:Saving checkpoints for 300 into /tmp/leftoff/model.ckpt.
INFO:tensorflow:Loss for final step: 0.0953846.
<strong>classifier.fit(..., steps=300)</strong>
INFO:tensorflow:Create CheckpointSaverHook
INFO:tensorflow:loss = 0.113173, step = 301
INFO:tensorflow:Saving checkpoints for 301 into /tmp/leftoff/model.ckpt.
INFO:tensorflow:loss = 0.175782, step = 401
INFO:tensorflow:loss = 0.119735, step = 501
INFO:tensorflow:Saving checkpoints for 600 into /tmp/leftoff/model.ckpt.
INFO:tensorflow:Loss for final step: 0.0518137.</pre>
To restore checkpoints to a new `Estimator`, just pass it the same `model_dir` argument, e.g.:
<pre><strong>classifier = learn.DNNClassifier(..., model_dir="/tmp/my_model")
classifier.fit(..., steps=300)</strong>
INFO:tensorflow:Create CheckpointSaverHook
INFO:tensorflow:loss = 1.16335, step = 1
INFO:tensorflow:Saving checkpoints for 1 into /tmp/leftoff/model.ckpt.
INFO:tensorflow:loss = 0.176995, step = 101
INFO:tensorflow:loss = 0.184573, step = 201
INFO:tensorflow:Saving checkpoints for 300 into /tmp/leftoff/model.ckpt.
INFO:tensorflow:Loss for final step: 0.0512496.
<strong>classifier2 = learn.DNNClassifier(..., model_dir="/tmp/my_model")
classifier2.fit(..., steps=300)</strong>
INFO:tensorflow:Create CheckpointSaverHook
INFO:tensorflow:loss = 0.0543797, step = 301
INFO:tensorflow:Saving checkpoints for 301 into /tmp/leftoff/model.ckpt.
INFO:tensorflow:loss = 0.101036, step = 401
INFO:tensorflow:loss = 0.137956, step = 501
INFO:tensorflow:Saving checkpoints for 600 into /tmp/leftoff/model.ckpt.
INFO:tensorflow:Loss for final step: 0.0162506.</pre>
## Summaries
If you supply a `model_dir` argument to your `Estimator`s, TensorFlow will write summaries for ``loss`` and histograms for variables in this directory. (You can also add custom summaries in your custom model function by calling [Summary](https://www.tensorflow.org/api_guides/python/summary) operations.)
To view the summaries in TensorBoard, run the following command, where `logdir` is the `model_dir` for your `Estimator`:
```shell
tensorboard --logdir=/tmp/tf_examples/my_model_1
```
and then load the reported URL.
**Graph visualization**
![Text classification RNN Graph](https://raw.githubusercontent.com/tensorflow/skflow/master/g3doc/images/text_classification_rnn_graph.png)
**Loss visualization**
![Text classification RNN Loss](https://raw.githubusercontent.com/tensorflow/skflow/master/g3doc/images/text_classification_rnn_loss.png)
## More examples
See the [examples folder](https://www.tensorflow.org/code/tensorflow/examples/learn) for:
- An easy way to handle [categorical variables](https://www.tensorflow.org/code/tensorflow/examples/learn/text_classification.py) (words are just an example of a categorical variable)
- Text Classification: see examples for [RNN](https://www.tensorflow.org/code/tensorflow/examples/learn/text_classification_character_rnn.py) and [CNN](https://www.tensorflow.org/code/tensorflow/examples/learn/text_classification_character_cnn.py) on characters
- [Digit recognition using a CNN](https://www.tensorflow.org/code/tensorflow/examples/learn/mnist.py)
- And much more!

View File

@ -1287,9 +1287,6 @@ class Estimator(BaseEstimator):
else:
saver_for_restore = saver.Saver(sharded=True)
with tf_session.Session('') as session:
variables.initialize_local_variables()
data_flow_ops.tables_initializer()
resources.initialize_resources(resources.shared_resources())
saver_for_restore.restore(session, checkpoint_path)
init_op = control_flow_ops.group(
variables.local_variables_initializer(),

View File

@ -37,6 +37,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import sparse_ops
@ -613,7 +614,9 @@ def _create_model_fn_ops(features,
if (mode != model_fn.ModeKeys.INFER) and (labels is not None):
weight_tensor = _weight_tensor(features, weight_column_name)
loss, weighted_average_loss = loss_fn(labels, logits, weight_tensor)
summary.scalar(
# Uses the deprecated API to set the tag explicitly.
# Without it, trianing and eval losses will show up in different graphs.
logging_ops.scalar_summary(
_summary_key(head_name, mkey.LOSS), weighted_average_loss)
if mode == model_fn.ModeKeys.TRAIN:

View File

@ -123,7 +123,7 @@ class PoissonHeadTest(test.TestCase):
train_op_fn=head_lib.no_op_train_fn,
logits=logits)
self._assert_output_alternatives(model_fn_ops)
_assert_summary_tags(self, ["regression_head/loss"])
_assert_summary_tags(self, ["loss"])
_assert_no_variables(self)
loss = self._log_poisson_loss(logits, labels)
_assert_metrics(self, loss, {"loss": loss}, model_fn_ops)
@ -149,7 +149,7 @@ class RegressionHeadTest(test.TestCase):
train_op_fn=head_lib.no_op_train_fn,
logits=((1.,), (1.,), (3.,)))
self._assert_output_alternatives(model_fn_ops)
_assert_summary_tags(self, ["regression_head/loss"])
_assert_summary_tags(self, ["loss"])
_assert_no_variables(self)
_assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops)
@ -179,7 +179,7 @@ class RegressionHeadTest(test.TestCase):
_assert_variables(
self, expected_global=w, expected_model=w, expected_trainable=w)
variables.global_variables_initializer().run()
_assert_summary_tags(self, ["regression_head/loss"])
_assert_summary_tags(self, ["loss"])
_assert_metrics(self, 2. / 3, {"loss": 2. / 3}, model_fn_ops)
def testRegressionWithLogitsAndLogitsInput(self):
@ -207,7 +207,7 @@ class RegressionHeadTest(test.TestCase):
self._assert_output_alternatives(model_fn_ops)
self.assertIsNone(model_fn_ops.train_op)
_assert_no_variables(self)
_assert_summary_tags(self, ["regression_head/loss"])
_assert_summary_tags(self, ["loss"])
_assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops)
def testRegressionWithLabelName(self):
@ -222,7 +222,7 @@ class RegressionHeadTest(test.TestCase):
logits=((1.,), (1.,), (3.,)))
self._assert_output_alternatives(model_fn_ops)
_assert_no_variables(self)
_assert_summary_tags(self, ["regression_head/loss"])
_assert_summary_tags(self, ["loss"])
_assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops)
def testRegressionWithWeights(self):
@ -237,7 +237,7 @@ class RegressionHeadTest(test.TestCase):
logits=((1.,), (1.,), (3.,)))
self._assert_output_alternatives(model_fn_ops)
_assert_no_variables(self)
_assert_summary_tags(self, ["regression_head/loss"])
_assert_summary_tags(self, ["loss"])
_assert_metrics(self, 2. / len(weights), {"loss": 2. / np.sum(weights)},
model_fn_ops)
@ -260,7 +260,7 @@ class RegressionHeadTest(test.TestCase):
expected_trainable=("regression_head/centered_bias_weight:0",))
variables.global_variables_initializer().run()
_assert_summary_tags(self, [
"regression_head/loss",
"loss",
"regression_head/centered_bias/bias_0"
])
_assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops)
@ -331,7 +331,7 @@ class MultiLabelHeadTest(test.TestCase):
logits=self._logits)
self._assert_output_alternatives(model_fn_ops)
_assert_no_variables(self)
_assert_summary_tags(self, ["multi_label_head/loss"])
_assert_summary_tags(self, ["loss"])
expected_loss = .89985204
_assert_metrics(self, expected_loss,
self._expected_eval_metrics(expected_loss), model_fn_ops)
@ -348,7 +348,7 @@ class MultiLabelHeadTest(test.TestCase):
train_op_fn=head_lib.no_op_train_fn, logits=logits)
self._assert_output_alternatives(model_fn_ops)
_assert_no_variables(self)
_assert_summary_tags(self, ["multi_label_head/loss"])
_assert_summary_tags(self, ["loss"])
expected_loss = 1.00320443
_assert_metrics(self, expected_loss, {
"accuracy": 0.,
@ -388,7 +388,7 @@ class MultiLabelHeadTest(test.TestCase):
_assert_variables(
self, expected_global=w, expected_model=w, expected_trainable=w)
variables.global_variables_initializer().run()
_assert_summary_tags(self, ["multi_label_head/loss"])
_assert_summary_tags(self, ["loss"])
expected_loss = .69314718
_assert_metrics(self, expected_loss, {
"accuracy": 2. / 3,
@ -433,7 +433,7 @@ class MultiLabelHeadTest(test.TestCase):
self._assert_output_alternatives(model_fn_ops)
self.assertIsNone(model_fn_ops.train_op)
_assert_no_variables(self)
_assert_summary_tags(self, ["multi_label_head/loss"])
_assert_summary_tags(self, ["loss"])
expected_loss = .89985204
_assert_metrics(self, expected_loss,
self._expected_eval_metrics(expected_loss), model_fn_ops)
@ -452,7 +452,7 @@ class MultiLabelHeadTest(test.TestCase):
self._assert_output_alternatives(model_fn_ops)
self.assertIsNone(model_fn_ops.train_op)
_assert_no_variables(self)
_assert_summary_tags(self, ["multi_label_head/loss"])
_assert_summary_tags(self, ["loss"])
expected_loss = 1.377779
expected_eval_metrics = {
"accuracy": 1. / 3,
@ -520,7 +520,7 @@ class MultiLabelHeadTest(test.TestCase):
head_lib.no_op_train_fn, logits=self._logits)
self._assert_output_alternatives(model_fn_ops)
_assert_no_variables(self)
_assert_summary_tags(self, ["multi_label_head/loss"])
_assert_summary_tags(self, ["loss"])
expected_loss = .89985204
_assert_metrics(self, expected_loss,
self._expected_eval_metrics(expected_loss), model_fn_ops)
@ -540,7 +540,7 @@ class MultiLabelHeadTest(test.TestCase):
logits=self._logits)
self._assert_output_alternatives(model_fn_ops)
_assert_no_variables(self)
_assert_summary_tags(self, ["multi_label_head/loss"])
_assert_summary_tags(self, ["loss"])
_assert_metrics(self, .089985214,
self._expected_eval_metrics(.89985214), model_fn_ops)
@ -559,7 +559,7 @@ class MultiLabelHeadTest(test.TestCase):
logits=self._logits)
self._assert_output_alternatives(model_fn_ops)
_assert_no_variables(self)
_assert_summary_tags(self, ["multi_label_head/loss"])
_assert_summary_tags(self, ["loss"])
_assert_metrics(self, .089985214,
self._expected_eval_metrics(.89985214), model_fn_ops)
@ -579,7 +579,7 @@ class MultiLabelHeadTest(test.TestCase):
logits=self._logits)
self._assert_output_alternatives(model_fn_ops)
_assert_no_variables(self)
_assert_summary_tags(self, ["multi_label_head/loss"])
_assert_summary_tags(self, ["loss"])
expected_loss = .089985214
_assert_metrics(self, expected_loss,
self._expected_eval_metrics(expected_loss), model_fn_ops)
@ -604,7 +604,7 @@ class MultiLabelHeadTest(test.TestCase):
expected_trainable=("multi_label_head/centered_bias_weight:0",))
variables.global_variables_initializer().run()
_assert_summary_tags(self, (
"multi_label_head/loss",
"loss",
"multi_label_head/centered_bias/bias_0",
"multi_label_head/centered_bias/bias_1",
"multi_label_head/centered_bias/bias_2"
@ -629,7 +629,7 @@ class MultiLabelHeadTest(test.TestCase):
train_op_fn=head_lib.no_op_train_fn,
logits=self._logits)
_assert_no_variables(self)
_assert_summary_tags(self, ["multi_label_head/loss"])
_assert_summary_tags(self, ["loss"])
expected_loss = .89985204
_assert_metrics(self, expected_loss,
self._expected_eval_metrics(expected_loss), model_fn_ops)
@ -695,7 +695,7 @@ class BinaryClassificationHeadTest(test.TestCase):
logits=self._logits)
self._assert_output_alternatives(model_fn_ops)
_assert_no_variables(self)
_assert_summary_tags(self, ["binary_logistic_head/loss"])
_assert_summary_tags(self, ["loss"])
expected_loss = .81326175
_assert_metrics(self, expected_loss,
self._expected_eval_metrics(expected_loss), model_fn_ops)
@ -723,7 +723,7 @@ class BinaryClassificationHeadTest(test.TestCase):
_assert_variables(
self, expected_global=w, expected_model=w, expected_trainable=w)
variables.global_variables_initializer().run()
_assert_summary_tags(self, ["binary_logistic_head/loss"])
_assert_summary_tags(self, ["loss"])
expected_loss = .69314718
label_mean = np.mean(self._labels)
_assert_metrics(self, expected_loss, {
@ -759,7 +759,7 @@ class BinaryClassificationHeadTest(test.TestCase):
self._assert_output_alternatives(model_fn_ops)
self.assertIsNone(model_fn_ops.train_op)
_assert_no_variables(self)
_assert_summary_tags(self, ["binary_logistic_head/loss"])
_assert_summary_tags(self, ["loss"])
expected_loss = .81326175
_assert_metrics(self, expected_loss,
self._expected_eval_metrics(expected_loss), model_fn_ops)
@ -838,7 +838,7 @@ class BinaryClassificationHeadTest(test.TestCase):
logits=self._logits)
self._assert_output_alternatives(model_fn_ops)
_assert_no_variables(self)
_assert_summary_tags(self, ["binary_logistic_head/loss"])
_assert_summary_tags(self, ["loss"])
expected_loss = .81326175
_assert_metrics(self, expected_loss,
self._expected_eval_metrics(expected_loss), model_fn_ops)
@ -859,7 +859,7 @@ class BinaryClassificationHeadTest(test.TestCase):
logits=self._logits)
self._assert_output_alternatives(model_fn_ops)
_assert_no_variables(self)
_assert_summary_tags(self, ["binary_logistic_head/loss"])
_assert_summary_tags(self, ["loss"])
expected_total_loss = .31326166
_assert_metrics(
self,
@ -892,7 +892,7 @@ class BinaryClassificationHeadTest(test.TestCase):
logits=self._logits)
self._assert_output_alternatives(model_fn_ops)
_assert_no_variables(self)
_assert_summary_tags(self, ["binary_logistic_head/loss"])
_assert_summary_tags(self, ["loss"])
# logloss: z:label, x:logit
# z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
# expected_loss is (total_weighted_loss)/1 since htere is 1 nonzero
@ -932,7 +932,7 @@ class BinaryClassificationHeadTest(test.TestCase):
expected_trainable=("binary_logistic_head/centered_bias_weight:0",))
variables.global_variables_initializer().run()
_assert_summary_tags(self, [
"binary_logistic_head/loss",
"loss",
"binary_logistic_head/centered_bias/bias_0"
])
expected_loss = .81326175
@ -983,7 +983,7 @@ class MultiClassHeadTest(test.TestCase):
logits=self._logits)
self._assert_output_alternatives(model_fn_ops)
_assert_no_variables(self)
_assert_summary_tags(self, ["multi_class_head/loss"])
_assert_summary_tags(self, ["loss"])
expected_loss = 1.5514447
_assert_metrics(self, expected_loss,
self._expected_eval_metrics(expected_loss), model_fn_ops)
@ -1022,7 +1022,7 @@ class MultiClassHeadTest(test.TestCase):
_assert_variables(
self, expected_global=w, expected_model=w, expected_trainable=w)
variables.global_variables_initializer().run()
_assert_summary_tags(self, ["multi_class_head/loss"])
_assert_summary_tags(self, ["loss"])
expected_loss = 1.0986123
_assert_metrics(self, expected_loss, {
"accuracy": 0.,
@ -1073,7 +1073,7 @@ class MultiClassHeadTest(test.TestCase):
expected_trainable=("multi_class_head/centered_bias_weight:0",))
variables.global_variables_initializer().run()
_assert_summary_tags(self,
["multi_class_head/loss",
["loss",
"multi_class_head/centered_bias/bias_0",
"multi_class_head/centered_bias/bias_1",
"multi_class_head/centered_bias/bias_2"])
@ -1091,7 +1091,7 @@ class MultiClassHeadTest(test.TestCase):
self._assert_output_alternatives(model_fn_ops)
self.assertIsNone(model_fn_ops.train_op)
_assert_no_variables(self)
_assert_summary_tags(self, ["multi_class_head/loss"])
_assert_summary_tags(self, ["loss"])
expected_loss = 1.5514447
_assert_metrics(self, expected_loss,
self._expected_eval_metrics(expected_loss), model_fn_ops)
@ -1110,7 +1110,7 @@ class MultiClassHeadTest(test.TestCase):
self._assert_output_alternatives(model_fn_ops)
self.assertIsNone(model_fn_ops.train_op)
_assert_no_variables(self)
_assert_summary_tags(self, ["multi_class_head/loss"])
_assert_summary_tags(self, ["loss"])
expected_loss = 3.1698461
expected_eval_metrics = {
"accuracy": 0.,
@ -1149,7 +1149,7 @@ class MultiClassHeadTest(test.TestCase):
logits=self._logits)
self._assert_output_alternatives(model_fn_ops)
_assert_no_variables(self)
_assert_summary_tags(self, ["multi_class_head/loss"])
_assert_summary_tags(self, ["loss"])
expected_loss = 1.5514447
_assert_metrics(self, expected_loss * weight,
self._expected_eval_metrics(expected_loss), model_fn_ops)
@ -1173,7 +1173,7 @@ class MultiClassHeadTest(test.TestCase):
logits=self._logits)
self._assert_output_alternatives(model_fn_ops)
_assert_no_variables(self)
_assert_summary_tags(self, ["multi_class_head/loss"])
_assert_summary_tags(self, ["loss"])
expected_loss = 1.5514447 * weight
_assert_metrics(self, expected_loss,
self._expected_eval_metrics(expected_loss), model_fn_ops)
@ -1280,7 +1280,7 @@ class MultiClassHeadTest(test.TestCase):
data_flow_ops.tables_initializer().run()
self.assertIsNone(model_fn_ops.train_op)
_assert_no_variables(self)
_assert_summary_tags(self, ["multi_class_head/loss"])
_assert_summary_tags(self, ["loss"])
expected_loss = 1.5514447
expected_eval_metrics = {
"accuracy": 0.,
@ -1306,7 +1306,7 @@ class MultiClassHeadTest(test.TestCase):
data_flow_ops.tables_initializer().run()
self.assertIsNone(model_fn_ops.train_op)
_assert_no_variables(self)
_assert_summary_tags(self, ["multi_class_head/loss"])
_assert_summary_tags(self, ["loss"])
expected_loss = 0.5514447
expected_eval_metrics = {
"accuracy": 1.,
@ -1345,7 +1345,7 @@ class BinarySvmHeadTest(test.TestCase):
logits=self._predictions)
self._assert_output_alternatives(model_fn_ops)
_assert_no_variables(self)
_assert_summary_tags(self, ["binary_svm_head/loss"])
_assert_summary_tags(self, ["loss"])
expected_loss = np.average(self._expected_losses)
_assert_metrics(self, expected_loss, {
"accuracy": 1.,
@ -1375,7 +1375,7 @@ class BinarySvmHeadTest(test.TestCase):
_assert_variables(
self, expected_global=w, expected_model=w, expected_trainable=w)
variables.global_variables_initializer().run()
_assert_summary_tags(self, ["binary_svm_head/loss"])
_assert_summary_tags(self, ["loss"])
expected_loss = 1.
_assert_metrics(self, expected_loss, {
"accuracy": .5,
@ -1407,7 +1407,7 @@ class BinarySvmHeadTest(test.TestCase):
self._assert_output_alternatives(model_fn_ops)
self.assertIsNone(model_fn_ops.train_op)
_assert_no_variables(self)
_assert_summary_tags(self, ["binary_svm_head/loss"])
_assert_summary_tags(self, ["loss"])
expected_loss = np.average(self._expected_losses)
_assert_metrics(self, expected_loss, {
"accuracy": 1.,
@ -1426,7 +1426,7 @@ class BinarySvmHeadTest(test.TestCase):
logits=self._predictions)
self._assert_output_alternatives(model_fn_ops)
_assert_no_variables(self)
_assert_summary_tags(self, ["binary_svm_head/loss"])
_assert_summary_tags(self, ["loss"])
expected_loss = np.average(self._expected_losses)
_assert_metrics(self, expected_loss, {
"accuracy": 1.,
@ -1446,7 +1446,7 @@ class BinarySvmHeadTest(test.TestCase):
logits=self._predictions)
self._assert_output_alternatives(model_fn_ops)
_assert_no_variables(self)
_assert_summary_tags(self, ["binary_svm_head/loss"])
_assert_summary_tags(self, ["loss"])
expected_weighted_losses = np.multiply(weights, self._expected_losses)
_assert_metrics(self, np.mean(expected_weighted_losses), {
"accuracy": 1.,
@ -1473,7 +1473,7 @@ class BinarySvmHeadTest(test.TestCase):
expected_trainable=("binary_svm_head/centered_bias_weight:0",))
variables.global_variables_initializer().run()
_assert_summary_tags(self, [
"binary_svm_head/loss",
"loss",
"binary_svm_head/centered_bias/bias_0"
])
expected_loss = np.average(self._expected_losses)

View File

@ -45,6 +45,7 @@ See @{$python/contrib.rnn} guide.
@@UGRNNCell
@@IntersectionRNNCell
@@PhasedLSTMCell
@@HighwayWrapper
### RNNCell wrappers
@@AttentionCellWrapper

View File

@ -882,6 +882,30 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[1].c, expected_state_c)
self.assertAllClose(res[1].h, expected_state_h)
def testHighwayWrapper(self):
with self.test_session() as sess:
with variable_scope.variable_scope(
"base_cell", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 3])
m = array_ops.zeros([1, 3])
base_cell = core_rnn_cell_impl.GRUCell(3)
g, m_new = base_cell(x, m)
with variable_scope.variable_scope(
"hw_cell", initializer=init_ops.constant_initializer(0.5)):
hw_cell = rnn_cell.HighwayWrapper(
core_rnn_cell_impl.GRUCell(3), carry_bias_init=-100.0)
g_res, m_new_res = hw_cell(x, m)
sess.run([variables.global_variables_initializer()])
res = sess.run([g, g_res, m_new, m_new_res], {
x: np.array([[1., 1., 1.]]),
m: np.array([[0.1, 0.1, 0.1]])
})
# As carry_bias_init is very negative, the carry gate is 'open' and the
# transform gate is 'closed'. This means the output equals the input.
self.assertAllClose(res[1], res[0])
# States are left untouched
self.assertAllClose(res[2], res[3])
class LayerNormBasicLSTMCellTest(test.TestCase):

View File

@ -1157,6 +1157,89 @@ class AttentionCellWrapper(core_rnn_cell.RNNCell):
return new_attns, new_attn_states
class HighwayWrapper(core_rnn_cell.RNNCell):
"""RNNCell wrapper that adds highway connection on cell input and output.
Based on:
R. K. Srivastava, K. Greff, and J. Schmidhuber, "Highway networks",
arXiv preprint arXiv:1505.00387, 2015.
https://arxiv.org/abs/1505.00387
"""
def __init__(self, cell,
couple_carry_transform_gates=True,
carry_bias_init=1.0):
"""Constructs a `HighwayWrapper` for `cell`.
Args:
cell: An instance of `RNNCell`.
couple_carry_transform_gates: boolean, should the Carry and Transform gate
be coupled.
carry_bias_init: float, carry gates bias initialization.
"""
self._cell = cell
self._couple_carry_transform_gates = couple_carry_transform_gates
self._carry_bias_init = carry_bias_init
@property
def state_size(self):
return self._cell.state_size
@property
def output_size(self):
return self._cell.output_size
def zero_state(self, batch_size, dtype):
with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
return self._cell.zero_state(batch_size, dtype)
def _highway(self, inp, out):
input_size = inp.get_shape().with_rank(2)[1].value
carry_weight = vs.get_variable("carry_w", [input_size, input_size])
carry_bias = vs.get_variable(
"carry_b", [input_size],
initializer=init_ops.constant_initializer(
self._carry_bias_init))
carry = math_ops.sigmoid(nn_ops.xw_plus_b(inp, carry_weight, carry_bias))
if self._couple_carry_transform_gates:
transform = 1 - carry
else:
transform_weight = vs.get_variable("transform_w",
[input_size, input_size])
transform_bias = vs.get_variable(
"transform_b", [input_size],
initializer=init_ops.constant_initializer(
-self._carry_bias_init))
transform = math_ops.sigmoid(nn_ops.xw_plus_b(inp,
transform_weight,
transform_bias))
return inp * carry + out * transform
def __call__(self, inputs, state, scope=None):
"""Run the cell and add its inputs to its outputs.
Args:
inputs: cell inputs.
state: cell state.
scope: optional cell scope.
Returns:
Tuple of cell outputs and new state.
Raises:
TypeError: If cell inputs and outputs have different structure (type).
ValueError: If cell inputs and outputs have different structure (value).
"""
outputs, new_state = self._cell(inputs, state, scope=scope)
nest.assert_same_structure(inputs, outputs)
# Ensure shapes match
def assert_shape_match(inp, out):
inp.get_shape().assert_is_compatible_with(out.get_shape())
nest.map_structure(assert_shape_match, inputs, outputs)
res_outputs = nest.map_structure(self._highway, inputs, outputs)
return (res_outputs, new_state)
class LayerNormBasicLSTMCell(core_rnn_cell.RNNCell):
"""LSTM unit with layer normalization and recurrent dropout.

View File

@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/core/common_runtime/debugger_state_interface.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
// static
@ -58,11 +60,17 @@ void DebuggerStateRegistry::RegisterFactory(
}
// static
std::unique_ptr<DebuggerStateInterface> DebuggerStateRegistry::CreateState(
const DebugOptions& debug_options) {
return (factory_ == nullptr || *factory_ == nullptr)
? nullptr
: (*factory_)(debug_options);
Status DebuggerStateRegistry::CreateState(
const DebugOptions& debug_options,
std::unique_ptr<DebuggerStateInterface>* state) {
if (factory_ == nullptr || *factory_ == nullptr) {
return errors::Internal(
"Creation of debugger state failed. "
"It appears that TFDBG is not linked in this TensorFlow build.");
} else {
*state = (*factory_)(debug_options);
return Status::OK();
}
}
// static
@ -73,10 +81,17 @@ void DebugGraphDecoratorRegistry::RegisterFactory(
}
// static
std::unique_ptr<DebugGraphDecoratorInterface>
DebugGraphDecoratorRegistry::CreateDecorator(const DebugOptions& options) {
return (factory_ == nullptr || *factory_ == nullptr) ? nullptr
: (*factory_)(options);
Status DebugGraphDecoratorRegistry::CreateDecorator(
const DebugOptions& options,
std::unique_ptr<DebugGraphDecoratorInterface>* decorator) {
if (factory_ == nullptr || *factory_ == nullptr) {
return errors::Internal(
"Creation of graph decorator failed. "
"It appears that TFDBG is not linked in this TensorFlow build.");
} else {
*decorator = (*factory_)(options);
return Status::OK();
}
}
} // end namespace tensorflow

View File

@ -83,11 +83,12 @@ class DebuggerStateRegistry {
// implementation based on DebugOptions.
static void RegisterFactory(const DebuggerStateFactory& factory);
// If RegisterFactory() has been called, creates and returns a concrete
// If RegisterFactory() has been called, creates and supplies a concrete
// DebuggerStateInterface implementation using the registered factory,
// owned by the caller. Otherwise returns nullptr.
static std::unique_ptr<DebuggerStateInterface> CreateState(
const DebugOptions& debug_options);
// owned by the caller and return an OK Status. Otherwise returns an error
// Status.
static Status CreateState(const DebugOptions& debug_options,
std::unique_ptr<DebuggerStateInterface>* state);
private:
static DebuggerStateFactory* factory_;
@ -103,8 +104,9 @@ class DebugGraphDecoratorRegistry {
public:
static void RegisterFactory(const DebugGraphDecoratorFactory& factory);
static std::unique_ptr<DebugGraphDecoratorInterface> CreateDecorator(
const DebugOptions& options);
static Status CreateDecorator(
const DebugOptions& options,
std::unique_ptr<DebugGraphDecoratorInterface>* decorator);
private:
static DebugGraphDecoratorFactory* factory_;

View File

@ -376,31 +376,19 @@ Status DirectSession::CreateDebuggerState(
const std::vector<string>& output_names,
const std::vector<string>& target_names,
std::unique_ptr<DebuggerStateInterface>* debugger_state) {
std::unique_ptr<DebuggerStateInterface> state =
DebuggerStateRegistry::CreateState(debug_options);
if (!state) {
return errors::Internal(
"Debugger options are set, but creation of debugger state failed. "
"It appears that debugger is not linked in this TensorFlow build.");
}
TF_RETURN_IF_ERROR(state->PublishDebugMetadata(
TF_RETURN_IF_ERROR(
DebuggerStateRegistry::CreateState(debug_options, debugger_state));
TF_RETURN_IF_ERROR(debugger_state->get()->PublishDebugMetadata(
debug_options.global_step(), session_run_count, executor_step_count,
input_names, output_names, target_names));
*debugger_state = std::move(state);
return Status::OK();
}
Status DirectSession::DecorateAndPublishGraphForDebug(
const DebugOptions& debug_options, Graph* graph, Device* device) {
std::unique_ptr<DebugGraphDecoratorInterface> decorator =
DebugGraphDecoratorRegistry::CreateDecorator(debug_options);
if (!decorator) {
return errors::Internal(
"Debugger options are set, but creation of debug graph publisher ",
"failed.");
}
std::unique_ptr<DebugGraphDecoratorInterface> decorator;
TF_RETURN_IF_ERROR(
DebugGraphDecoratorRegistry::CreateDecorator(debug_options, &decorator));
TF_RETURN_IF_ERROR(decorator->DecorateGraph(graph, device));
TF_RETURN_IF_ERROR(decorator->PublishGraph(*graph));

View File

@ -456,7 +456,7 @@ Status FunctionLibraryRuntimeImpl::Instantiate(
void DumpGraph(StringPiece label, const Graph* g) {
// TODO(zhifengc): Change Graph to record #nodes.
VLOG(1) << "Graph " << label << " #nodes " << g->num_nodes() << " #edges "
<< g->edges().size();
<< g->num_edges();
if (VLOG_IS_ON(2)) {
for (const auto& line : str_util::Split(DebugString(g), '\n')) {
VLOG(2) << "|| " << line;

View File

@ -424,7 +424,7 @@ TEST_F(FunctionLibraryRuntimeTest, ControlDeps) {
n8 = NoOp() @ n4
n9 = Identity[T=float](n3) @ n8
n10 = Identity[T=float](n2) @ n8
n11 = NoOp() @ n10, n9
n11 = NoOp() @ n9, n10
n5 = Mul[T=float](n2, n2) @ n11
n6 = Add[T=float](n4, n5)
}
@ -500,8 +500,8 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) {
OptimizeGraph(lib_, &g);
const char* e2 = R"P(
(n2:float, n3:float) -> (n9:float) {
n11 = Const[dtype=int32, value=Tensor<type: int32 shape: [0] values: >]()
n10 = Const[dtype=float, value=Tensor<type: float shape: [] values: 2>]()
n11 = Const[dtype=int32, value=Tensor<type: int32 shape: [0] values: >]()
n6 = Shape[T=float, out_type=int32](n2)
n5 = Mul[T=float](n3, n10)
n7 = BroadcastGradientArgs[T=int32](n6, n11)
@ -614,10 +614,10 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) {
n17 = Sum[T=float, Tidx=int32, keep_dims=false](n14, n16)
n19 = SymbolicGradient[Tin={float, int32, float}, Tout={float, int32}, f=Sum[T=float, Tidx=int32, keep_dims=false]](n14, n16, n26)
n21 = SymbolicGradient[Tin={float, float, float}, Tout={float, float}, f=Add[T=float]](n24, n25, n19)
n28 = Identity[T=float](n21:1)
n27 = Identity[T=float](n21)
n6 = Identity[T=float](n28)
n28 = Identity[T=float](n21:1)
n8 = Identity[T=float](n27)
n6 = Identity[T=float](n28)
}
)P";
EXPECT_EQ(e1, DebugString(g.get()));
@ -626,8 +626,8 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) {
const char* e2 = R"P(
(n4:float, n3:float) -> (n25:float, n23:float) {
n2 = Const[dtype=float, value=Tensor<type: float shape: [] values: 1>]()
n8 = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]()
n7 = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 1>]()
n8 = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]()
n19 = Shape[T=float, out_type=int32](n3)
n9 = Add[T=float](n4, n3)
n20 = Shape[T=float, out_type=int32](n4)
@ -641,10 +641,10 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) {
n16 = Reshape[T=float, Tshape=int32](n2, n15)
n17 = Div[T=int32](n14, n15)
n18 = Tile[T=float, Tmultiples=int32](n16, n17)
n24 = Sum[T=float, Tidx=int32, keep_dims=false](n18, n21)
n22 = Sum[T=float, Tidx=int32, keep_dims=false](n18, n21:1)
n25 = Reshape[T=float, Tshape=int32](n24, n20)
n24 = Sum[T=float, Tidx=int32, keep_dims=false](n18, n21)
n23 = Reshape[T=float, Tshape=int32](n22, n19)
n25 = Reshape[T=float, Tshape=int32](n24, n20)
}
)P";
EXPECT_EQ(e2, DebugString(g.get()));

View File

@ -97,16 +97,11 @@ static Status ValidateGraphDefForDevices(const GraphDef& gdef) {
Status GraphMgr::DecorateAndPublishGraphForDebug(
const DebugOptions& debug_options, Graph* graph, Device* device) {
std::unique_ptr<DebugGraphDecoratorInterface> decorator =
DebugGraphDecoratorRegistry::CreateDecorator(debug_options);
if (!decorator) {
return errors::Internal(
"Debugger options are set, but creation of debug graph publisher ",
"failed.");
}
std::unique_ptr<DebugGraphDecoratorInterface> decorator;
TF_RETURN_IF_ERROR(
DebugGraphDecoratorRegistry::CreateDecorator(debug_options, &decorator));
TF_RETURN_IF_ERROR(decorator->DecorateGraph(graph, device));
TF_RETURN_IF_ERROR(decorator->PublishGraph(*graph));
return Status::OK();
}

View File

@ -1337,13 +1337,8 @@ Status MasterSession::CreateDebuggerState(
const DebugOptions& debug_options, const RunStepRequestWrapper& req,
int64 rcg_execution_count,
std::unique_ptr<DebuggerStateInterface>* debugger_state) {
std::unique_ptr<DebuggerStateInterface> state =
DebuggerStateRegistry::CreateState(debug_options);
if (!state) {
return errors::Internal(
"Debugger options are set, but creation of debugger state failed. "
"It appears that debugger is not linked in this TensorFlow build.");
}
TF_RETURN_IF_ERROR(
DebuggerStateRegistry::CreateState(debug_options, debugger_state));
std::vector<string> input_names;
for (size_t i = 0; i < req.num_feeds(); ++i) {
@ -1362,11 +1357,10 @@ Status MasterSession::CreateDebuggerState(
// While this counter value is straightforward to define and obtain for
// DirectSessions, it is less so for non-direct Sessions. Devise a better
// way to get its value when the need arises.
TF_RETURN_IF_ERROR(state->PublishDebugMetadata(
TF_RETURN_IF_ERROR(debugger_state->get()->PublishDebugMetadata(
debug_options.global_step(), -1, rcg_execution_count, input_names,
output_names, target_names));
*debugger_state = std::move(state);
return Status::OK();
}

View File

@ -882,6 +882,12 @@ Status FunctionLibraryDefinition::AddFunctionDef(const FunctionDef& fdef) {
fdef.signature().name(),
" already exists in function library.");
}
const OpDef* op_def;
if (default_registry_->LookUpOpDef(fdef.signature().name(), &op_def).ok()) {
return errors::InvalidArgument(
"Cannot add function '", fdef.signature().name(),
"' because an op with the same name already exists.");
}
ptr.reset(new FunctionDefAndOpRegistration(fdef));
return Status::OK();
}

View File

@ -944,6 +944,15 @@ TEST(FunctionLibraryDefinitionTest, AddFunctionDef) {
ASSERT_NE(second, nullptr);
EXPECT_EQ(second->DebugString(),
test::function::WXPlusB().signature().DebugString());
// Can't add function with same name as existing op
FunctionDef fdef = test::function::XTimesTwo();
fdef.mutable_signature()->set_name("Add");
Status s = lib_def.AddFunctionDef(fdef);
EXPECT_FALSE(s.ok());
EXPECT_EQ(s.error_message(),
"Cannot add function 'Add' because an op with the same name "
"already exists.");
}
TEST(FunctionLibraryDefinitionTest, AddGradientDef) {

View File

@ -432,6 +432,7 @@ class OpOutputList {
OpOutputList& operator=(const OpOutputList& other) = default;
Tensor* operator[](int i);
bool required(int i) const;
DataType expected_output_dtype(int i) const;
Status allocate(int i, const TensorShape& shape, Tensor** output);
void set(int i, const Tensor& tensor);
void set_ref(int i, mutex* mu, Tensor* tensor_for_ref);
@ -1452,6 +1453,12 @@ inline bool OpOutputList::required(int i) const {
return ctx_->output_required(start_ + i);
}
inline DataType OpOutputList::expected_output_dtype(int i) const {
DCHECK_GE(i, 0);
DCHECK_LT(i, stop_ - start_);
return ctx_->expected_output_dtype(start_ + i);
}
inline Status OpOutputList::allocate(int i, const TensorShape& shape,
Tensor** output) {
DCHECK_GE(i, 0);

View File

@ -613,6 +613,36 @@ TEST_F(OpKernelBuilderTest, BadConstraint) {
error::INVALID_ARGUMENT);
}
REGISTER_OP("ListOut").Output("a: int32").Output("b: T").Attr("T: list(type)");
REGISTER_KERNEL_BUILDER(Name("ListOut").Device(tensorflow::DEVICE_CPU),
DummyKernel);
TEST_F(OpKernelBuilderTest, OpOutputList) {
Env* env = Env::Default();
OpKernelContext::Params params;
params.record_tensor_accesses = false;
std::unique_ptr<DummyDevice> device(
new DummyDevice(env, params.record_tensor_accesses));
params.device = device.get();
Status status;
std::unique_ptr<OpKernel> op(CreateOpKernel(
DEVICE_CPU, params.device, cpu_allocator(),
CreateNodeDef("ListOut", {"T|list(type)|[DT_FLOAT, DT_INT32]"}),
TF_GRAPH_DEF_VERSION, &status));
EXPECT_TRUE(status.ok()) << status.ToString();
params.op_kernel = op.get();
gtl::InlinedVector<TensorValue, 4> inputs{};
params.inputs = &inputs;
std::unique_ptr<OpKernelContext> ctx(new OpKernelContext(&params));
EXPECT_EQ(DT_INT32, ctx->expected_output_dtype(0));
OpOutputList out_list;
EXPECT_FALSE(ctx->output_list("non_existent_output", &out_list).ok());
ASSERT_TRUE(ctx->output_list("b", &out_list).ok());
EXPECT_EQ(DT_FLOAT, out_list.expected_output_dtype(0));
EXPECT_EQ(DT_INT32, out_list.expected_output_dtype(1));
}
class GetAttrKernel : public ::tensorflow::OpKernel {
public:
explicit GetAttrKernel(OpKernelConstruction* context) : OpKernel(context) {

View File

@ -246,6 +246,14 @@ ResourceHandle HandleFromInput(OpKernelContext* ctx, int input) {
return ctx->input(input).flat<ResourceHandle>()(0);
}
Status HandleFromInput(OpKernelContext* ctx, StringPiece input,
ResourceHandle* handle) {
const Tensor* tensor;
TF_RETURN_IF_ERROR(ctx->input(input, &tensor));
*handle = tensor->flat<ResourceHandle>()(0);
return Status::OK();
}
Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p) {
TF_RETURN_IF_ERROR(internal::ValidateDevice(ctx, p));
return ctx->resource_manager()->Delete(p);

View File

@ -211,6 +211,8 @@ ResourceHandle MakePerStepResourceHandle(OpKernelContext* ctx,
// Returns a resource handle from a numbered op input.
ResourceHandle HandleFromInput(OpKernelContext* ctx, int input);
Status HandleFromInput(OpKernelContext* ctx, StringPiece input,
ResourceHandle* handle);
// Create a resource pointed by a given resource handle.
template <typename T>

View File

@ -344,7 +344,7 @@ const Edge* Graph::AddEdge(Node* source, int x, Node* dest, int y) {
CHECK(source->out_edges_.insert(e).second);
CHECK(dest->in_edges_.insert(e).second);
edges_.push_back(e);
edge_set_.insert(e);
++num_edges_;
return e;
}
@ -354,8 +354,8 @@ void Graph::RemoveEdge(const Edge* e) {
CHECK_EQ(e->src_->out_edges_.erase(e), size_t{1});
CHECK_EQ(e->dst_->in_edges_.erase(e), size_t{1});
CHECK_EQ(e, edges_[e->id_]);
CHECK_GT(num_edges_, 0);
CHECK_EQ(edge_set_.erase(e), size_t{1});
edges_[e->id_] = nullptr;
Edge* del = const_cast<Edge*>(e);
@ -365,6 +365,7 @@ void Graph::RemoveEdge(const Edge* e) {
del->src_output_ = kControlSlot - 1;
del->dst_input_ = kControlSlot - 1;
free_edges_.push_back(del);
--num_edges_;
}
Status Graph::AddFunctionLibrary(const FunctionDefLibrary& fdef_lib) {
@ -380,13 +381,6 @@ Status Graph::AddFunctionLibrary(const FunctionDefLibrary& fdef_lib) {
// Ignore duplicate FunctionDefs
continue;
}
// TODO(skyewm): fix test breakages and reenable this check
// const OpDef* op_def;
// if (ops_.LookUpOpDef(fdef.signature().name(), &op_def).ok()) {
// return errors::InvalidArgument(
// "Cannot add function '", fdef.signature().name(),
// "' because an op with the same name already exists.");
// }
TF_RETURN_IF_ERROR(ops_.AddFunctionDef(fdef));
}
for (const GradientDef& grad : fdef_lib.gradient()) {

View File

@ -268,6 +268,66 @@ class Edge {
int dst_input_;
};
// Allows for iteration of the edges of a Graph, by iterating the underlying
// Graph.edges_ vector while skipping over null entries.
class GraphEdgesIterable {
private:
const std::vector<Edge*>& edges_;
public:
explicit GraphEdgesIterable(const std::vector<Edge*>& edges)
: edges_(edges) {}
typedef Edge* value_type;
class const_iterator {
private:
// The underlying iterator.
std::vector<value_type>::const_iterator iter_;
// The end of the underlying iterator.
std::vector<value_type>::const_iterator end_;
// Advances iter_ until it reaches a non-null item, or reaches the end.
void apply_filter() {
while (iter_ != end_ && *iter_ == nullptr) {
++iter_;
}
}
public:
const_iterator(std::vector<value_type>::const_iterator iter,
std::vector<value_type>::const_iterator end)
: iter_(iter), end_(end) {
apply_filter();
}
bool operator==(const const_iterator& other) const {
return iter_ == other.iter_;
}
bool operator!=(const const_iterator& other) const {
return iter_ != other.iter_;
}
// This is the prefix increment operator (++x), which is the operator
// used by C++ range iteration (for (x : y) ...). We intentionally do not
// provide a postfix increment operator.
const_iterator& operator++() {
++iter_;
apply_filter();
return *this;
}
value_type operator*() { return *iter_; }
};
const_iterator begin() {
return const_iterator(edges_.begin(), edges_.end());
}
const_iterator end() { return const_iterator(edges_.end(), edges_.end()); }
};
// Thread compatible but not thread safe.
class Graph {
public:
@ -345,7 +405,7 @@ class Graph {
// smaller than num_edge_ids(). If one needs to create an array of
// edges indexed by edge ids, num_edge_ids() should be used as the
// array's size.
int num_edges() const { return edges().size(); }
int num_edges() const { return num_edges_; }
// Serialize the nodes starting at `from_node_id` to a GraphDef.
void ToGraphDefSubRange(GraphDef* graph_def, int from_node_id) const;
@ -381,7 +441,7 @@ class Graph {
// Access to the set of all edges. Example usage:
// for (const Edge* e : graph.edges()) { ... }
const EdgeSet& edges() const { return edge_set_; }
GraphEdgesIterable edges() const { return GraphEdgesIterable(edges_); }
// The pre-defined nodes.
enum { kSourceId = 0, kSinkId = 1 };
@ -421,9 +481,8 @@ class Graph {
// the edge with that id was removed from the graph.
std::vector<Edge*> edges_;
// For ease of iteration, we currently just keep a set of all live
// edges. May want to optimize by removing this copy.
EdgeSet edge_set_;
// The number of entries in edges_ that are not nullptr.
int num_edges_ = 0;
// Allocated but free nodes and edges.
std::vector<Node*> free_nodes_;

View File

@ -412,15 +412,14 @@ TEST_F(GraphTest, AddFunctionLibrary) {
"Cannot add function 'XTimesTwo' because a different function with "
"the same name already exists.");
// TODO(skyewm): reenable along with duplicate op check
// Function with same name as an existing op triggers an error
// error_proto = proto;
// error_proto.mutable_function(0)->mutable_signature()->set_name("Add");
// s = graph_.AddFunctionLibrary(error_proto);
// EXPECT_FALSE(s.ok());
// EXPECT_EQ(s.error_message(),
// "Cannot add function 'Add' because an op with the same name "
// "already exists.");
error_proto = proto;
error_proto.mutable_function(0)->mutable_signature()->set_name("Add");
s = graph_.AddFunctionLibrary(error_proto);
EXPECT_FALSE(s.ok());
EXPECT_EQ(s.error_message(),
"Cannot add function 'Add' because an op with the same name "
"already exists.");
// Adding a gradient function to an existing function is ok
GradientDef* grad = proto.add_gradient();

View File

@ -2249,6 +2249,7 @@ cc_library(
":batch_matmul_op",
":betainc_op",
":bincount_op",
":bucketize_op",
":cast_op",
":check_numerics_op",
":cross_op",
@ -2286,6 +2287,12 @@ tf_kernel_library(
deps = MATH_DEPS,
)
tf_kernel_library(
name = "bucketize_op",
prefix = "bucketize_op",
deps = MATH_DEPS,
)
tf_kernel_library(
name = "cast_op",
prefix = "cast_op",

View File

@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// See docs in ../ops/math_ops.cc.
#include <algorithm>
#include <vector>

View File

@ -367,7 +367,9 @@ class InitializeTableOp : public OpKernel {
GetInitializableLookupTable("table_handle", ctx, &table));
core::ScopedUnref unref_me(table);
DataTypeVector expected_inputs = {DT_STRING_REF, table->key_dtype(),
DataType expected_input_0 =
(ctx->input_dtype(0) == DT_RESOURCE) ? DT_RESOURCE : DT_STRING_REF;
DataTypeVector expected_inputs = {expected_input_0, table->key_dtype(),
table->value_dtype()};
DataTypeVector expected_outputs = {};
OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, expected_outputs));
@ -408,6 +410,8 @@ class InitializeTableOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("InitializeTable").Device(DEVICE_CPU),
InitializeTableOp);
REGISTER_KERNEL_BUILDER(Name("InitializeTableV2").Device(DEVICE_CPU),
InitializeTableOp);
// Kernel to initialize a lookup table from a text file.
//
@ -433,7 +437,9 @@ class InitializeTableFromTextFileOp : public OpKernel {
GetInitializableLookupTable("table_handle", ctx, &table));
core::ScopedUnref unref_me(table);
DataTypeVector expected_inputs = {DT_STRING_REF, DT_STRING};
DataType expected_input_0 =
(ctx->input_dtype(0) == DT_RESOURCE) ? DT_RESOURCE : DT_STRING_REF;
DataTypeVector expected_inputs = {expected_input_0, DT_STRING};
DataTypeVector expected_outputs = {};
OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, expected_outputs));
@ -472,5 +478,8 @@ class InitializeTableFromTextFileOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("InitializeTableFromTextFile").Device(DEVICE_CPU),
InitializeTableFromTextFileOp);
REGISTER_KERNEL_BUILDER(
Name("InitializeTableFromTextFileV2").Device(DEVICE_CPU),
InitializeTableFromTextFileOp);
} // namespace tensorflow

View File

@ -624,7 +624,10 @@ class LookupTableFindOp : public OpKernel {
OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table));
core::ScopedUnref unref_me(table);
DataTypeVector expected_inputs = {DT_STRING_REF, table->key_dtype(),
// Input 0 could be a STRING_REF or a RESOURCE
DataType expected_input_0 =
(ctx->input_dtype(0) == DT_RESOURCE) ? DT_RESOURCE : DT_STRING_REF;
DataTypeVector expected_inputs = {expected_input_0, table->key_dtype(),
table->value_dtype()};
DataTypeVector expected_outputs = {table->value_dtype()};
OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, expected_outputs));
@ -647,6 +650,8 @@ class LookupTableFindOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("LookupTableFind").Device(DEVICE_CPU),
LookupTableFindOp);
REGISTER_KERNEL_BUILDER(Name("LookupTableFindV2").Device(DEVICE_CPU),
LookupTableFindOp);
// Table insert op.
class LookupTableInsertOp : public OpKernel {
@ -658,7 +663,9 @@ class LookupTableInsertOp : public OpKernel {
OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table));
core::ScopedUnref unref_me(table);
DataTypeVector expected_inputs = {DT_STRING_REF, table->key_dtype(),
DataType expected_input_0 =
(ctx->input_dtype(0) == DT_RESOURCE) ? DT_RESOURCE : DT_STRING_REF;
DataTypeVector expected_inputs = {expected_input_0, table->key_dtype(),
table->value_dtype()};
OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {}));
@ -680,6 +687,8 @@ class LookupTableInsertOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("LookupTableInsert").Device(DEVICE_CPU),
LookupTableInsertOp);
REGISTER_KERNEL_BUILDER(Name("LookupTableInsertV2").Device(DEVICE_CPU),
LookupTableInsertOp);
// Op that returns the size of the given table.
class LookupTableSizeOp : public OpKernel {
@ -699,6 +708,8 @@ class LookupTableSizeOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("LookupTableSize").Device(DEVICE_CPU),
LookupTableSizeOp);
REGISTER_KERNEL_BUILDER(Name("LookupTableSizeV2").Device(DEVICE_CPU),
LookupTableSizeOp);
// Op that outputs tensors of all keys and all values.
class LookupTableExportOp : public OpKernel {
@ -716,6 +727,8 @@ class LookupTableExportOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("LookupTableExport").Device(DEVICE_CPU),
LookupTableExportOp);
REGISTER_KERNEL_BUILDER(Name("LookupTableExportV2").Device(DEVICE_CPU),
LookupTableExportOp);
// Clear the table and insert data.
class LookupTableImportOp : public OpKernel {
@ -727,7 +740,9 @@ class LookupTableImportOp : public OpKernel {
OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table));
core::ScopedUnref unref_me(table);
DataTypeVector expected_inputs = {DT_STRING_REF, table->key_dtype(),
DataType expected_input_0 =
(ctx->input_dtype(0) == DT_RESOURCE) ? DT_RESOURCE : DT_STRING_REF;
DataTypeVector expected_inputs = {expected_input_0, table->key_dtype(),
table->value_dtype()};
OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {}));
@ -749,6 +764,8 @@ class LookupTableImportOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("LookupTableImport").Device(DEVICE_CPU),
LookupTableImportOp);
REGISTER_KERNEL_BUILDER(Name("LookupTableImportV2").Device(DEVICE_CPU),
LookupTableImportOp);
// Register the HashTable op with the currently supported key and value types.
#define REGISTER_KERNEL(key_dtype, value_dtype) \
@ -757,6 +774,13 @@ REGISTER_KERNEL_BUILDER(Name("LookupTableImport").Device(DEVICE_CPU),
.Device(DEVICE_CPU) \
.TypeConstraint<key_dtype>("key_dtype") \
.TypeConstraint<value_dtype>("value_dtype"), \
LookupTableOp<lookup::HashTable<key_dtype, value_dtype>, key_dtype, \
value_dtype>) \
REGISTER_KERNEL_BUILDER( \
Name("HashTableV2") \
.Device(DEVICE_CPU) \
.TypeConstraint<key_dtype>("key_dtype") \
.TypeConstraint<value_dtype>("value_dtype"), \
LookupTableOp<lookup::HashTable<key_dtype, value_dtype>, key_dtype, \
value_dtype>)
@ -778,6 +802,13 @@ REGISTER_KERNEL(string, bool);
.Device(DEVICE_CPU) \
.TypeConstraint<key_dtype>("key_dtype") \
.TypeConstraint<value_dtype>("value_dtype"), \
LookupTableOp<lookup::MutableHashTableOfScalars<key_dtype, value_dtype>, \
key_dtype, value_dtype>) \
REGISTER_KERNEL_BUILDER( \
Name("MutableHashTableV2") \
.Device(DEVICE_CPU) \
.TypeConstraint<key_dtype>("key_dtype") \
.TypeConstraint<value_dtype>("value_dtype"), \
LookupTableOp<lookup::MutableHashTableOfScalars<key_dtype, value_dtype>, \
key_dtype, value_dtype>)
@ -796,6 +827,13 @@ REGISTER_KERNEL(int64, float);
.Device(DEVICE_CPU) \
.TypeConstraint<key_dtype>("key_dtype") \
.TypeConstraint<value_dtype>("value_dtype"), \
LookupTableOp<lookup::MutableHashTableOfTensors<key_dtype, value_dtype>, \
key_dtype, value_dtype>) \
REGISTER_KERNEL_BUILDER( \
Name("MutableHashTableOfTensorsV2") \
.Device(DEVICE_CPU) \
.TypeConstraint<key_dtype>("key_dtype") \
.TypeConstraint<value_dtype>("value_dtype"), \
LookupTableOp<lookup::MutableHashTableOfTensors<key_dtype, value_dtype>, \
key_dtype, value_dtype>)
@ -813,6 +851,13 @@ REGISTER_KERNEL(string, bool);
.Device(DEVICE_CPU) \
.TypeConstraint<key_dtype>("key_dtype") \
.TypeConstraint<value_dtype>("value_dtype"), \
LookupTableOp<lookup::MutableDenseHashTable<key_dtype, value_dtype>, \
key_dtype, value_dtype>) \
REGISTER_KERNEL_BUILDER( \
Name("MutableDenseHashTableV2") \
.Device(DEVICE_CPU) \
.TypeConstraint<key_dtype>("key_dtype") \
.TypeConstraint<value_dtype>("value_dtype"), \
LookupTableOp<lookup::MutableDenseHashTable<key_dtype, value_dtype>, \
key_dtype, value_dtype>)

View File

@ -51,40 +51,52 @@ class LookupTableOp : public OpKernel {
// ctx is not owned by this function.
void Compute(OpKernelContext* ctx) override {
mutex_lock l(mu_);
if (!table_handle_set_) {
OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def(),
use_node_name_sharing_));
auto creator = [ctx, this](lookup::LookupInterface** ret) {
lookup::LookupInterface* container = new Container(ctx, this);
if (!ctx->status().ok()) {
container->Unref();
return ctx->status();
}
if (ctx->track_allocations()) {
ctx->record_device_persistent_memory_allocation(
container->MemoryUsed());
}
*ret = container;
return Status::OK();
};
lookup::LookupInterface* table = nullptr;
OP_REQUIRES_OK(
ctx, cinfo_.resource_manager()
->template LookupOrCreate<lookup::LookupInterface>(
cinfo_.container(), cinfo_.name(), &table, creator));
core::ScopedUnref unref_me(table);
OP_REQUIRES_OK(ctx, lookup::CheckTableDataTypes(
*table, DataTypeToEnum<key_dtype>::v(),
DataTypeToEnum<value_dtype>::v(), cinfo_.name()));
auto h = table_handle_.AccessTensor(ctx)->template flat<string>();
h(0) = cinfo_.container();
h(1) = cinfo_.name();
table_handle_set_ = true;
}
ctx->set_output_ref(0, &mu_, table_handle_.AccessTensor(ctx));
auto creator = [ctx, this](lookup::LookupInterface** ret) {
lookup::LookupInterface* container = new Container(ctx, this);
if (!ctx->status().ok()) {
container->Unref();
return ctx->status();
}
if (ctx->track_allocations()) {
ctx->record_device_persistent_memory_allocation(
container->MemoryUsed());
}
*ret = container;
return Status::OK();
};
lookup::LookupInterface* table = nullptr;
OP_REQUIRES_OK(ctx,
cinfo_.resource_manager()
->template LookupOrCreate<lookup::LookupInterface>(
cinfo_.container(), cinfo_.name(), &table, creator));
core::ScopedUnref unref_me(table);
OP_REQUIRES_OK(ctx, lookup::CheckTableDataTypes(
*table, DataTypeToEnum<key_dtype>::v(),
DataTypeToEnum<value_dtype>::v(), cinfo_.name()));
if (ctx->expected_output_dtype(0) == DT_RESOURCE) {
Tensor* handle;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle));
handle->scalar<ResourceHandle>()() =
MakeResourceHandle<lookup::LookupInterface>(ctx, cinfo_.container(),
cinfo_.name());
} else {
if (!table_handle_set_) {
auto h = table_handle_.AccessTensor(ctx)->template flat<string>();
h(0) = cinfo_.container();
h(1) = cinfo_.name();
}
ctx->set_output_ref(0, &mu_, table_handle_.AccessTensor(ctx));
}
table_handle_set_ = true;
}
~LookupTableOp() override {

View File

@ -49,26 +49,48 @@ Status GetLookupTable(const string& input_name, OpKernelContext* ctx,
LookupInterface** table) {
string container;
string table_handle;
TF_RETURN_IF_ERROR(
GetTableHandle(input_name, ctx, &container, &table_handle));
return ctx->resource_manager()->Lookup(container, table_handle, table);
DataType handle_dtype;
TF_RETURN_IF_ERROR(ctx->input_dtype(input_name, &handle_dtype));
if (handle_dtype == DT_RESOURCE) {
ResourceHandle handle;
TF_RETURN_IF_ERROR(HandleFromInput(ctx, input_name, &handle));
return LookupResource(ctx, handle, table);
} else {
TF_RETURN_IF_ERROR(
GetTableHandle(input_name, ctx, &container, &table_handle));
return ctx->resource_manager()->Lookup(container, table_handle, table);
}
}
Status GetInitializableLookupTable(const string& input_name,
OpKernelContext* ctx,
InitializableLookupTable** table) {
string container;
string table_handle;
TF_RETURN_IF_ERROR(
GetTableHandle(input_name, ctx, &container, &table_handle));
LookupInterface* lookup_table;
TF_RETURN_IF_ERROR(
ctx->resource_manager()->Lookup(container, table_handle, &lookup_table));
*table = lookup_table->GetInitializableLookupTable();
if (*table == nullptr) {
lookup_table->Unref();
return errors::InvalidArgument("Table ", container, " ", table_handle,
" is not initializable");
DataType handle_dtype;
TF_RETURN_IF_ERROR(ctx->input_dtype(input_name, &handle_dtype));
if (handle_dtype == DT_RESOURCE) {
ResourceHandle handle;
TF_RETURN_IF_ERROR(HandleFromInput(ctx, input_name, &handle));
TF_RETURN_IF_ERROR(LookupResource(ctx, handle, &lookup_table));
*table = lookup_table->GetInitializableLookupTable();
if (*table == nullptr) {
lookup_table->Unref();
return errors::InvalidArgument("Table ", handle.container(), " ",
handle.name(), " is not initializable");
}
} else {
string container;
string table_handle;
TF_RETURN_IF_ERROR(
GetTableHandle(input_name, ctx, &container, &table_handle));
TF_RETURN_IF_ERROR(ctx->resource_manager()->Lookup(container, table_handle,
&lookup_table));
*table = lookup_table->GetInitializableLookupTable();
if (*table == nullptr) {
lookup_table->Unref();
return errors::InvalidArgument("Table ", container, " ", table_handle,
" is not initializable");
}
}
return Status::OK();
}

View File

@ -3592,6 +3592,33 @@ op {
}
}
}
op {
name: "Bucketize"
input_arg {
name: "input"
type_attr: "T"
}
output_arg {
name: "output"
type: DT_INT32
}
attr {
name: "T"
type: "type"
allowed_values {
list {
type: DT_INT32
type: DT_INT64
type: DT_FLOAT
type: DT_DOUBLE
}
}
}
attr {
name: "boundaries"
type: "list(float)"
}
}
op {
name: "CTCBeamSearchDecoder"
input_arg {
@ -8878,6 +8905,43 @@ op {
}
is_stateful: true
}
op {
name: "HashTableV2"
output_arg {
name: "table_handle"
type: DT_RESOURCE
}
attr {
name: "container"
type: "string"
default_value {
s: ""
}
}
attr {
name: "shared_name"
type: "string"
default_value {
s: ""
}
}
attr {
name: "use_node_name_sharing"
type: "bool"
default_value {
b: false
}
}
attr {
name: "key_dtype"
type: "type"
}
attr {
name: "value_dtype"
type: "type"
}
is_stateful: true
}
op {
name: "HistogramSummary"
input_arg {
@ -9309,6 +9373,70 @@ op {
}
}
}
op {
name: "InitializeTableFromTextFileV2"
input_arg {
name: "table_handle"
type: DT_RESOURCE
}
input_arg {
name: "filename"
type: DT_STRING
}
attr {
name: "key_index"
type: "int"
has_minimum: true
minimum: -2
}
attr {
name: "value_index"
type: "int"
has_minimum: true
minimum: -2
}
attr {
name: "vocab_size"
type: "int"
default_value {
i: -1
}
has_minimum: true
minimum: -1
}
attr {
name: "delimiter"
type: "string"
default_value {
s: "\t"
}
}
is_stateful: true
}
op {
name: "InitializeTableV2"
input_arg {
name: "table_handle"
type: DT_RESOURCE
}
input_arg {
name: "keys"
type_attr: "Tkey"
}
input_arg {
name: "values"
type_attr: "Tval"
}
attr {
name: "Tkey"
type: "type"
}
attr {
name: "Tval"
type: "type"
}
is_stateful: true
}
op {
name: "Inv"
input_arg {
@ -10155,6 +10283,30 @@ op {
type: "type"
}
}
op {
name: "LookupTableExportV2"
input_arg {
name: "table_handle"
type: DT_RESOURCE
}
output_arg {
name: "keys"
type_attr: "Tkeys"
}
output_arg {
name: "values"
type_attr: "Tvalues"
}
attr {
name: "Tkeys"
type: "type"
}
attr {
name: "Tvalues"
type: "type"
}
is_stateful: true
}
op {
name: "LookupTableFind"
input_arg {
@ -10183,6 +10335,34 @@ op {
type: "type"
}
}
op {
name: "LookupTableFindV2"
input_arg {
name: "table_handle"
type: DT_RESOURCE
}
input_arg {
name: "keys"
type_attr: "Tin"
}
input_arg {
name: "default_value"
type_attr: "Tout"
}
output_arg {
name: "values"
type_attr: "Tout"
}
attr {
name: "Tin"
type: "type"
}
attr {
name: "Tout"
type: "type"
}
is_stateful: true
}
op {
name: "LookupTableImport"
input_arg {
@ -10207,6 +10387,30 @@ op {
type: "type"
}
}
op {
name: "LookupTableImportV2"
input_arg {
name: "table_handle"
type: DT_RESOURCE
}
input_arg {
name: "keys"
type_attr: "Tin"
}
input_arg {
name: "values"
type_attr: "Tout"
}
attr {
name: "Tin"
type: "type"
}
attr {
name: "Tout"
type: "type"
}
is_stateful: true
}
op {
name: "LookupTableInsert"
input_arg {
@ -10231,6 +10435,30 @@ op {
type: "type"
}
}
op {
name: "LookupTableInsertV2"
input_arg {
name: "table_handle"
type: DT_RESOURCE
}
input_arg {
name: "keys"
type_attr: "Tin"
}
input_arg {
name: "values"
type_attr: "Tout"
}
attr {
name: "Tin"
type: "type"
}
attr {
name: "Tout"
type: "type"
}
is_stateful: true
}
op {
name: "LookupTableSize"
input_arg {
@ -10243,6 +10471,18 @@ op {
type: DT_INT64
}
}
op {
name: "LookupTableSizeV2"
input_arg {
name: "table_handle"
type: DT_RESOURCE
}
output_arg {
name: "size"
type: DT_INT64
}
is_stateful: true
}
op {
name: "LoopCond"
input_arg {
@ -12147,6 +12387,69 @@ op {
}
is_stateful: true
}
op {
name: "MutableDenseHashTableV2"
input_arg {
name: "empty_key"
type_attr: "key_dtype"
}
output_arg {
name: "table_handle"
type: DT_RESOURCE
}
attr {
name: "container"
type: "string"
default_value {
s: ""
}
}
attr {
name: "shared_name"
type: "string"
default_value {
s: ""
}
}
attr {
name: "use_node_name_sharing"
type: "bool"
default_value {
b: false
}
}
attr {
name: "key_dtype"
type: "type"
}
attr {
name: "value_dtype"
type: "type"
}
attr {
name: "value_shape"
type: "shape"
default_value {
shape {
}
}
}
attr {
name: "initial_num_buckets"
type: "int"
default_value {
i: 131072
}
}
attr {
name: "max_load_factor"
type: "float"
default_value {
f: 0.8
}
}
is_stateful: true
}
op {
name: "MutableHashTable"
output_arg {
@ -12231,6 +12534,88 @@ op {
}
is_stateful: true
}
op {
name: "MutableHashTableOfTensorsV2"
output_arg {
name: "table_handle"
type: DT_RESOURCE
}
attr {
name: "container"
type: "string"
default_value {
s: ""
}
}
attr {
name: "shared_name"
type: "string"
default_value {
s: ""
}
}
attr {
name: "use_node_name_sharing"
type: "bool"
default_value {
b: false
}
}
attr {
name: "key_dtype"
type: "type"
}
attr {
name: "value_dtype"
type: "type"
}
attr {
name: "value_shape"
type: "shape"
default_value {
shape {
}
}
}
is_stateful: true
}
op {
name: "MutableHashTableV2"
output_arg {
name: "table_handle"
type: DT_RESOURCE
}
attr {
name: "container"
type: "string"
default_value {
s: ""
}
}
attr {
name: "shared_name"
type: "string"
default_value {
s: ""
}
}
attr {
name: "use_node_name_sharing"
type: "bool"
default_value {
b: false
}
}
attr {
name: "key_dtype"
type: "type"
}
attr {
name: "value_dtype"
type: "type"
}
is_stateful: true
}
op {
name: "Neg"
input_arg {

View File

@ -210,10 +210,29 @@ Status TwoElementVectorInputsAndScalarOutputs(InferenceContext* c) {
return Status::OK();
}
Status ScalarAndTwoElementVectorInputsAndScalarOutputs(InferenceContext* c) {
ShapeHandle handle;
DimensionHandle unused_handle;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
for (int i = 1; i < c->num_inputs(); ++i) {
TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &handle));
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_handle));
}
for (int i = 0; i < c->num_outputs(); ++i) {
c->set_output(i, c->Scalar());
}
return Status::OK();
}
Status TwoElementOutput(InferenceContext* c) {
c->set_output(0, c->Vector(2));
return Status::OK();
}
Status ScalarOutput(InferenceContext* c) {
c->set_output(0, c->Scalar());
return Status::OK();
}
} // namespace
REGISTER_OP("RandomShuffleQueue")
@ -1881,6 +1900,38 @@ values: Same shape as `keys`. Values found in the table, or `default_values`
for missing keys.
)doc");
REGISTER_OP("LookupTableFindV2")
.Input("table_handle: resource")
.Input("keys: Tin")
.Input("default_value: Tout")
.Output("values: Tout")
.Attr("Tin: type")
.Attr("Tout: type")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle handle;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
// Default value must be scalar or vector.
ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &unused));
c->set_output(0, c->UnknownShape());
return Status::OK();
})
.Doc(R"doc(
Looks up keys in a table, outputs the corresponding values.
The tensor `keys` must of the same type as the keys of the table.
The output `values` is of the type of the table values.
The scalar `default_value` is the value output for keys not present in the
table. It must also be of the same type as the table values.
table_handle: Handle to the table.
keys: Any shape. Keys to look up.
values: Same shape as `keys`. Values found in the table, or `default_values`
for missing keys.
)doc");
REGISTER_OP("LookupTableInsert")
.Input("table_handle: Ref(string)")
.Input("keys: Tin")
@ -1893,6 +1944,30 @@ REGISTER_OP("LookupTableInsert")
DimensionHandle unused_dim;
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
// TODO(ebrevdo): Validate keys and values shape.
return Status::OK();
})
.Doc(R"doc(
Updates the table to associates keys with values.
The tensor `keys` must be of the same type as the keys of the table.
The tensor `values` must be of the type of the table values.
table_handle: Handle to the table.
keys: Any shape. Keys to look up.
values: Values to associate with keys.
)doc");
REGISTER_OP("LookupTableInsertV2")
.Input("table_handle: resource")
.Input("keys: Tin")
.Input("values: Tout")
.Attr("Tin: type")
.Attr("Tout: type")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle handle;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
// TODO: Validate keys and values shape.
return Status::OK();
})
@ -1918,6 +1993,17 @@ table_handle: Handle to the table.
size: Scalar that contains number of elements in the table.
)doc");
REGISTER_OP("LookupTableSizeV2")
.Input("table_handle: resource")
.Output("size: int64")
.SetShapeFn(ScalarAndTwoElementVectorInputsAndScalarOutputs)
.Doc(R"doc(
Computes the number of elements in the given table.
table_handle: Handle to the table.
size: Scalar that contains number of elements in the table.
)doc");
REGISTER_OP("LookupTableExport")
.Input("table_handle: Ref(string)")
.Output("keys: Tkeys")
@ -1945,6 +2031,31 @@ keys: Vector of all keys present in the table.
values: Tensor of all values in the table. Indexed in parallel with `keys`.
)doc");
REGISTER_OP("LookupTableExportV2")
.Input("table_handle: resource")
.Output("keys: Tkeys")
.Output("values: Tvalues")
.Attr("Tkeys: type")
.Attr("Tvalues: type")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle handle;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
ShapeHandle values = c->UnknownShape();
TF_RETURN_IF_ERROR(c->WithRankAtLeast(values, 1, &values));
ShapeHandle keys = c->Vector(c->Dim(values, 0));
c->set_output(0, keys);
c->set_output(1, values);
return Status::OK();
})
.Doc(R"doc(
Outputs all keys and values in the table.
table_handle: Handle to the table.
keys: Vector of all keys present in the table.
values: Tensor of all values in the table. Indexed in parallel with `keys`.
)doc");
REGISTER_OP("LookupTableImport")
.Input("table_handle: Ref(string)")
.Input("keys: Tin")
@ -1957,6 +2068,30 @@ REGISTER_OP("LookupTableImport")
DimensionHandle unused_dim;
TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_dim));
// TODO(ebrevdo): Validate keys and values shape.
return Status::OK();
})
.Doc(R"doc(
Replaces the contents of the table with the specified keys and values.
The tensor `keys` must be of the same type as the keys of the table.
The tensor `values` must be of the type of the table values.
table_handle: Handle to the table.
keys: Any shape. Keys to look up.
values: Values to associate with keys.
)doc");
REGISTER_OP("LookupTableImportV2")
.Input("table_handle: resource")
.Input("keys: Tin")
.Input("values: Tout")
.Attr("Tin: type")
.Attr("Tout: type")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle handle;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
// TODO: Validate keys and values shape.
return Status::OK();
})
@ -1998,6 +2133,33 @@ key_dtype: Type of the table keys.
value_dtype: Type of the table values.
)doc");
REGISTER_OP("HashTableV2")
.Output("table_handle: resource")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
.Attr("use_node_name_sharing: bool = false")
.Attr("key_dtype: type")
.Attr("value_dtype: type")
.SetIsStateful()
.SetShapeFn(ScalarOutput)
.Doc(R"doc(
Creates a non-initialized hash table.
This op creates a hash table, specifying the type of its keys and values.
Before using the table you will have to initialize it. After initialization the
table will be immutable.
table_handle: Handle to a table.
container: If non-empty, this table is placed in the given container.
Otherwise, a default container is used.
shared_name: If non-empty, this table is shared under the given name across
multiple sessions.
use_node_name_sharing: If true and shared_name is empty, the table is shared
using the node name.
key_dtype: Type of the table keys.
value_dtype: Type of the table values.
)doc");
REGISTER_OP("MutableHashTable")
.Output("table_handle: Ref(string)")
.Attr("container: string = ''")
@ -2025,6 +2187,33 @@ key_dtype: Type of the table keys.
value_dtype: Type of the table values.
)doc");
REGISTER_OP("MutableHashTableV2")
.Output("table_handle: resource")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
.Attr("use_node_name_sharing: bool = false")
.Attr("key_dtype: type")
.Attr("value_dtype: type")
.SetIsStateful()
.SetShapeFn(ScalarOutput)
.Doc(R"doc(
Creates an empty hash table.
This op creates a mutable hash table, specifying the type of its keys and
values. Each value must be a scalar. Data can be inserted into the table using
the insert operations. It does not support the initialization operation.
table_handle: Handle to a table.
container: If non-empty, this table is placed in the given container.
Otherwise, a default container is used.
shared_name: If non-empty, this table is shared under the given name across
multiple sessions.
use_node_name_sharing: If true and shared_name is empty, the table is shared
using the node name.
key_dtype: Type of the table keys.
value_dtype: Type of the table values.
)doc");
REGISTER_OP("MutableHashTableOfTensors")
.Output("table_handle: Ref(string)")
.Attr("container: string = ''")
@ -2051,6 +2240,32 @@ key_dtype: Type of the table keys.
value_dtype: Type of the table values.
)doc");
REGISTER_OP("MutableHashTableOfTensorsV2")
.Output("table_handle: resource")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
.Attr("use_node_name_sharing: bool = false")
.Attr("key_dtype: type")
.Attr("value_dtype: type")
.Attr("value_shape: shape = {}")
.SetIsStateful()
.SetShapeFn(ScalarOutput)
.Doc(R"doc(
Creates an empty hash table.
This op creates a mutable hash table, specifying the type of its keys and
values. Each value must be a vector. Data can be inserted into the table using
the insert operations. It does not support the initialization operation.
table_handle: Handle to a table.
container: If non-empty, this table is placed in the given container.
Otherwise, a default container is used.
shared_name: If non-empty, this table is shared under the given name across
multiple sessions.
key_dtype: Type of the table keys.
value_dtype: Type of the table values.
)doc");
REGISTER_OP("MutableDenseHashTable")
.Input("empty_key: key_dtype")
.Output("table_handle: Ref(string)")
@ -2088,6 +2303,43 @@ max_load_factor: The maximum ratio between number of entries and number of
buckets before growing the table. Must be between 0 and 1.
)doc");
REGISTER_OP("MutableDenseHashTableV2")
.Input("empty_key: key_dtype")
.Output("table_handle: resource")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
.Attr("use_node_name_sharing: bool = false")
.Attr("key_dtype: type")
.Attr("value_dtype: type")
.Attr("value_shape: shape = {}")
.Attr("initial_num_buckets: int = 131072") // 2^17
.Attr("max_load_factor: float = 0.8")
.SetIsStateful()
.SetShapeFn(ScalarOutput)
.Doc(R"doc(
Creates an empty hash table that uses tensors as the backing store. It uses
"open addressing" with quadratic reprobing to resolve collisions.
This op creates a mutable hash table, specifying the type of its keys and
values. Each value must be a scalar. Data can be inserted into the table using
the insert operations. It does not support the initialization operation.
empty_key: The key used to represent empty key buckets internally. Must not
be used in insert or lookup operations.
table_handle: Handle to a table.
container: If non-empty, this table is placed in the given container.
Otherwise, a default container is used.
shared_name: If non-empty, this table is shared under the given name across
multiple sessions.
key_dtype: Type of the table keys.
value_dtype: Type of the table values.
value_shape: The shape of each value.
initial_num_buckets: The initial number of hash table buckets. Must be a power
to 2.
max_load_factor: The maximum ratio between number of entries and number of
buckets before growing the table. Must be between 0 and 1.
)doc");
REGISTER_OP("InitializeTable")
.Input("table_handle: Ref(string)")
.Input("keys: Tkey")
@ -2113,6 +2365,29 @@ keys: Keys of type Tkey.
values: Values of type Tval.
)doc");
REGISTER_OP("InitializeTableV2")
.Input("table_handle: resource")
.Input("keys: Tkey")
.Input("values: Tval")
.Attr("Tkey: type")
.Attr("Tval: type")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle handle;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
ShapeHandle keys;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &keys));
TF_RETURN_IF_ERROR(c->Merge(keys, c->input(2), &keys));
return Status::OK();
})
.Doc(R"doc(
Table initializer that takes two tensors for keys and values respectively.
table_handle: Handle to a table which will be initialized.
keys: Keys of type Tkey.
values: Values of type Tval.
)doc");
REGISTER_OP("InitializeTableFromTextFile")
.Input("table_handle: Ref(string)")
.Input("filename: string")
@ -2152,6 +2427,43 @@ vocab_size: Number of elements of the file, use -1 if unknown.
delimiter: Delimiter to separate fields in a line.
)doc");
REGISTER_OP("InitializeTableFromTextFileV2")
.Input("table_handle: resource")
.Input("filename: string")
.Attr("key_index: int >= -2")
.Attr("value_index: int >= -2")
.Attr("vocab_size: int >= -1 = -1")
.Attr("delimiter: string = '\t'")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle handle;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &handle));
return Status::OK();
})
.Doc(R"doc(
Initializes a table from a text file.
It inserts one key-value pair into the table for each line of the file.
The key and value is extracted from the whole line content, elements from the
split line based on `delimiter` or the line number (starting from zero).
Where to extract the key and value from a line is specified by `key_index` and
`value_index`.
- A value of -1 means use the line number(starting from zero), expects `int64`.
- A value of -2 means use the whole line content, expects `string`.
- A value >= 0 means use the index (starting at zero) of the split line based
on `delimiter`.
table_handle: Handle to a table which will be initialized.
filename: Filename of a vocabulary text file.
key_index: Column index in a line to get the table `key` values from.
value_index: Column index that represents information of a line to get the table
`value` values from.
vocab_size: Number of elements of the file, use -1 if unknown.
delimiter: Delimiter to separate fields in a line.
)doc");
REGISTER_OP("GetSessionHandle")
.Input("value: T")
.Output("handle: string")

View File

@ -662,13 +662,12 @@ Compute the upper regularized incomplete Gamma function `Q(a, x)`.
The upper regularized incomplete Gamma function is defined as:
```
Q(a, x) = Gamma(a, x) / Gamma(a) = 1 - P(a, x)
```
\\(Q(a, x) = Gamma(a, x) / Gamma(a) = 1 - P(a, x)\\)
where
```
Gamma(a, x) = int_{x}^{\infty} t^{a-1} exp(-t) dt
```
\\(Gamma(a, x) = int_{x}^{\infty} t^{a-1} exp(-t) dt\\)
is the upper incomplete Gama function.
Note, above `P(a, x)` (`Igamma`) is the lower regularized complete
@ -686,13 +685,13 @@ Compute the lower regularized incomplete Gamma function `Q(a, x)`.
The lower regularized incomplete Gamma function is defined as:
```
P(a, x) = gamma(a, x) / Gamma(a) = 1 - Q(a, x)
```
\\(P(a, x) = gamma(a, x) / Gamma(a) = 1 - Q(a, x)\\)
where
```
gamma(a, x) = int_{0}^{x} t^{a-1} exp(-t) dt
```
\\(gamma(a, x) = int_{0}^{x} t^{a-1} exp(-t) dt\\)
is the lower incomplete Gamma function.
Note, above `Q(a, x)` (`Igammac`) is the upper regularized complete
@ -710,9 +709,9 @@ Compute the Hurwitz zeta function \\(\zeta(x, q)\\).
The Hurwitz zeta function is defined as:
```
\zeta(x, q) = \sum_{n=0}^{\infty} (q + n)^{-x}
```
\\(\zeta(x, q) = \sum_{n=0}^{\infty} (q + n)^{-x}\\)
)doc");
REGISTER_OP("Polygamma")
@ -726,9 +725,9 @@ Compute the polygamma function \\(\psi^{(n)}(x)\\).
The polygamma function is defined as:
```
\psi^{(n)}(x) = \frac{d^n}{dx^n} \psi(x)
```
\\(\psi^{(n)}(x) = \frac{d^n}{dx^n} \psi(x)\\)
where \\(\psi(x)\\) is the digamma function.
)doc");
@ -790,14 +789,14 @@ Compute the regularized incomplete beta integral \\(I_x(a, b)\\).
The regularized incomplete beta integral is defined as:
```
I_x(a, b) = \frac{B(x; a, b)}{B(a, b)}
```
\\(I_x(a, b) = \frac{B(x; a, b)}{B(a, b)}\\)
where
```
B(x; a, b) = \int_0^x t^{a-1} (1 - t)^{b-1} dt
```
\\(B(x; a, b) = \int_0^x t^{a-1} (1 - t)^{b-1} dt\\)
is the incomplete beta function and \\(B(a, b)\\) is the *complete*
beta function.
@ -2371,4 +2370,35 @@ output_max: the computed max output.
)doc");
// --------------------------------------------------------------------------
REGISTER_OP("Bucketize")
.Input("input: T")
.Output("output: int32")
.Attr("T: {int32, int64, float, double}")
.Attr("boundaries: list(float)")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
Bucketizes 'input' based on 'boundaries'.
For example, if the inputs are
boundaries = [0, 10, 100]
input = [[-5, 10000]
[150, 10]
[5, 100]]
then the output will be
output = [[0, 3]
[3, 2]
[1, 3]]
input: Any shape of Tensor contains with int or float type.
boundaries: A sorted list of floats gives the boundary of the buckets.
output: Same shape with 'input', each value of input replaced with bucket index.
@compatibility(numpy)
Equivalent to np.digitize.
@end_compatibility
)doc");
} // namespace tensorflow

View File

@ -3324,7 +3324,7 @@ op {
}
}
summary: "Compute the regularized incomplete beta integral \\\\(I_x(a, b)\\\\)."
description: "The regularized incomplete beta integral is defined as:\n\n```\nI_x(a, b) = \\frac{B(x; a, b)}{B(a, b)}\n```\nwhere\n\n```\nB(x; a, b) = \\int_0^x t^{a-1} (1 - t)^{b-1} dt\n```\n\nis the incomplete beta function and \\\\(B(a, b)\\\\) is the *complete*\nbeta function."
description: "The regularized incomplete beta integral is defined as:\n\n\n\\\\(I_x(a, b) = \\frac{B(x; a, b)}{B(a, b)}\\\\)\n\nwhere\n\n\n\\\\(B(x; a, b) = \\int_0^x t^{a-1} (1 - t)^{b-1} dt\\\\)\n\n\nis the incomplete beta function and \\\\(B(a, b)\\\\) is the *complete*\nbeta function."
}
op {
name: "BiasAdd"
@ -3633,6 +3633,38 @@ op {
summary: "Return the reduction indices for computing gradients of s0 op s1 with broadcast."
description: "This is typically used by gradient computations for a broadcasting operation."
}
op {
name: "Bucketize"
input_arg {
name: "input"
description: "Any shape of Tensor contains with int or float type."
type_attr: "T"
}
output_arg {
name: "output"
description: "Same shape with \'input\', each value of input replaced with bucket index.\n\n@compatibility(numpy)\nEquivalent to np.digitize.\n@end_compatibility"
type: DT_INT32
}
attr {
name: "T"
type: "type"
allowed_values {
list {
type: DT_INT32
type: DT_INT64
type: DT_FLOAT
type: DT_DOUBLE
}
}
}
attr {
name: "boundaries"
type: "list(float)"
description: "A sorted list of floats gives the boundary of the buckets."
}
summary: "Bucketizes \'input\' based on \'boundaries\'."
description: "For example, if the inputs are\n boundaries = [0, 10, 100]\n input = [[-5, 10000]\n [150, 10]\n [5, 100]]\n\nthen the output will be\n output = [[0, 3]\n [3, 2]\n [1, 3]]"
}
op {
name: "CTCBeamSearchDecoder"
input_arg {
@ -8713,6 +8745,51 @@ op {
description: "This op creates a hash table, specifying the type of its keys and values.\nBefore using the table you will have to initialize it. After initialization the\ntable will be immutable."
is_stateful: true
}
op {
name: "HashTableV2"
output_arg {
name: "table_handle"
description: "Handle to a table."
type: DT_RESOURCE
}
attr {
name: "container"
type: "string"
default_value {
s: ""
}
description: "If non-empty, this table is placed in the given container.\nOtherwise, a default container is used."
}
attr {
name: "shared_name"
type: "string"
default_value {
s: ""
}
description: "If non-empty, this table is shared under the given name across\nmultiple sessions."
}
attr {
name: "use_node_name_sharing"
type: "bool"
default_value {
b: false
}
description: "If true and shared_name is empty, the table is shared\nusing the node name."
}
attr {
name: "key_dtype"
type: "type"
description: "Type of the table keys."
}
attr {
name: "value_dtype"
type: "type"
description: "Type of the table values."
}
summary: "Creates a non-initialized hash table."
description: "This op creates a hash table, specifying the type of its keys and values.\nBefore using the table you will have to initialize it. After initialization the\ntable will be immutable."
is_stateful: true
}
op {
name: "HistogramSummary"
input_arg {
@ -8954,7 +9031,7 @@ op {
}
}
summary: "Compute the lower regularized incomplete Gamma function `Q(a, x)`."
description: "The lower regularized incomplete Gamma function is defined as:\n\n```\nP(a, x) = gamma(a, x) / Gamma(a) = 1 - Q(a, x)\n```\nwhere\n```\ngamma(a, x) = int_{0}^{x} t^{a-1} exp(-t) dt\n```\nis the lower incomplete Gamma function.\n\nNote, above `Q(a, x)` (`Igammac`) is the upper regularized complete\nGamma function."
description: "The lower regularized incomplete Gamma function is defined as:\n\n\n\\\\(P(a, x) = gamma(a, x) / Gamma(a) = 1 - Q(a, x)\\\\)\n\nwhere\n\n\\\\(gamma(a, x) = int_{0}^{x} t^{a-1} exp(-t) dt\\\\)\n\nis the lower incomplete Gamma function.\n\nNote, above `Q(a, x)` (`Igammac`) is the upper regularized complete\nGamma function."
}
op {
name: "Igammac"
@ -8981,7 +9058,7 @@ op {
}
}
summary: "Compute the upper regularized incomplete Gamma function `Q(a, x)`."
description: "The upper regularized incomplete Gamma function is defined as:\n\n```\nQ(a, x) = Gamma(a, x) / Gamma(a) = 1 - P(a, x)\n```\nwhere\n```\nGamma(a, x) = int_{x}^{\\infty} t^{a-1} exp(-t) dt\n```\nis the upper incomplete Gama function.\n\nNote, above `P(a, x)` (`Igamma`) is the lower regularized complete\nGamma function."
description: "The upper regularized incomplete Gamma function is defined as:\n\n\\\\(Q(a, x) = Gamma(a, x) / Gamma(a) = 1 - P(a, x)\\\\)\n\nwhere\n\n\\\\(Gamma(a, x) = int_{x}^{\\infty} t^{a-1} exp(-t) dt\\\\)\n\nis the upper incomplete Gama function.\n\nNote, above `P(a, x)` (`Igamma`) is the lower regularized complete\nGamma function."
}
op {
name: "Imag"
@ -9223,6 +9300,82 @@ op {
summary: "Initializes a table from a text file."
description: "It inserts one key-value pair into the table for each line of the file.\nThe key and value is extracted from the whole line content, elements from the\nsplit line based on `delimiter` or the line number (starting from zero).\nWhere to extract the key and value from a line is specified by `key_index` and\n`value_index`.\n\n- A value of -1 means use the line number(starting from zero), expects `int64`.\n- A value of -2 means use the whole line content, expects `string`.\n- A value >= 0 means use the index (starting at zero) of the split line based\n on `delimiter`."
}
op {
name: "InitializeTableFromTextFileV2"
input_arg {
name: "table_handle"
description: "Handle to a table which will be initialized."
type: DT_RESOURCE
}
input_arg {
name: "filename"
description: "Filename of a vocabulary text file."
type: DT_STRING
}
attr {
name: "key_index"
type: "int"
description: "Column index in a line to get the table `key` values from."
has_minimum: true
minimum: -2
}
attr {
name: "value_index"
type: "int"
description: "Column index that represents information of a line to get the table\n`value` values from."
has_minimum: true
minimum: -2
}
attr {
name: "vocab_size"
type: "int"
default_value {
i: -1
}
description: "Number of elements of the file, use -1 if unknown."
has_minimum: true
minimum: -1
}
attr {
name: "delimiter"
type: "string"
default_value {
s: "\t"
}
description: "Delimiter to separate fields in a line."
}
summary: "Initializes a table from a text file."
description: "It inserts one key-value pair into the table for each line of the file.\nThe key and value is extracted from the whole line content, elements from the\nsplit line based on `delimiter` or the line number (starting from zero).\nWhere to extract the key and value from a line is specified by `key_index` and\n`value_index`.\n\n- A value of -1 means use the line number(starting from zero), expects `int64`.\n- A value of -2 means use the whole line content, expects `string`.\n- A value >= 0 means use the index (starting at zero) of the split line based\n on `delimiter`."
is_stateful: true
}
op {
name: "InitializeTableV2"
input_arg {
name: "table_handle"
description: "Handle to a table which will be initialized."
type: DT_RESOURCE
}
input_arg {
name: "keys"
description: "Keys of type Tkey."
type_attr: "Tkey"
}
input_arg {
name: "values"
description: "Values of type Tval."
type_attr: "Tval"
}
attr {
name: "Tkey"
type: "type"
}
attr {
name: "Tval"
type: "type"
}
summary: "Table initializer that takes two tensors for keys and values respectively."
is_stateful: true
}
op {
name: "Inv"
input_arg {
@ -10035,6 +10188,34 @@ op {
}
summary: "Outputs all keys and values in the table."
}
op {
name: "LookupTableExportV2"
input_arg {
name: "table_handle"
description: "Handle to the table."
type: DT_RESOURCE
}
output_arg {
name: "keys"
description: "Vector of all keys present in the table."
type_attr: "Tkeys"
}
output_arg {
name: "values"
description: "Tensor of all values in the table. Indexed in parallel with `keys`."
type_attr: "Tvalues"
}
attr {
name: "Tkeys"
type: "type"
}
attr {
name: "Tvalues"
type: "type"
}
summary: "Outputs all keys and values in the table."
is_stateful: true
}
op {
name: "LookupTableFind"
input_arg {
@ -10068,6 +10249,39 @@ op {
summary: "Looks up keys in a table, outputs the corresponding values."
description: "The tensor `keys` must of the same type as the keys of the table.\nThe output `values` is of the type of the table values.\n\nThe scalar `default_value` is the value output for keys not present in the\ntable. It must also be of the same type as the table values."
}
op {
name: "LookupTableFindV2"
input_arg {
name: "table_handle"
description: "Handle to the table."
type: DT_RESOURCE
}
input_arg {
name: "keys"
description: "Any shape. Keys to look up."
type_attr: "Tin"
}
input_arg {
name: "default_value"
type_attr: "Tout"
}
output_arg {
name: "values"
description: "Same shape as `keys`. Values found in the table, or `default_values`\nfor missing keys."
type_attr: "Tout"
}
attr {
name: "Tin"
type: "type"
}
attr {
name: "Tout"
type: "type"
}
summary: "Looks up keys in a table, outputs the corresponding values."
description: "The tensor `keys` must of the same type as the keys of the table.\nThe output `values` is of the type of the table values.\n\nThe scalar `default_value` is the value output for keys not present in the\ntable. It must also be of the same type as the table values."
is_stateful: true
}
op {
name: "LookupTableImport"
input_arg {
@ -10097,6 +10311,35 @@ op {
summary: "Replaces the contents of the table with the specified keys and values."
description: "The tensor `keys` must be of the same type as the keys of the table.\nThe tensor `values` must be of the type of the table values."
}
op {
name: "LookupTableImportV2"
input_arg {
name: "table_handle"
description: "Handle to the table."
type: DT_RESOURCE
}
input_arg {
name: "keys"
description: "Any shape. Keys to look up."
type_attr: "Tin"
}
input_arg {
name: "values"
description: "Values to associate with keys."
type_attr: "Tout"
}
attr {
name: "Tin"
type: "type"
}
attr {
name: "Tout"
type: "type"
}
summary: "Replaces the contents of the table with the specified keys and values."
description: "The tensor `keys` must be of the same type as the keys of the table.\nThe tensor `values` must be of the type of the table values."
is_stateful: true
}
op {
name: "LookupTableInsert"
input_arg {
@ -10126,6 +10369,35 @@ op {
summary: "Updates the table to associates keys with values."
description: "The tensor `keys` must be of the same type as the keys of the table.\nThe tensor `values` must be of the type of the table values."
}
op {
name: "LookupTableInsertV2"
input_arg {
name: "table_handle"
description: "Handle to the table."
type: DT_RESOURCE
}
input_arg {
name: "keys"
description: "Any shape. Keys to look up."
type_attr: "Tin"
}
input_arg {
name: "values"
description: "Values to associate with keys."
type_attr: "Tout"
}
attr {
name: "Tin"
type: "type"
}
attr {
name: "Tout"
type: "type"
}
summary: "Updates the table to associates keys with values."
description: "The tensor `keys` must be of the same type as the keys of the table.\nThe tensor `values` must be of the type of the table values."
is_stateful: true
}
op {
name: "LookupTableSize"
input_arg {
@ -10141,6 +10413,21 @@ op {
}
summary: "Computes the number of elements in the given table."
}
op {
name: "LookupTableSizeV2"
input_arg {
name: "table_handle"
description: "Handle to the table."
type: DT_RESOURCE
}
output_arg {
name: "size"
description: "Scalar that contains number of elements in the table."
type: DT_INT64
}
summary: "Computes the number of elements in the given table."
is_stateful: true
}
op {
name: "LoopCond"
input_arg {
@ -11845,6 +12132,80 @@ op {
description: "\"open addressing\" with quadratic reprobing to resolve collisions.\n\nThis op creates a mutable hash table, specifying the type of its keys and\nvalues. Each value must be a scalar. Data can be inserted into the table using\nthe insert operations. It does not support the initialization operation."
is_stateful: true
}
op {
name: "MutableDenseHashTableV2"
input_arg {
name: "empty_key"
description: "The key used to represent empty key buckets internally. Must not\nbe used in insert or lookup operations."
type_attr: "key_dtype"
}
output_arg {
name: "table_handle"
description: "Handle to a table."
type: DT_RESOURCE
}
attr {
name: "container"
type: "string"
default_value {
s: ""
}
description: "If non-empty, this table is placed in the given container.\nOtherwise, a default container is used."
}
attr {
name: "shared_name"
type: "string"
default_value {
s: ""
}
description: "If non-empty, this table is shared under the given name across\nmultiple sessions."
}
attr {
name: "use_node_name_sharing"
type: "bool"
default_value {
b: false
}
}
attr {
name: "key_dtype"
type: "type"
description: "Type of the table keys."
}
attr {
name: "value_dtype"
type: "type"
description: "Type of the table values."
}
attr {
name: "value_shape"
type: "shape"
default_value {
shape {
}
}
description: "The shape of each value."
}
attr {
name: "initial_num_buckets"
type: "int"
default_value {
i: 131072
}
description: "The initial number of hash table buckets. Must be a power\nto 2."
}
attr {
name: "max_load_factor"
type: "float"
default_value {
f: 0.8
}
description: "The maximum ratio between number of entries and number of\nbuckets before growing the table. Must be between 0 and 1."
}
summary: "Creates an empty hash table that uses tensors as the backing store. It uses"
description: "\"open addressing\" with quadratic reprobing to resolve collisions.\n\nThis op creates a mutable hash table, specifying the type of its keys and\nvalues. Each value must be a scalar. Data can be inserted into the table using\nthe insert operations. It does not support the initialization operation."
is_stateful: true
}
op {
name: "MutableHashTable"
output_arg {
@ -11944,6 +12305,103 @@ op {
description: "This op creates a mutable hash table, specifying the type of its keys and\nvalues. Each value must be a vector. Data can be inserted into the table using\nthe insert operations. It does not support the initialization operation."
is_stateful: true
}
op {
name: "MutableHashTableOfTensorsV2"
output_arg {
name: "table_handle"
description: "Handle to a table."
type: DT_RESOURCE
}
attr {
name: "container"
type: "string"
default_value {
s: ""
}
description: "If non-empty, this table is placed in the given container.\nOtherwise, a default container is used."
}
attr {
name: "shared_name"
type: "string"
default_value {
s: ""
}
description: "If non-empty, this table is shared under the given name across\nmultiple sessions."
}
attr {
name: "use_node_name_sharing"
type: "bool"
default_value {
b: false
}
}
attr {
name: "key_dtype"
type: "type"
description: "Type of the table keys."
}
attr {
name: "value_dtype"
type: "type"
description: "Type of the table values."
}
attr {
name: "value_shape"
type: "shape"
default_value {
shape {
}
}
}
summary: "Creates an empty hash table."
description: "This op creates a mutable hash table, specifying the type of its keys and\nvalues. Each value must be a vector. Data can be inserted into the table using\nthe insert operations. It does not support the initialization operation."
is_stateful: true
}
op {
name: "MutableHashTableV2"
output_arg {
name: "table_handle"
description: "Handle to a table."
type: DT_RESOURCE
}
attr {
name: "container"
type: "string"
default_value {
s: ""
}
description: "If non-empty, this table is placed in the given container.\nOtherwise, a default container is used."
}
attr {
name: "shared_name"
type: "string"
default_value {
s: ""
}
description: "If non-empty, this table is shared under the given name across\nmultiple sessions."
}
attr {
name: "use_node_name_sharing"
type: "bool"
default_value {
b: false
}
description: "If true and shared_name is empty, the table is shared\nusing the node name."
}
attr {
name: "key_dtype"
type: "type"
description: "Type of the table keys."
}
attr {
name: "value_dtype"
type: "type"
description: "Type of the table values."
}
summary: "Creates an empty hash table."
description: "This op creates a mutable hash table, specifying the type of its keys and\nvalues. Each value must be a scalar. Data can be inserted into the table using\nthe insert operations. It does not support the initialization operation."
is_stateful: true
}
op {
name: "Neg"
input_arg {
@ -12884,7 +13342,7 @@ op {
}
}
summary: "Compute the polygamma function \\\\(\\psi^{(n)}(x)\\\\)."
description: "The polygamma function is defined as:\n\n```\n\\psi^{(n)}(x) = \\frac{d^n}{dx^n} \\psi(x)\n```\nwhere \\\\(\\psi(x)\\\\) is the digamma function."
description: "The polygamma function is defined as:\n\n\n\\\\(\\psi^{(n)}(x) = \\frac{d^n}{dx^n} \\psi(x)\\\\)\n\nwhere \\\\(\\psi(x)\\\\) is the digamma function."
}
op {
name: "Pow"
@ -19650,7 +20108,7 @@ op {
minimum: 1
}
summary: "Distributed version of Stochastic Dual Coordinate Ascent (SDCA) optimizer for"
description: "linear models with L1 + L2 regularization. As global optimization objective is\nstrongly-convex, the optimizer optimizes the dual objective at each step. The\noptimizer applies each update one example at a time. Examples are sampled\nuniformly, and the optimizer is learning rate free and enjoys linear convergence\nrate.\n\nProximal Stochastic Dual Coordinate Ascent, Shalev-Shwartz, Shai; Zhang, Tong.\n2012 arXiv1211.2717S: http://arxiv.org/pdf/1211.2717v1.pdf\n\n Loss objective = \\sum f_{i}(wx_{i}) + (l2 / 2) * |w|^2 + l1 * |w|\n\nAdding vs. Averaging in Distributed Primal-Dual Optimization.\nChenxin Ma, Virginia Smith, Martin Jaggi, Michael I. Jordan, Peter Richtarik,\nMartin Takac http://arxiv.org/abs/1502.03508\n\nStochastic Dual Coordinate Ascent with Adaptive Probabilities\nDominik Csiba, Zheng Qu, Peter Richtarik https://arxiv.org/abs/1502.08053"
description: "linear models with L1 + L2 regularization. As global optimization objective is\nstrongly-convex, the optimizer optimizes the dual objective at each step. The\noptimizer applies each update one example at a time. Examples are sampled\nuniformly, and the optimizer is learning rate free and enjoys linear convergence\nrate.\n\n[Proximal Stochastic Dual Coordinate Ascent](http://arxiv.org/pdf/1211.2717v1.pdf).<br>\nShai Shalev-Shwartz, Tong Zhang. 2012\n\n$$Loss Objective = \\sum f_{i} (wx_{i}) + (l2 / 2) * |w|^2 + l1 * |w|$$\n\n[Adding vs. Averaging in Distributed Primal-Dual Optimization](http://arxiv.org/abs/1502.03508).<br>\nChenxin Ma, Virginia Smith, Martin Jaggi, Michael I. Jordan,\nPeter Richtarik, Martin Takac. 2015\n\n[Stochastic Dual Coordinate Ascent with Adaptive Probabilities](https://arxiv.org/abs/1502.08053).<br>\nDominik Csiba, Zheng Qu, Peter Richtarik. 2015"
}
op {
name: "SdcaShrinkL1"
@ -26527,5 +26985,5 @@ op {
}
}
summary: "Compute the Hurwitz zeta function \\\\(\\zeta(x, q)\\\\)."
description: "The Hurwitz zeta function is defined as:\n\n```\n\\zeta(x, q) = \\sum_{n=0}^{\\infty} (q + n)^{-x}\n```"
description: "The Hurwitz zeta function is defined as:\n\n\n\\\\(\\zeta(x, q) = \\sum_{n=0}^{\\infty} (q + n)^{-x}\\\\)"
}

View File

@ -72,17 +72,17 @@ optimizer applies each update one example at a time. Examples are sampled
uniformly, and the optimizer is learning rate free and enjoys linear convergence
rate.
Proximal Stochastic Dual Coordinate Ascent, Shalev-Shwartz, Shai; Zhang, Tong.
2012 arXiv1211.2717S: http://arxiv.org/pdf/1211.2717v1.pdf
[Proximal Stochastic Dual Coordinate Ascent](http://arxiv.org/pdf/1211.2717v1.pdf).<br>
Shai Shalev-Shwartz, Tong Zhang. 2012
Loss objective = \sum f_{i}(wx_{i}) + (l2 / 2) * |w|^2 + l1 * |w|
$$Loss Objective = \sum f_{i} (wx_{i}) + (l2 / 2) * |w|^2 + l1 * |w|$$
Adding vs. Averaging in Distributed Primal-Dual Optimization.
Chenxin Ma, Virginia Smith, Martin Jaggi, Michael I. Jordan, Peter Richtarik,
Martin Takac http://arxiv.org/abs/1502.03508
[Adding vs. Averaging in Distributed Primal-Dual Optimization](http://arxiv.org/abs/1502.03508).<br>
Chenxin Ma, Virginia Smith, Martin Jaggi, Michael I. Jordan,
Peter Richtarik, Martin Takac. 2015
Stochastic Dual Coordinate Ascent with Adaptive Probabilities
Dominik Csiba, Zheng Qu, Peter Richtarik https://arxiv.org/abs/1502.08053
[Stochastic Dual Coordinate Ascent with Adaptive Probabilities](https://arxiv.org/abs/1502.08053).<br>
Dominik Csiba, Zheng Qu, Peter Richtarik. 2015
loss_type: Type of the primal loss. Currently SdcaSolver supports logistic,
squared and hinge losses.

View File

@ -0,0 +1,429 @@
# TensorFlow Performance Benchmarks
## Overview
A selection of image classification models were tested across multiple platforms
to create a point of reference for the TensorFlow community. The methodology,
links to the scripts, and commands to reproduce the results are in the
[appendix](#appendix).
## Results for image classification models
InceptionV3 ([arXiv:1512.00567](https://arxiv.org/abs/1512.00567)),
ResNet-50 ([arXiv:1512.03385](https://arxiv.org/abs/1512.03385)),
ResNet-152 ([arXiv:1512.03385](https://arxiv.org/abs/1512.03385)), VGG16
([arXiv:1409.1556](https://arxiv.org/abs/1409.1556)), and
[AlexNet](http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf)
were tested using the [ImageNet](http://www.image-net.org/) data set. Tests were
run on Google Compute Engine, Amazon Elastic Compute Cloud (Amazon EC2), and an
NVIDIA® DGX-1™. Most of the tests were run with both synthetic and real data.
Testing with synthetic data was done by using a `tf.Variable` set to the same
shape as the data expected by each model for ImageNet. We believe it is
important to include real data measurements when benchmarking a platform. This
load tests both the underlying hardware and the framework at preparing data for
actual training. We start with synthetic data to remove disk I/O as a variable
and to set a baseline. Real data is then used to verify that the TensorFlow
input pipeline and the underlying disk I/O are saturating the compute units.
### Training with NVIDIA® DGX-1™ (NVIDIA® Tesla® P100)
<div style="width:100%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="../images/perf_summary_p100_single_server.png">
</div>
Details and additional results are in the [Details for NVIDIA® DGX-1™ (NVIDIA®
Tesla® P100)](#details-for-nvidia®-dgx-1™-nvidia®-tesla®-p100) section.
### Training with NVIDIA® Tesla® K80
<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="../images/perf_summary_k80_single_server.png">
</div>
Details and additional results are in the [Details for Google Compute Engine
(NVIDIA® Tesla® K80)](#details-for-google-compute-engine-nvidia®-tesla®-k80) and
[Details for Amazon EC2 (NVIDIA® Tesla®
K80)](#details-for-amazon-ec2-nvidia®-tesla®-k80) sections.
### Distributed training with NVIDIA® Tesla® K80
<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="../images/perf_summary_k80_aws_distributed.png">
</div>
Details and additional results are in the [Details for Amazon EC2 Distributed
(NVIDIA® Tesla® K80)](#details-for-amazon-ec2-distributed-nvidia®-tesla®-k80)
section.
### Compare synthetic with real data training
**NVIDIA® Tesla® P100**
<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:35%" src="../images/perf_summary_p100_data_compare_inceptionv3.png">
<img style="width:35%" src="../images/perf_summary_p100_data_compare_resnet50.png">
</div>
**NVIDIA® Tesla® K80**
<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:35%" src="../images/perf_summary_k80_data_compare_inceptionv3.png">
<img style="width:35%" src="../images/perf_summary_k80_data_compare_resnet50.png">
</div>
## Details for NVIDIA® DGX-1™ (NVIDIA® Tesla® P100)
### Environment
* **Instance type**: NVIDIA® DGX-1™
* **GPU:** 8x NVIDIA® Tesla® P100
* **OS:** Ubuntu 16.04 LTS with tests run via Docker
* **CUDA / cuDNN:** 8.0 / 5.1
* **TensorFlow GitHub hash:** b1e174e
* **Build Command:** `bazel build -c opt --copt=-march="haswell" --config=cuda
//tensorflow/tools/pip_package:build_pip_package`
* **Disk:** local SSD
* **DataSet:** ImageNet
Batch size and optimizer used for each model.
| InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
------------------ | ----------- | --------- | ---------- | ------- | -----
Batch size per GPU | 64 | 64 | 64 | 512 | 64
Optimizer | sgd | sgd | sgd | sgd | sgd
Configuration used for each model.
Model | variable_update | local_parameter_device
----------- | ---------------------- | ----------------------
InceptionV3 | parameter_server | cpu
ResNet50 | parameter_server | cpu
ResNet152 | parameter_server | cpu
AlexNet | replicated (with NCCL) | n/a
VGG16 | replicated (with NCCL) | n/a
### Results
Batch size and optimizer used for each model are listed in the table below.
<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="../images/perf_summary_p100_single_server.png">
</div>
<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:35%" src="../images/perf_dgx1_synth_p100_single_server_scaling.png">
<img style="width:35%" src="../images/perf_dgx1_real_p100_single_server_scaling.png">
</div>
**Training synthetic data**
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
---- | ----------- | --------- | ---------- | ------- | -----
1 | 142 | 238 | 95.6 | 2987 | 132
2 | 284 | 479 | 187 | 5658 | 259
4 | 569 | 948 | 374 | 10509 | 511
8 | 1131 | 1886 | 744 | 17822 | 959
**Training real data**
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
---- | ----------- | --------- | ---------- | ------- | -----
1 | 142 | 239 | 95.5 | 2890 | 132
2 | 278 | 468 | 187 | 4448 | 245
4 | 551 | 938 | 373 | 7105 | 466
8 | 1079 | 1802 | 721 | N/A | 794
Training AlexNet with real data on 8 GPUs was excluded from the graph and table
above due to it maxing out the input pipeline.
## Details for Google Compute Engine (NVIDIA® Tesla® K80)
### Environment
* **Instance type**: n1-standard-32-k80x8
* **GPU:** 8x NVIDIA® Tesla® K80
* **OS:** Ubuntu 16.04 LTS
* **CUDA / cuDNN:** 8.0 / 5.1
* **TensorFlow GitHub hash:** b1e174e
* **Build Command:** `bazel build -c opt --copt=-march="haswell" --config=cuda
//tensorflow/tools/pip_package:build_pip_package`
* **Disk:** 1.7 TB Shared SSD persistent disk (800 MB/s)
* **DataSet:** ImageNet
* **Test Date:** April 2017
Batch size and optimizer used for each model are listed in the table below. In
addition to the batch sizes listed in the table, InceptionV3 and ResNet-50 were
tested with a batch size of 32. Those results are in the *other results*
section.
| InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
------------------ | ----------- | --------- | ---------- | ------- | -----
Batch size per GPU | 64 | 64 | 32 | 512 | 32
Optimizer | sgd | sgd | sgd | sgd | sgd
The configuration used for each model was `variable_update` equal to
`parameter_server` and `local_parameter_device` equal to `cpu`.
### Results
<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:35%" src="../images/perf_gce_synth_k80_single_server_scaling.png">
<img style="width:35%" src="../images/perf_gce_real_k80_single_server_scaling.png">
</div>
**Training synthetic data**
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
---- | ----------- | --------- | ---------- | ------- | -----
1 | 30.5 | 56.8 | 20.8 | 656 | 30.3
2 | 57.8 | 107 | 39.1 | 1210 | 56.2
4 | 116 | 212 | 77.2 | 2330 | 106
8 | 227 | 419 | 151 | 4640 | 222
**Training real data**
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
---- | ----------- | --------- | ---------- | ------- | -----
1 | 30.5 | 56.7 | 20.7 | 639 | 30.2
2 | 57.8 | 107 | 39 | 1136 | 55.5
4 | 115 | 211 | 77.3 | 2067 | 106
8 | 225 | 418 | 150 | 4056 | 213
### Other Results
**Training synthetic data**
GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32)
---- | --------------------------- | -------------------------
1 | 29.3 | 53.9
2 | 55.0 | 101
4 | 109 | 200
8 | 216 | 398
**Training real data**
GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32)
---- | --------------------------- | -------------------------
1 | 29.3 | 53.6
2 | 55 | 102
4 | 109 | 200
8 | 215 | 387
## Details for Amazon EC2 (NVIDIA® Tesla® K80)
### Environment
* **Instance type**: p2.8xlarge
* **GPU:** 8x NVIDIA® Tesla® K80
* **OS:** Ubuntu 16.04 LTS
* **CUDA / cuDNN:** 8.0 / 5.1
* **TensorFlow GitHub hash:** b1e174e
* **Build Command:** `bazel build -c opt --copt=-march="haswell" --config=cuda
//tensorflow/tools/pip_package:build_pip_package`
* **Disk:** 1TB Amazon EFS (burst 100 MiB/sec for 12 hours, continuous 50
MiB/sec)
* **DataSet:** ImageNet
* **Test Date:** April 2017
Batch size and optimizer used for each model are listed in the table below. In
addition to the batch sizes listed in the table, InceptionV3 and ResNet-50 were
tested with a batch size of 32. Those results are in the *other results*
section.
| InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
------------------ | ----------- | --------- | ---------- | ------- | -----
Batch size per GPU | 64 | 64 | 32 | 512 | 32
Optimizer | sgd | sgd | sgd | sgd | sgd
Configuration used for each model.
Model | variable_update | local_parameter_device
----------- | ------------------------- | ----------------------
InceptionV3 | parameter_server | cpu
ResNet-50 | replicated (without NCCL) | gpu
ResNet-152 | replicated (without NCCL) | gpu
AlexNet | parameter_server | gpu
VGG16 | parameter_server | gpu
### Results
<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:35%" src="../images/perf_aws_synth_k80_single_server_scaling.png">
<img style="width:35%" src="../images/perf_aws_real_k80_single_server_scaling.png">
</div>
**Training synthetic data**
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
---- | ----------- | --------- | ---------- | ------- | -----
1 | 30.8 | 56.3 | 20.9 | 684 | 32.4
2 | 58.7 | 108 | 39.3 | 1244 | 61.5
4 | 117 | 217 | 79.1 | 2479 | 123
8 | 230 | 419 | 156 | 4853 | 234
**Training real data**
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
---- | ----------- | --------- | ---------- | ------- | -----
1 | 30.5 | 56.0 | 20.6 | 674 | 32.0
2 | 58.7 | 107 | 39.0 | 1227 | 61.0
4 | 118 | 205 | 77.9 | 2201 | 120
8 | 228 | 405 | 152 | N/A | 191
Training AlexNet with real data on 8 GPUs was excluded from the graph and table
above due to our EFS setup not providing enough throughput.
### Other Results
**Training synthetic data**
GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32)
---- | --------------------------- | -------------------------
1 | 29.9 | 53.5
2 | 57.5 | 101
4 | 114 | 202
8 | 216 | 380
**Training real data**
GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32)
---- | --------------------------- | -------------------------
1 | 30.0 | 53.6
2 | 57.5 | 101
4 | 113 | 202
8 | 212 | 379
## Details for Amazon EC2 Distributed (NVIDIA® Tesla® K80)
### Environment
* **Instance type**: p2.8xlarge
* **GPU:** 8x NVIDIA® Tesla® K80
* **OS:** Ubuntu 16.04 LTS
* **CUDA / cuDNN:** 8.0 / 5.1
* **TensorFlow GitHub hash:** b1e174e
* **Build Command:** `bazel build -c opt --copt=-march="haswell" --config=cuda
//tensorflow/tools/pip_package:build_pip_package`
* **Disk:** 1.0 TB EFS (burst 100 MB/sec for 12 hours, continuous 50 MB/sec)
* **DataSet:** ImageNet
* **Test Date:** April 2017
The batch size and optimizer used for the tests are listed in the table. In
addition to the batch sizes listed in the table, InceptionV3 and ResNet-50 were
tested with a batch size of 32. Those results are in the *other results*
section.
| InceptionV3 | ResNet-50 | ResNet-152
------------------ | ----------- | --------- | ----------
Batch size per GPU | 64 | 64 | 32
Optimizer | sgd | sgd | sgd
Configuration used for each model.
Model | variable_update | local_parameter_device
----------- | ---------------------- | ----------------------
InceptionV3 | distributed_replicated | n/a
ResNet-50 | distributed_replicated | n/a
ResNet-152 | distributed_replicated | n/a
To simplify server setup, EC2 instances (p2.8xlarge) running worker servers also
ran parameter servers. Equal numbers of parameter servers and work servers were
used with the following exceptions:
* InceptionV3: 8 instances / 6 parameter servers
* ResNet-50: (batch size 32) 8 instances / 4 parameter servers
* ResNet-152: 8 instances / 4 parameter servers
### Results
<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:95%" src="../images/perf_summary_k80_aws_distributed.png">
</div>
<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:70%" src="../images/perf_aws_synth_k80_distributed_scaling.png">
</div>
**Training synthetic data**
GPUs | InceptionV3 | ResNet-50 | ResNet-152
---- | ----------- | --------- | ----------
1 | 29.7 | 55.0 | 19.8
8 | 229 | 410 | 150
16 | 459 | 825 | 300
32 | 902 | 1468 | 575
64 | 1783 | 3051 | 1004
### Other Results
<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:50%" src="../images/perf_aws_synth_k80_multi_server_batch32.png">
</div>
**Training synthetic data**
GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32)
---- | --------------------------- | -------------------------
1 | 29.2 | 53.0
8 | 219 | 363
16 | 427 | 719
32 | 820 | 1265
64 | 1608 | 2623
## Appendix
### Executing benchmark tests
The code for the benchmarks was created to be both used for benchmarking
TensorFlow as well as used as a tool to test hardware platforms. The benchmark
code includes modes such as `trivial` that run a virtually empty model that is
useful for testing the maximum possibly samples/sec for the input pipeline among
other things. Not only does this test TensorFlow but also the throughput of the
underlying systems. There are two ways to execute the benchmarks in
[tf_cnn_benchmarks.py](TODO: LINK TO GITHUB):
1. Execute [tf_cnn_benchmarks.py](TODO: LINK TO GITHUB) directly
2. Utilize the [small wrapper](TODO: LINK TO GITHUB) that helps pick the
correct config
The wrapper is suggested as a starting point. Then investigate the variety of
options available in `tf_cnn_benchmarks.py`. While the wrapper extensive
examples, below are a couple highlights.
Run ResNet-50 on a single instance with 8 GPUs. The `system` argument is used to
determine the optimal configuration. The supported values are gce, aws, and
dgx1. If `system` is not passeed, the best config for the most widely available
hardware is used.
```bash
python main.py --model=resnet50 --num_gpus=8
python main.py --system=aws --model=resnet50 --num_gpus=8
```
Run ResNet-50 on 2 hosts, e.g. host_0 (10.0.0.1) and host_1 (10.0.0.2), with 8
GPUs each on aws.
```bash
# Run the following commands on host_0 (10.0.0.1):
$ python main.py --system=aws --model=resnet50 --job_name=worker
--hosts=10.0.0.1,10.0.0.2 --task_index=0
$ python main.py --system=aws --model=resnet50 --job_name=ps
--hosts=10.0.0.1,10.0.0.2 --task_index=0
# Run the following commands on host_1 (10.0.0.2):
$ python main.py --system=aws --model=resnet50 --job_name=worker
--hosts=10.0.0.1,10.0.0.2 --task_index=1
$ python main.py --system=aws --model=resnet50 --job_name=ps
--hosts=10.0.0.1,10.0.0.2 --task_index=1
```
### Methodology
Unless otherwise stated, each test is run 5 times and then the times are
averaged together. GPUs are run in their default state on the given platform.
For NVIDIA® Tesla® K80 this means leaving on [GPU
Boost](https://devblogs.nvidia.com/parallelforall/increase-performance-gpu-boost-k80-autoboost/)
unless it has been turned off by the provider. For a given test, 10 warmup steps
are done and then the next 100 steps are averaged.

View File

@ -0,0 +1,389 @@
# High-Performance Models
TensorFlow is a powerful and flexible machine learning platform.
It can be used to distribute model training and inference across a large number
of machines and computation devices.
Its software stack is made of a few layers:
* a fast and powerful C++ core
* low-level Python primitives that sit right above individual kernels
* a diverse range of high-level libraries that aim to make building real models
easier
There are many existing examples and tutorials that explain useful features in
TensorFlow. The goal of this set of scripts is to demonstrate that we can build
flexible and powerful high-performance models using the low-level APIs.
In the future, many of the high-performance primitives will be incorporated into
high-level APIs, and made available to more users transparently.
But meanwhile, we show that it is fairly easy for advanced users to build highly
scalable models targeting different system types, network topologies, etc.
We divide our effort to build high-performance models into three categories:
1. A fast input pipeline to read data from disk, preprocess it, and make it
ready on the GPU.
2. A high-throughput model that trains on GPU very efficiently.
3. Fast variable and gradients distribution mechanisms that scale well across
many machines and computation devices.
## Input Pipeline
The input pipeline is the part of a TensorFlow program that reads input data,
shuffles it, and preprocesses it.
Among the most important features to build a fast input pipeline:
* Avoid using feed-dictionary to feed a large amount of data for each step.
* Instead, use reader ops to get data into TensorFlow directly.
* Parallelize data processing.
* Use software pipelining to feed data, so that data is available immediately
when needed.
One way to implement software pipelining in TensorFlow is through
`tf.FifoQueue`, and it is possible to parallelize data processing through
`tf.train.queue_runner`, which uses Python threads as its underlying
implementation.
This lays the foundation for the current Inception input pipeline.
This design is well built for feeding older generation of GPUs,
but the overhead of Python threads is too large to feed newer GPUs that are four
to five times faster.
In this model, we explore an alternative design that uses the native
parallelism in TensorFlow. In our example of an image model input pipeline,
there are a few important parts:
* Choose and read the image files from the disk.
* Decode the image data into images, transform and add distortion so they are
ready to be used.
* Organize the transformed images into a minibatch.
* Transfer the images from CPU to GPU, so they are ready for model training.
It is important to note that the dominant part of each stage can happen in
parallel with that of other stages:
the file IO uses DMA to transfer the data from hard disk to memory;
image decoding, transformation and distortion are CPU-heavy;
the data transfer from CPU to GPU uses the GPU's copy-engine unit;
and the GPU kernels use the main SMs of the GPU.
It is natural to cut our pipeline into those parts so they can run in parallel
with each other.
Also, as mentioned earlier, most of the current input pipeline heavily uses
Python threads. However, the large overhead introduced by Python threads
severely limits its scalability when the newer GPUs are a lot faster; we can
alleviate this by making a single `session.run` call execute all parts of the
pipeline.
### Parallelize IO Reads
In this new model, we use the native parallelism in TensorFlow: TensorFlow
subscribes to an eager-execution model, which means that when nodes in the graph
became available, TensorFlow will try to execute as many of them as possible.
In order to parallelize reading from hard disk, we use `data_flow_ops.RecordInput`
in this model.
Given a list of input files of TFRecords, `RecordInput` continuously reads
records using background threads, placing the records into its own large,
internal pool of records.
When it is has loaded at least half of its capacity, it produces output tensors.
Since this op has its internal threads, and is dominated by IO time that doesnt
consume much CPU time, it naturally runs in parallel with the rest of the model.
### Parallelize Image Processing
After reading from “RecordInput”, the tensors are passed to the input processing
pipeline. For example, if we need to feed 8 GPUs, each with a batch-size of 32,
then for each step we do the following.
First, read 32x8=256 records, and process them individually, in
parallel. This starts with 256 independent RecordInput read ops in the graph.
Then, follow each read with identical set of ops for processing. Each set is
considered independent and will execute in parallel. The operations include
image decoding, image distortion, and resizing.
Finally, once the images are ready, they will be concatenated together into 8
batch-size 32 tensors.
Note that we can use “tf.concat” for this purpose.
However, “tf.concat” is implemented as a single op, which waits for all
the inputs to be ready, and then concatenates them together. Since all
inputs are produced in parallel, there will be a long tail waiting for all
inputs to be available; and when concatenation happens, the op becomes memory
limited as all input tensors compete for memory bandwidth.
So for the final concatenation, we use `tf.parallel_stack` instead. This
allocates an uninitialized tensor as an output, and each input tensor is written
to its designated portion of the output tensor as soon as the input is
available. When all the input tensors are finished, the output tensor is passed
along in the graph. This effectively hides all the memory latency with the long
tail of producing all the input tensors.
### Parallelize CPU-to-GPU Data Transfer
In our example, once all the input images are processed and concatenated
together by the CPU, we have 8 tensors, each of which has a batch-size of 32.
These tensors are then to be used by the GPU for the model training.
In TensorFlow, users can use tensors from one device on any other device
directly. TensorFlow inserts implicit copies to make the tensors available on
any devices where they are used. The runtime schedules the copy between devices
to run before the tensors are actually used. However, if the copy cannot finish
in time, the computation that needs those tensors will stall.
For high-performance models, it is helpful to explicitly schedule the copy ahead
of the time in parallel, so when the computation starts on GPU, all the tensors
are already available on the right device.
### Software Pipelining
With all the stages capable of being driven by different processors, we insert
`data_flow_ops.StagingArea` in between them so they run in parallel.
`StagingArea` is a queue-like operator similar to `tf.FifoQueue`.
But it offers simpler functionalities and can be executed on both CPU and GPU.
Before the model starts running all the stages, we warm up the stages in order
so the staging buffers in between all have one set of data in them.
During each run step that follows, we will run all the stages.
They read one set of data from the staging buffers at the beginning of each
stage, and push one set at end end.
For example: if there are three stages: A, B and C.
There are two staging areas in between: S1 and S2.
During the warm up, we run:
```
Warm up:
Step 1: A0
Step 2: A1 B0
Actual execution:
Step 3: A2 B1 C0
Step 4: A3 B2 C1
Step 5: A4 B3 C2
```
After the warm up, S1 and S2 each have one set of data in them.
For each step of the actual execution, one set of data is consumed from each
staging area, and one set is added to each.
There are a few nice properties about the scheme:
* All the stages are non-blocking, since the staging areas always have one set
of data after the warm up.
* Each stage can run in parallel since they can all start immediately.
* The staging buffers have a fixed memory overhead. They will have at most one
extra set of data.
* Only a single`session.run()` call is needed to run all stages of the step,
which makes profiling and debugging much easier.
## Best Practices in Building High-Performance Models
The computation on GPU can happen immediately since the input data have already
been transferred onto GPU when the step starts.
But it is still important to build the model that runs as fast as possible.
Here are some tips for a high-performance convolutional neural network (CNN)
model:
### Build the model with both NHWC and NCHW
Most TensorFlow operations used by a CNN support both NHWC and NCHW data format.
On GPU, NCHW is faster.
But on CPU, NHWC is sometimes faster.
So it is a good idea to build the model that can work in both ways.
Our model shows a good way to do that effectively.
For GPU training, we should always use NCHW.
But if the model needs inference on CPU, we could use NHWC; weights obtained
from training with NCHW data format can be used for inference in NHWC data
format.
### Use Fused Batch-Normalization
The default batch-normalization in TensorFlow is implemented as composite
operations.
This is very general, but often leads to suboptimal performance.
An alternative is the fused batch-normalization, and the performance on GPU
is often much faster.
## Variable Distribution and Gradient Aggregation
During training, training variable values are updated using aggregated gradients
and deltas. In this model, we demonstrate that with the flexible and
general-purpose TensorFlow primitives, it is fairly easy to build a diverse
range of high-performance distribution and aggregation schemes for different
types of systems.
For example:
* The standard parameter-server where each replica of the training model reads
the variables directly, and updates the variable independently. When each
model needs the variables, they are copied over through the standard implicit
copies added by the TensorFlow runtime. It is shown how to use this method
in either local training, distributed synchronous training, and distributed
asynchronous training.
* A replicated mode for local training where each GPU has an identical
copy of the training parameters. The forward and backward computation can
start immediately as the variable data is immediately available. Gradients
are accumulated across all GPUs, and the aggregated total is applied to
each GPU's copy of the variables so that they stay in sync.
* A distributed replicated mode of training where each GPU has an identical copy
of the training parameters, and a master copy of the variables is stored
on the parameter-servers. The forward and backward computation can
start immediately as the variable data is immediately available. Gradients
are accumulated across all GPUs on each server and then the per-server
aggregated gradients are applied to the master copy. After all workers do
this, each worker updates its copy of the variable from the master copy.
We show that most of the variable distribution and aggregation subsystem can
be implemented through TensorFlow low-level primitives with manageable
complexity at the model level. Here we discuss some more details.
### Parameter-server Variables
The most common way trainable variables are managed in TensorFlow models is the
parameter server mode.
In a distributed system, this means that each worker process runs the same
model, and parameter server processes own the master copies of the variables.
When a worker needs a variable from a parameter server, it refers to it
directly. The TensorFlow runtime adds implicit copies to the graph to make the
variable value available on the computation device that needs it. When a
gradient is computed on a worker, it is sent to the parameter server that owns
the particular variable, and the corresponding optimizer is used to update the
variable.
There are some techniques to improve throughput:
* The variables are spread among parameter servers based on their size, for load
balancing.
* When each worker has multiple GPUs, gradients are accumulated across the GPUs
and a single aggregated gradient is sent to the parameter server. This reduces
the network bandwidth and the amount of work done by the parameter servers.
For coordinating between workers, a very common mode is async updates, where
each worker updates the master copy of the variables without synchronizing with
other workers. In our model, we demonstrate that it is fairly easy to introduce
synchronization across workers so updates for all workers are finished in one
step before the next step can start.
The parameter-server method can also be used for local training, In this case,
instead of spreading the master copies of variables across parameters servers,
they are either on the CPU or spread across the available GPUs.
Due to the simple nature of this setup, this architecture has gained a lot of
popularity within the community.
This is available in the benchmark scripts as the 'parameter_server'
variable_update mode.
![parameter_server mode in distributed
training](../images/perf_parameter_server_mode_doc.png){
width="900" style="max-width: inherit"}
### Replicated Variables
In this design, each GPU on the server has its own copy of each variable. The
values are kept in sync across GPUs by applying the fully aggregated gradient to
each GPU's copy of the variable.
The variables and data are available at the start of training, so the forward
pass of training can start immediately. Gradients are aggregated across the
devices and the fully aggregated gradient is then applied to each local copy.
Gradient aggregation across the server can be done in different ways:
* Using standard TensorFlow operations to accumulate the total on a single
device (CPU or GPU) and then copy it back to all GPUs.
* Using NVIDIA NCCL, described below in the NCCL section.
This is available in the benchmark scripts for local execution only, as the
'replicated' variable_update mode.
### Replicated Variables in Distributed Training
The replicated method for variables can be extended to distributed training.
One way to do this like the replicated mode: aggregate the gradients fully
across the cluster and apply them to each local copy of the variable. This may
be shown in a future version of this scripts; the scripts do present a different
variation, described here.
In this mode, in addition to each GPU's copy of the variables, a master copy is
stored on the parameter servers. As with the replicated mode, training can start
immediately using the local copies of the variables.
As the gradients of the weights become available, they are sent back to the
parameter servers and all local copies are updated:
1. All the gradients from the GPU on the same worker are aggregated together.
2. Aggregated gradients from each worker are sent to the parameter server that
owns the variable, where the specified optimizer is used to update the
master copy of the variable.
3. Each worker updates its local copy of the variable from the master. In
the example model, this is done with a cross-replica barrier that waits for
all the workers to finish updating the variables, and fetches the new
variable only after the barrier has been released by all replicas. Once the
copy finishes for all variables, this marks the end of a training step, and a
new step can start.
Although this sounds similar to the standard use of parameter servers, the
performance is often better in many cases. This is largely due to the fact the
computation can happen without any delay, and much of the copy latency of early
gradients can be hidden by later computation layers.
This is available in the benchmark scripts as the 'distributed_replicated'
variable_update mode.
![distributed_replicated mode](
../images/perf_distributed_replicated_mode_doc.png){
width="900" style="max-width: inherit"}
#### NCCL
In order to broadcast variables and aggregate gradients across different GPUs
within the same host machine, we can use the default TensorFlow implicit copy
mechanism.
However, we can instead use the optional NCCL support. NCCL is an NVIDIA
library that can efficiently broadcast and aggregate data across different GPUs.
It schedules a cooperating kernel on each GPU that knows how to best utilize the
underlying hardware topology; this kernel uses a single SM of the GPU.
In our experiment, we demonstrate that although NCCL often leads to much faster
data aggregation by itself, it doesn't necessarily lead to faster training. Our
hypothesis is that the implicit copies are essentially free since they go to the
copy engine on GPU, as long as its latency can be hidden by the main computation
itself. Although NCCL can transfer data faster, it takes one SM away, and adds
more pressure to the underlying L2 cache. Our results show that for 8-GPUs,
NCCL often leads to better performance. However, for fewer GPUs, the implicit
copies often perform better.
#### Staged Variables
We further introduce a staged-variable mode where we use staging areas for both
the variable reads, and their updates.
Similar to software pipelining of the input pipeline, this can hide the data
copy latency.
If the computation time takes longer than the copy and aggregation, the copy
itself becomes essentially free.
The downside is that all the weights read are from the previous training step.
So it is a different algorithm from SGD.
But it is possible to improve its convergence by adjusting learning rate and
other hyperparameters.
## Conclusions
In this high-performance model, we present a number of options to build
high-performance models in TensorFlow.
Due to the flexible design in TensorFlow, advanced features like this often
requires no system-level changes, and can be largely achieved through
model-level changes.
We do not claim which combination works best for a particular model.
That should be left to the engineers who build the model and the training system.
Many of the ingredients of the high-performance model will find their ways
to high-level primitives that become transparent to users.
However, we have shown that advanced users can easily tune and modify the
underlying model behavior using low-level primitives.
This could be very useful when improving performance for particular system
setups and model configurations.

View File

@ -8,6 +8,7 @@ supervisor.md
debugger.md
tfdbg-tflearn.md
meta_graph.md
saved_model_cli.md
version_semantics.md
data_versions.md
faq.md

View File

@ -0,0 +1,251 @@
# SavedModel CLI (Command-Line Interface)
[`SavedModel`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md)
is a universal serialization format for Tensorflow. It provides a
language-neutral format to save machine-learned models and enables higher-level
systems and tools to produce, consume and transform TensorFlow models.
We provide SavedModel CLI(command-line interface) as a tool to inspect and
execute a [`MetaGraph`](https://www.tensorflow.org/programmers_guide/meta_graph)
in a SavedModel. You can inspect for example, what
[`SignatureDefs`](https://github.com/tensorflow/serving/blob/master/tensorflow_serving/g3doc/signature_defs.md),
including their input and output tensors, are in the model without writing any
code. This can be useful in situations such as when you want to quickly check
your input dtype and shape match with the model. Moreover, if you want to test
out the model, it also allows you to do a sanity check by passing in sample
inputs in the format of for example, python expressions, and fetch the outputs
simply through command line.
## Get SavedModel CLI
If TensorFlow is installed on your system through pip, the `saved_model_cli`
binary can be invoked directly from command line.
To build the binary from source, run the following command:
```
$bazel build tensorflow/python/tools:saved_model_cli
```
## Commands
SavedModel CLI allows users to both show and run computations on a
[`MetaGraphDef`](https://www.tensorflow.org/code/tensorflow/core/protobuf/meta_graph.proto)
in a SavedModel. These are done through `show` and `run` commands. We will
explain the usages of both commands with detailed examples. SavedModel CLI will
also display this information with `-h` option.
### `show` command
A SavedModel contains one or more MetaGraphs, identified by their tag-sets. Each
MetaGraph contains both a TensorFlow GraphDef as well as associated metadata
necessary for running computation in a graph. In order to serve a model, you
might wonder what kind of SignatureDefs are in each model, and what are their
inputs and outputs etc. The `show` command let you examine the content of the
SavedModel in a hierarchical order.
```
usage: saved_model_cli show [-h] --dir DIR [--all]
[--tag_set TAG_SET] [--signature_def SIGNATURE_DEF_KEY]
```
#### Examples
To show all available MetaGraphDef tag-sets in the SavedModel:
```
$saved_model_cli show --dir /tmp/saved_model_dir
The given SavedModel contains the following tag-sets:
serve
serve, gpu
```
To show all available SignatureDef keys in a MetaGraphDef:
```
$saved_model_cli show --dir /tmp/saved_model_dir --tag_set serve
The given SavedModel MetaGraphDef contains SignatureDefs with the following keys:
SignatureDef key: "classify_x2_to_y3"
SignatureDef key: "classify_x_to_y"
SignatureDef key: "regress_x2_to_y3"
SignatureDef key: "regress_x_to_y"
SignatureDef key: "regress_x_to_y2"
SignatureDef key: "serving_default"
```
For a MetaGraphDef with multiple tags in the tag-set, all tags must be passed
in, separated by ',':
```
$saved_model_cli show --dir /tmp/saved_model_dir --tag_set serve,gpu
```
To show all inputs and outputs TensorInfo for a specific SignatureDef, pass in
the SignatureDef key to `signature_def` option. This is very useful when you
want to know the tensor key value, dtype and shape of the input tensors for
executing the computation graph later.
```
$saved_model_cli show --dir \
/tmp/saved_model_dir --tag_set serve --signature_def serving_default
The given SavedModel SignatureDef contains the following input(s):
inputs['x'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 1)
name: x:0
The given SavedModel SignatureDef contains the following output(s):
outputs['y'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 1)
name: y:0
Method name is: tensorflow/serving/predict
```
To show all available information in the SavedModel, use `--all` option:
```
$saved_model_cli show --dir /tmp/saved_model_dir --all
MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:
signature_def['classify_x2_to_y3']:
The given SavedModel SignatureDef contains the following input(s):
inputs['inputs'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 1)
name: x2:0
The given SavedModel SignatureDef contains the following output(s):
outputs['scores'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 1)
name: y3:0
Method name is: tensorflow/serving/classify
...
signature_def['serving_default']:
The given SavedModel SignatureDef contains the following input(s):
inputs['x'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 1)
name: x:0
The given SavedModel SignatureDef contains the following output(s):
outputs['y'] tensor_info:
dtype: DT_FLOAT
shape: (-1, 1)
name: y:0
Method name is: tensorflow/serving/predict
```
### `run` command
SavedModel CLI also allows you to run a graph computation by passing in inputs,
displaying, and saving the outputs.
```
usage: saved_model_cli run [-h] --dir DIR --tag_set TAG_SET --signature_def
SIGNATURE_DEF_KEY [--inputs INPUTS]
[--input_exprs INPUT_EXPRS] [--outdir OUTDIR]
[--overwrite] [--tf_debug]
```
Tensor keys are used to specify which input we are passing in the values for.
There are two ways to pass inputs to the model. With '--inputs' option, you can
pass in numpy ndarray by files. With '--input_exprs' option, you can use python
expressions as inputs.
#### Input By File
To pass in inputs by files, use `--inputs` option in the format of
`<input_key>=<filename>`, or `<input_key>=<filename>[<variable_name>]`. Each
input is separated by semicolon. File specified by `filename` will be loaded
using `numpy.load`. Inputs can be loaded from only `.npy`, `.npz` or pickle
files. The `variable_name` key is optional depending on the input file type as
descripted in more details below.
When loading from a `.npy` file, which always contains a numpy ndarray, the
content will be directly assigned to the specified input tensor. If a
`variable_name` is specified, it will be ignored and a warning will be issued.
When loading from a `.npz` zip file, user can specify which variable within the
zip file to load for the input tensor key with `variable_name`. If nothing is
specified, SavedModel CLI will check that only one file is included in the zip
file and load it for the specified input tensor key.
When loading from a pickle file, if no `variable_name` is specified in the
square brackets, whatever that is inside the pickle file will be passed to the
specified input tensor key. Else SavedModel CLI will assume a dictionary is
stored in the pickle file and the value corresponding to the variable_name will
be used.
#### Input By Python Expression
To pass in inputs by python expressions, use `--input_exprs` option. `numpy`
module is available as `np`. For example, `input_key=np.ones((32, 32, 3))` or
`input_key=[[1], [2], [3]]`. This can be useful for when you don't have data
files lying around, but still want to sanity check the model with some simple
inputs that match the dtype and shape of the model signature.
#### Save Output
By default, SavedModel CLI will print outputs to console. If a directory is
passed to `--outdir` option, the outputs will be saved as npy files named after
output tensor keys under the given directory. Use `--overwite` to overwrite
existing output files.
#### TensorFlow Debugger (tfdbg) Integration
If `--tf_debug` option is set, SavedModel CLI will use TensorFlow Debugger
(tfdbg) to watch the intermediate Tensors and runtime GraphDefs while running
the SavedModel.
#### Examples
If we have a simple model that adds `x1` and `x2` to get output `y`, where all
tensors are of shape `(-1, 1)`, and we have two `npz` files. File
`/tmp/my_data1.npy` contains a numpy ndarray `[[1], [2], [3]]`, file
`/tmp/my_data2.npy` contains another numpy ndarray `[[0.5], [0.5], [0.5]]`. Now
let's run these two `npy` files through the model to get `y`:
```
$saved_model_cli run --dir /tmp/saved_model_dir --tag_set serve \
--signature_def x1_x2_to_y --inputs x1=/tmp/my_data1.npy;x2=/tmp/my_data2.npy \
--outdir /tmp/out
Result for output key y:
[[ 1.5]
[ 2.5]
[ 3.5]]
```
Similarly, we can run input tensors from `npz` file and pickle file, as well as
overwrite the previous output file:
```
$saved_model_cli run --dir /tmp/saved_model_dir --tag_set serve \
--signature_def x1_x2_to_y \
--inputs x1=/tmp/my_data1.npz[x];x2=/tmp/my_data2.pkl --outdir /tmp/out \
--overwrite
Result for output key y:
[[ 1.5]
[ 2.5]
[ 3.5]]
```
You can also use python expression instead of input file. Here we replace input
`x2` with a python expression:
```
$saved_model_cli run --dir /tmp/saved_model_dir --tag_set serve \
--signature_def x1_x2_to_y --inputs x1=/tmp/my_data1.npz[x] \
--input_exprs 'x2=np.ones((3,1))'
Result for output key y:
[[ 2]
[ 3]
[ 4]]
```
To run model with TensorFlow Debugger on:
```
$saved_model_cli run --dir /tmp/saved_model_dir --tag_set serve \
--signature_def serving_default --inputs x=/tmp/data.npz[x] --tf_debug
```

View File

@ -62,7 +62,7 @@ with tf.device("/job:ps/task:7"):
v = tf.Variable(...)
```
**N.B.** Operations that mutate a variable, such as
**NOTE** Operations that mutate a variable, such as
@{tf.Variable.assign} and the parameter
update operations in a
@{tf.train.Optimizer} *must* run on

View File

@ -3522,6 +3522,256 @@ func Stage(scope *Scope, values []tf.Output, optional ...StageAttr) (o *tf.Opera
return scope.AddOperation(opspec)
}
// Table initializer that takes two tensors for keys and values respectively.
//
// Arguments:
// table_handle: Handle to a table which will be initialized.
// keys: Keys of type Tkey.
// values: Values of type Tval.
//
// Returns the created operation.
func InitializeTableV2(scope *Scope, table_handle tf.Output, keys tf.Output, values tf.Output) (o *tf.Operation) {
if scope.Err() != nil {
return
}
opspec := tf.OpSpec{
Type: "InitializeTableV2",
Input: []tf.Input{
table_handle, keys, values,
},
}
return scope.AddOperation(opspec)
}
// MutableHashTableV2Attr is an optional argument to MutableHashTableV2.
type MutableHashTableV2Attr func(optionalAttr)
// MutableHashTableV2Container sets the optional container attribute to value.
//
// value: If non-empty, this table is placed in the given container.
// Otherwise, a default container is used.
// If not specified, defaults to ""
func MutableHashTableV2Container(value string) MutableHashTableV2Attr {
return func(m optionalAttr) {
m["container"] = value
}
}
// MutableHashTableV2SharedName sets the optional shared_name attribute to value.
//
// value: If non-empty, this table is shared under the given name across
// multiple sessions.
// If not specified, defaults to ""
func MutableHashTableV2SharedName(value string) MutableHashTableV2Attr {
return func(m optionalAttr) {
m["shared_name"] = value
}
}
// MutableHashTableV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value.
//
// value: If true and shared_name is empty, the table is shared
// using the node name.
// If not specified, defaults to false
func MutableHashTableV2UseNodeNameSharing(value bool) MutableHashTableV2Attr {
return func(m optionalAttr) {
m["use_node_name_sharing"] = value
}
}
// Creates an empty hash table.
//
// This op creates a mutable hash table, specifying the type of its keys and
// values. Each value must be a scalar. Data can be inserted into the table using
// the insert operations. It does not support the initialization operation.
//
// Arguments:
// key_dtype: Type of the table keys.
// value_dtype: Type of the table values.
//
// Returns Handle to a table.
func MutableHashTableV2(scope *Scope, key_dtype tf.DataType, value_dtype tf.DataType, optional ...MutableHashTableV2Attr) (table_handle tf.Output) {
if scope.Err() != nil {
return
}
attrs := map[string]interface{}{"key_dtype": key_dtype, "value_dtype": value_dtype}
for _, a := range optional {
a(attrs)
}
opspec := tf.OpSpec{
Type: "MutableHashTableV2",
Attrs: attrs,
}
op := scope.AddOperation(opspec)
return op.Output(0)
}
// HashTableV2Attr is an optional argument to HashTableV2.
type HashTableV2Attr func(optionalAttr)
// HashTableV2Container sets the optional container attribute to value.
//
// value: If non-empty, this table is placed in the given container.
// Otherwise, a default container is used.
// If not specified, defaults to ""
func HashTableV2Container(value string) HashTableV2Attr {
return func(m optionalAttr) {
m["container"] = value
}
}
// HashTableV2SharedName sets the optional shared_name attribute to value.
//
// value: If non-empty, this table is shared under the given name across
// multiple sessions.
// If not specified, defaults to ""
func HashTableV2SharedName(value string) HashTableV2Attr {
return func(m optionalAttr) {
m["shared_name"] = value
}
}
// HashTableV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value.
//
// value: If true and shared_name is empty, the table is shared
// using the node name.
// If not specified, defaults to false
func HashTableV2UseNodeNameSharing(value bool) HashTableV2Attr {
return func(m optionalAttr) {
m["use_node_name_sharing"] = value
}
}
// Creates a non-initialized hash table.
//
// This op creates a hash table, specifying the type of its keys and values.
// Before using the table you will have to initialize it. After initialization the
// table will be immutable.
//
// Arguments:
// key_dtype: Type of the table keys.
// value_dtype: Type of the table values.
//
// Returns Handle to a table.
func HashTableV2(scope *Scope, key_dtype tf.DataType, value_dtype tf.DataType, optional ...HashTableV2Attr) (table_handle tf.Output) {
if scope.Err() != nil {
return
}
attrs := map[string]interface{}{"key_dtype": key_dtype, "value_dtype": value_dtype}
for _, a := range optional {
a(attrs)
}
opspec := tf.OpSpec{
Type: "HashTableV2",
Attrs: attrs,
}
op := scope.AddOperation(opspec)
return op.Output(0)
}
// Replaces the contents of the table with the specified keys and values.
//
// The tensor `keys` must be of the same type as the keys of the table.
// The tensor `values` must be of the type of the table values.
//
// Arguments:
// table_handle: Handle to the table.
// keys: Any shape. Keys to look up.
// values: Values to associate with keys.
//
// Returns the created operation.
func LookupTableImportV2(scope *Scope, table_handle tf.Output, keys tf.Output, values tf.Output) (o *tf.Operation) {
if scope.Err() != nil {
return
}
opspec := tf.OpSpec{
Type: "LookupTableImportV2",
Input: []tf.Input{
table_handle, keys, values,
},
}
return scope.AddOperation(opspec)
}
// Outputs all keys and values in the table.
//
// Arguments:
// table_handle: Handle to the table.
//
//
//
// Returns Vector of all keys present in the table.Tensor of all values in the table. Indexed in parallel with `keys`.
func LookupTableExportV2(scope *Scope, table_handle tf.Output, Tkeys tf.DataType, Tvalues tf.DataType) (keys tf.Output, values tf.Output) {
if scope.Err() != nil {
return
}
attrs := map[string]interface{}{"Tkeys": Tkeys, "Tvalues": Tvalues}
opspec := tf.OpSpec{
Type: "LookupTableExportV2",
Input: []tf.Input{
table_handle,
},
Attrs: attrs,
}
op := scope.AddOperation(opspec)
return op.Output(0), op.Output(1)
}
// Updates the table to associates keys with values.
//
// The tensor `keys` must be of the same type as the keys of the table.
// The tensor `values` must be of the type of the table values.
//
// Arguments:
// table_handle: Handle to the table.
// keys: Any shape. Keys to look up.
// values: Values to associate with keys.
//
// Returns the created operation.
func LookupTableInsertV2(scope *Scope, table_handle tf.Output, keys tf.Output, values tf.Output) (o *tf.Operation) {
if scope.Err() != nil {
return
}
opspec := tf.OpSpec{
Type: "LookupTableInsertV2",
Input: []tf.Input{
table_handle, keys, values,
},
}
return scope.AddOperation(opspec)
}
// Looks up keys in a table, outputs the corresponding values.
//
// The tensor `keys` must of the same type as the keys of the table.
// The output `values` is of the type of the table values.
//
// The scalar `default_value` is the value output for keys not present in the
// table. It must also be of the same type as the table values.
//
// Arguments:
// table_handle: Handle to the table.
// keys: Any shape. Keys to look up.
//
//
// Returns Same shape as `keys`. Values found in the table, or `default_values`
// for missing keys.
func LookupTableFindV2(scope *Scope, table_handle tf.Output, keys tf.Output, default_value tf.Output) (values tf.Output) {
if scope.Err() != nil {
return
}
opspec := tf.OpSpec{
Type: "LookupTableFindV2",
Input: []tf.Input{
table_handle, keys, default_value,
},
}
op := scope.AddOperation(opspec)
return op.Output(0)
}
// FakeQuantWithMinMaxArgsAttr is an optional argument to FakeQuantWithMinMaxArgs.
type FakeQuantWithMinMaxArgsAttr func(optionalAttr)
@ -7941,17 +8191,17 @@ func SdcaOptimizerAdaptative(value bool) SdcaOptimizerAttr {
// uniformly, and the optimizer is learning rate free and enjoys linear convergence
// rate.
//
// Proximal Stochastic Dual Coordinate Ascent, Shalev-Shwartz, Shai; Zhang, Tong.
// 2012 arXiv1211.2717S: http://arxiv.org/pdf/1211.2717v1.pdf
// [Proximal Stochastic Dual Coordinate Ascent](http://arxiv.org/pdf/1211.2717v1.pdf).<br>
// Shai Shalev-Shwartz, Tong Zhang. 2012
//
// Loss objective = \sum f_{i}(wx_{i}) + (l2 / 2) * |w|^2 + l1 * |w|
// $$Loss Objective = \sum f_{i} (wx_{i}) + (l2 / 2) * |w|^2 + l1 * |w|$$
//
// Adding vs. Averaging in Distributed Primal-Dual Optimization.
// Chenxin Ma, Virginia Smith, Martin Jaggi, Michael I. Jordan, Peter Richtarik,
// Martin Takac http://arxiv.org/abs/1502.03508
// [Adding vs. Averaging in Distributed Primal-Dual Optimization](http://arxiv.org/abs/1502.03508).<br>
// Chenxin Ma, Virginia Smith, Martin Jaggi, Michael I. Jordan,
// Peter Richtarik, Martin Takac. 2015
//
// Stochastic Dual Coordinate Ascent with Adaptive Probabilities
// Dominik Csiba, Zheng Qu, Peter Richtarik https://arxiv.org/abs/1502.08053
// [Stochastic Dual Coordinate Ascent with Adaptive Probabilities](https://arxiv.org/abs/1502.08053).<br>
// Dominik Csiba, Zheng Qu, Peter Richtarik. 2015
//
// Arguments:
// sparse_example_indices: a list of vectors which contain example indices.
@ -10019,6 +10269,70 @@ func QuantizedRelu(scope *Scope, features tf.Output, min_features tf.Output, max
return op.Output(0), op.Output(1), op.Output(2)
}
// InitializeTableFromTextFileV2Attr is an optional argument to InitializeTableFromTextFileV2.
type InitializeTableFromTextFileV2Attr func(optionalAttr)
// InitializeTableFromTextFileV2VocabSize sets the optional vocab_size attribute to value.
//
// value: Number of elements of the file, use -1 if unknown.
// If not specified, defaults to -1
//
// REQUIRES: value >= -1
func InitializeTableFromTextFileV2VocabSize(value int64) InitializeTableFromTextFileV2Attr {
return func(m optionalAttr) {
m["vocab_size"] = value
}
}
// InitializeTableFromTextFileV2Delimiter sets the optional delimiter attribute to value.
//
// value: Delimiter to separate fields in a line.
// If not specified, defaults to "\t"
func InitializeTableFromTextFileV2Delimiter(value string) InitializeTableFromTextFileV2Attr {
return func(m optionalAttr) {
m["delimiter"] = value
}
}
// Initializes a table from a text file.
//
// It inserts one key-value pair into the table for each line of the file.
// The key and value is extracted from the whole line content, elements from the
// split line based on `delimiter` or the line number (starting from zero).
// Where to extract the key and value from a line is specified by `key_index` and
// `value_index`.
//
// - A value of -1 means use the line number(starting from zero), expects `int64`.
// - A value of -2 means use the whole line content, expects `string`.
// - A value >= 0 means use the index (starting at zero) of the split line based
// on `delimiter`.
//
// Arguments:
// table_handle: Handle to a table which will be initialized.
// filename: Filename of a vocabulary text file.
// key_index: Column index in a line to get the table `key` values from.
// value_index: Column index that represents information of a line to get the table
// `value` values from.
//
// Returns the created operation.
func InitializeTableFromTextFileV2(scope *Scope, table_handle tf.Output, filename tf.Output, key_index int64, value_index int64, optional ...InitializeTableFromTextFileV2Attr) (o *tf.Operation) {
if scope.Err() != nil {
return
}
attrs := map[string]interface{}{"key_index": key_index, "value_index": value_index}
for _, a := range optional {
a(attrs)
}
opspec := tf.OpSpec{
Type: "InitializeTableFromTextFileV2",
Input: []tf.Input{
table_handle, filename,
},
Attrs: attrs,
}
return scope.AddOperation(opspec)
}
// ResourceSparseApplyProximalGradientDescentAttr is an optional argument to ResourceSparseApplyProximalGradientDescent.
type ResourceSparseApplyProximalGradientDescentAttr func(optionalAttr)
@ -10875,6 +11189,75 @@ func IFFT2D(scope *Scope, input tf.Output) (output tf.Output) {
return op.Output(0)
}
// MutableHashTableOfTensorsV2Attr is an optional argument to MutableHashTableOfTensorsV2.
type MutableHashTableOfTensorsV2Attr func(optionalAttr)
// MutableHashTableOfTensorsV2Container sets the optional container attribute to value.
//
// value: If non-empty, this table is placed in the given container.
// Otherwise, a default container is used.
// If not specified, defaults to ""
func MutableHashTableOfTensorsV2Container(value string) MutableHashTableOfTensorsV2Attr {
return func(m optionalAttr) {
m["container"] = value
}
}
// MutableHashTableOfTensorsV2SharedName sets the optional shared_name attribute to value.
//
// value: If non-empty, this table is shared under the given name across
// multiple sessions.
// If not specified, defaults to ""
func MutableHashTableOfTensorsV2SharedName(value string) MutableHashTableOfTensorsV2Attr {
return func(m optionalAttr) {
m["shared_name"] = value
}
}
// MutableHashTableOfTensorsV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value.
// If not specified, defaults to false
func MutableHashTableOfTensorsV2UseNodeNameSharing(value bool) MutableHashTableOfTensorsV2Attr {
return func(m optionalAttr) {
m["use_node_name_sharing"] = value
}
}
// MutableHashTableOfTensorsV2ValueShape sets the optional value_shape attribute to value.
// If not specified, defaults to <>
func MutableHashTableOfTensorsV2ValueShape(value tf.Shape) MutableHashTableOfTensorsV2Attr {
return func(m optionalAttr) {
m["value_shape"] = value
}
}
// Creates an empty hash table.
//
// This op creates a mutable hash table, specifying the type of its keys and
// values. Each value must be a vector. Data can be inserted into the table using
// the insert operations. It does not support the initialization operation.
//
// Arguments:
// key_dtype: Type of the table keys.
// value_dtype: Type of the table values.
//
// Returns Handle to a table.
func MutableHashTableOfTensorsV2(scope *Scope, key_dtype tf.DataType, value_dtype tf.DataType, optional ...MutableHashTableOfTensorsV2Attr) (table_handle tf.Output) {
if scope.Err() != nil {
return
}
attrs := map[string]interface{}{"key_dtype": key_dtype, "value_dtype": value_dtype}
for _, a := range optional {
a(attrs)
}
opspec := tf.OpSpec{
Type: "MutableHashTableOfTensorsV2",
Attrs: attrs,
}
op := scope.AddOperation(opspec)
return op.Output(0)
}
// ResourceApplyProximalAdagradAttr is an optional argument to ResourceApplyProximalAdagrad.
type ResourceApplyProximalAdagradAttr func(optionalAttr)
@ -13270,6 +13653,104 @@ func ReaderReadV2(scope *Scope, reader_handle tf.Output, queue_handle tf.Output)
return op.Output(0), op.Output(1)
}
// MutableDenseHashTableV2Attr is an optional argument to MutableDenseHashTableV2.
type MutableDenseHashTableV2Attr func(optionalAttr)
// MutableDenseHashTableV2Container sets the optional container attribute to value.
//
// value: If non-empty, this table is placed in the given container.
// Otherwise, a default container is used.
// If not specified, defaults to ""
func MutableDenseHashTableV2Container(value string) MutableDenseHashTableV2Attr {
return func(m optionalAttr) {
m["container"] = value
}
}
// MutableDenseHashTableV2SharedName sets the optional shared_name attribute to value.
//
// value: If non-empty, this table is shared under the given name across
// multiple sessions.
// If not specified, defaults to ""
func MutableDenseHashTableV2SharedName(value string) MutableDenseHashTableV2Attr {
return func(m optionalAttr) {
m["shared_name"] = value
}
}
// MutableDenseHashTableV2UseNodeNameSharing sets the optional use_node_name_sharing attribute to value.
// If not specified, defaults to false
func MutableDenseHashTableV2UseNodeNameSharing(value bool) MutableDenseHashTableV2Attr {
return func(m optionalAttr) {
m["use_node_name_sharing"] = value
}
}
// MutableDenseHashTableV2ValueShape sets the optional value_shape attribute to value.
//
// value: The shape of each value.
// If not specified, defaults to <>
func MutableDenseHashTableV2ValueShape(value tf.Shape) MutableDenseHashTableV2Attr {
return func(m optionalAttr) {
m["value_shape"] = value
}
}
// MutableDenseHashTableV2InitialNumBuckets sets the optional initial_num_buckets attribute to value.
//
// value: The initial number of hash table buckets. Must be a power
// to 2.
// If not specified, defaults to 131072
func MutableDenseHashTableV2InitialNumBuckets(value int64) MutableDenseHashTableV2Attr {
return func(m optionalAttr) {
m["initial_num_buckets"] = value
}
}
// MutableDenseHashTableV2MaxLoadFactor sets the optional max_load_factor attribute to value.
//
// value: The maximum ratio between number of entries and number of
// buckets before growing the table. Must be between 0 and 1.
// If not specified, defaults to 0.8
func MutableDenseHashTableV2MaxLoadFactor(value float32) MutableDenseHashTableV2Attr {
return func(m optionalAttr) {
m["max_load_factor"] = value
}
}
// Creates an empty hash table that uses tensors as the backing store. It uses
//
// "open addressing" with quadratic reprobing to resolve collisions.
//
// This op creates a mutable hash table, specifying the type of its keys and
// values. Each value must be a scalar. Data can be inserted into the table using
// the insert operations. It does not support the initialization operation.
//
// Arguments:
// empty_key: The key used to represent empty key buckets internally. Must not
// be used in insert or lookup operations.
// value_dtype: Type of the table values.
//
// Returns Handle to a table.
func MutableDenseHashTableV2(scope *Scope, empty_key tf.Output, value_dtype tf.DataType, optional ...MutableDenseHashTableV2Attr) (table_handle tf.Output) {
if scope.Err() != nil {
return
}
attrs := map[string]interface{}{"value_dtype": value_dtype}
for _, a := range optional {
a(attrs)
}
opspec := tf.OpSpec{
Type: "MutableDenseHashTableV2",
Input: []tf.Input{
empty_key,
},
Attrs: attrs,
}
op := scope.AddOperation(opspec)
return op.Output(0)
}
// LRNAttr is an optional argument to LRN.
type LRNAttr func(optionalAttr)
@ -13835,9 +14316,9 @@ func SoftsignGrad(scope *Scope, gradients tf.Output, features tf.Output) (backpr
//
// The polygamma function is defined as:
//
// ```
// \psi^{(n)}(x) = \frac{d^n}{dx^n} \psi(x)
// ```
//
// \\(\psi^{(n)}(x) = \frac{d^n}{dx^n} \psi(x)\\)
//
// where \\(\psi(x)\\) is the digamma function.
func Polygamma(scope *Scope, a tf.Output, x tf.Output) (z tf.Output) {
if scope.Err() != nil {
@ -15059,9 +15540,8 @@ func Any(scope *Scope, input tf.Output, reduction_indices tf.Output, optional ..
//
// The Hurwitz zeta function is defined as:
//
// ```
// \zeta(x, q) = \sum_{n=0}^{\infty} (q + n)^{-x}
// ```
//
// \\(\zeta(x, q) = \sum_{n=0}^{\infty} (q + n)^{-x}\\)
func Zeta(scope *Scope, x tf.Output, q tf.Output) (z tf.Output) {
if scope.Err() != nil {
return
@ -18239,13 +18719,12 @@ func Pow(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
//
// The upper regularized incomplete Gamma function is defined as:
//
// ```
// Q(a, x) = Gamma(a, x) / Gamma(a) = 1 - P(a, x)
// ```
// \\(Q(a, x) = Gamma(a, x) / Gamma(a) = 1 - P(a, x)\\)
//
// where
// ```
// Gamma(a, x) = int_{x}^{\infty} t^{a-1} exp(-t) dt
// ```
//
// \\(Gamma(a, x) = int_{x}^{\infty} t^{a-1} exp(-t) dt\\)
//
// is the upper incomplete Gama function.
//
// Note, above `P(a, x)` (`Igamma`) is the lower regularized complete
@ -18268,13 +18747,13 @@ func Igammac(scope *Scope, a tf.Output, x tf.Output) (z tf.Output) {
//
// The lower regularized incomplete Gamma function is defined as:
//
// ```
// P(a, x) = gamma(a, x) / Gamma(a) = 1 - Q(a, x)
// ```
//
// \\(P(a, x) = gamma(a, x) / Gamma(a) = 1 - Q(a, x)\\)
//
// where
// ```
// gamma(a, x) = int_{0}^{x} t^{a-1} exp(-t) dt
// ```
//
// \\(gamma(a, x) = int_{0}^{x} t^{a-1} exp(-t) dt\\)
//
// is the lower incomplete Gamma function.
//
// Note, above `Q(a, x)` (`Igammac`) is the upper regularized complete
@ -18297,14 +18776,14 @@ func Igamma(scope *Scope, a tf.Output, x tf.Output) (z tf.Output) {
//
// The regularized incomplete beta integral is defined as:
//
// ```
// I_x(a, b) = \frac{B(x; a, b)}{B(a, b)}
// ```
//
// \\(I_x(a, b) = \frac{B(x; a, b)}{B(a, b)}\\)
//
// where
//
// ```
// B(x; a, b) = \int_0^x t^{a-1} (1 - t)^{b-1} dt
// ```
//
// \\(B(x; a, b) = \int_0^x t^{a-1} (1 - t)^{b-1} dt\\)
//
//
// is the incomplete beta function and \\(B(a, b)\\) is the *complete*
// beta function.
@ -18744,6 +19223,44 @@ func SparseReshape(scope *Scope, input_indices tf.Output, input_shape tf.Output,
return op.Output(0), op.Output(1)
}
// Bucketizes 'input' based on 'boundaries'.
//
// For example, if the inputs are
// boundaries = [0, 10, 100]
// input = [[-5, 10000]
// [150, 10]
// [5, 100]]
//
// then the output will be
// output = [[0, 3]
// [3, 2]
// [1, 3]]
//
// Arguments:
// input: Any shape of Tensor contains with int or float type.
// boundaries: A sorted list of floats gives the boundary of the buckets.
//
// Returns Same shape with 'input', each value of input replaced with bucket index.
//
// @compatibility(numpy)
// Equivalent to np.digitize.
// @end_compatibility
func Bucketize(scope *Scope, input tf.Output, boundaries []float32) (output tf.Output) {
if scope.Err() != nil {
return
}
attrs := map[string]interface{}{"boundaries": boundaries}
opspec := tf.OpSpec{
Type: "Bucketize",
Input: []tf.Input{
input,
},
Attrs: attrs,
}
op := scope.AddOperation(opspec)
return op.Output(0)
}
// Computes the product along segments of a tensor.
//
// Read @{$math_ops#segmentation$the section on segmentation} for an explanation of
@ -20516,6 +21033,26 @@ func SoftmaxCrossEntropyWithLogits(scope *Scope, features tf.Output, labels tf.O
return op.Output(0), op.Output(1)
}
// Computes the number of elements in the given table.
//
// Arguments:
// table_handle: Handle to the table.
//
// Returns Scalar that contains number of elements in the table.
func LookupTableSizeV2(scope *Scope, table_handle tf.Output) (size tf.Output) {
if scope.Err() != nil {
return
}
opspec := tf.OpSpec{
Type: "LookupTableSizeV2",
Input: []tf.Input{
table_handle,
},
}
op := scope.AddOperation(opspec)
return op.Output(0)
}
// ResizeBilinearGradAttr is an optional argument to ResizeBilinearGrad.
type ResizeBilinearGradAttr func(optionalAttr)

View File

@ -328,7 +328,7 @@ class OpDefLibrary(object):
# Need to flatten all the arguments into a list.
# pylint: disable=protected-access
g = ops._get_graph_from_inputs(_Flatten(keywords.values()))
# pyline: enable=protected-access
# pylint: enable=protected-access
except AssertionError as e:
raise RuntimeError(
"Cannot determine graph for Op '%s' due to: %s"

View File

@ -3981,9 +3981,13 @@ class GraphKeys(object):
for more details.
* `REGULARIZATION_LOSSES`: regularization losses collected during graph
construction.
* `WEIGHTS`: weights inside neural network layers
* `BIASES`: biases inside neural network layers
* `ACTIVATIONS`: activations of neural network layers
The following standard keys are _defined_, but their collections are **not**
automatically populated as many of the others are:
* `WEIGHTS`
* `BIASES`
* `ACTIVATIONS`
"""
# Key to collect Variable objects that are global (shared across machines).

View File

@ -629,10 +629,10 @@ static string GetPythonOp(const OpDef& op_def, bool is_hidden, string op_name) {
}
void GenerateLowerCaseOpName(const string& str, string* result) {
char joiner = '_';
int last_index = str.size() - 1;
const char joiner = '_';
const int last_index = str.size() - 1;
for (int i = 0; i <= last_index; ++i) {
char c = str[i];
const char c = str[i];
// Emit a joiner only if a previous-lower-to-now-upper or a
// now-upper-to-next-lower transition happens.
if (isupper(c) && (i > 0)) {
@ -731,8 +731,8 @@ void PrintPythonOps(const OpList& ops, const std::vector<string>& hidden_ops,
printf("%s", GetPythonOps(ops, hidden_ops, require_shapes).c_str());
}
string GetPythonWrappers(const char* op_wrapper_buf, size_t op_wrapper_len) {
string op_list_str(op_wrapper_buf, op_wrapper_len);
string GetPythonWrappers(const char* op_list_buf, size_t op_list_len) {
string op_list_str(op_list_buf, op_list_len);
OpList ops;
ops.ParseFromString(op_list_str);
return GetPythonOps(ops, {}, false);

View File

@ -33,9 +33,10 @@ string GetPythonOps(const OpList& ops, const std::vector<string>& hidden_ops,
bool require_shapes);
// Get the python wrappers for a list of ops in a OpList.
// buf should be a pointer to a buffer containing the binary encoded OpList
// proto, and len should be the length of that buffer.
string GetPythonWrappers(const char* op_wrapper_buf, size_t op_wrapper_len);
// `op_list_buf` should be a pointer to a buffer containing
// the binary encoded OpList proto, and `op_list_len` should be the
// length of that buffer.
string GetPythonWrappers(const char* op_list_buf, size_t op_list_len);
} // namespace tensorflow

View File

@ -25,7 +25,7 @@ limitations under the License.
// going from python bytes to const char* tries to decode the
// contents from utf-8 to unicode for Python version >= 3, but
// we want the bytes to be uninterpreted.
%typemap(in) (const char* op_wrapper_buf, size_t op_wrapper_len) {
%typemap(in) (const char* op_list_buf, size_t op_list_len) {
char* c_string;
Py_ssize_t py_size;
if (PyBytes_AsStringAndSize($input, &c_string, &py_size) == -1) {

View File

@ -2569,6 +2569,18 @@ tf_py_test(
],
)
tf_py_test(
name = "bucketize_op_test",
size = "small",
srcs = ["bucketize_op_test.py"],
additional_deps = [
"//third_party/py/numpy",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
],
)
filegroup(
name = "all_files",
srcs = glob(

View File

@ -12,35 +12,57 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for bucketization_op."""
"""Tests for bucketize_op."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.layers.python.ops import bucketization_op
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
class BucketizationOpTest(test.TestCase):
def test_normal_usecase(self):
op = bucketization_op.bucketize(
def testInt(self):
op = math_ops._bucketize(
constant_op.constant([-5, 0, 2, 3, 5, 8, 10, 11, 12]),
boundaries=[0, 3, 8, 11])
expected_out = [0, 1, 1, 2, 2, 3, 3, 4, 4]
with self.test_session() as sess:
self.assertAllEqual(expected_out, sess.run(op))
def test_invalid_boundaries_order(self):
op = bucketization_op.bucketize(
def testFloat(self):
op = math_ops._bucketize(
constant_op.constant([-5., 0., 2., 3., 5., 8., 10., 11., 12.]),
boundaries=[0., 3., 8., 11.])
expected_out = [0, 1, 1, 2, 2, 3, 3, 4, 4]
with self.test_session() as sess:
self.assertAllEqual(expected_out, sess.run(op))
def test2DInput(self):
op = math_ops._bucketize(
constant_op.constant([[-5, 0, 2, 3, 5], [8, 10, 11, 12, 0]]),
boundaries=[0, 3, 8, 11])
expected_out = [[0, 1, 1, 2, 2], [3, 3, 4, 4, 1]]
with self.test_session() as sess:
self.assertAllEqual(expected_out, sess.run(op))
def testInvalidBoundariesOrder(self):
op = math_ops._bucketize(
constant_op.constant([-5, 0]), boundaries=[0, 8, 3, 11])
with self.test_session() as sess:
with self.assertRaises(errors_impl.InvalidArgumentError):
with self.assertRaisesRegexp(
errors_impl.InvalidArgumentError, "Expected sorted boundaries"):
sess.run(op)
def testBoundariesNotList(self):
with self.assertRaisesRegexp(
TypeError, "Expected list for attr boundaries"):
math_ops._bucketize(constant_op.constant([-5, 0]), boundaries=0)
if __name__ == "__main__":
test.main()

View File

@ -29,6 +29,52 @@ cuda_py_test(
],
)
cuda_py_test(
name = "kullback_leibler_test",
size = "small",
srcs = ["kullback_leibler_test.py"],
additional_deps = [
"//tensorflow/python/ops/distributions",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:platform_test",
],
)
cuda_py_test(
name = "normal_test",
size = "medium",
srcs = ["normal_test.py"],
additional_deps = [
"//tensorflow/python/ops/distributions",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:gradients",
"//tensorflow/python:nn_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:variables",
],
)
cuda_py_test(
name = "special_math_test",
size = "medium",
srcs = ["special_math_test.py"],
additional_deps = [
"//tensorflow/python/ops/distributions",
"//third_party/py/numpy",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:gradients",
"//tensorflow/python:platform_test",
"//tensorflow/python:variables",
],
)
filegroup(
name = "all_files",
srcs = glob(

View File

@ -18,9 +18,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.distributions.python.ops import kullback_leibler
from tensorflow.contrib.distributions.python.ops import normal
from tensorflow.python.ops import array_ops
from tensorflow.python.ops.distributions import kullback_leibler
from tensorflow.python.ops.distributions import normal
from tensorflow.python.platform import test
# pylint: disable=protected-access
@ -43,7 +43,7 @@ class KLTest(test.TestCase):
return name
a = MyDist(loc=0.0, scale=1.0)
self.assertEqual("OK", kullback_leibler.kl(a, a, name="OK"))
self.assertEqual("OK", kullback_leibler.kl_divergence(a, a, name="OK"))
def testDomainErrorExceptions(self):
@ -60,11 +60,11 @@ class KLTest(test.TestCase):
with self.test_session():
a = MyDistException(loc=0.0, scale=1.0)
kl = kullback_leibler.kl(a, a, allow_nan_stats=False)
kl = kullback_leibler.kl_divergence(a, a, allow_nan_stats=False)
with self.assertRaisesOpError(
"KL calculation between .* and .* returned NaN values"):
kl.eval()
kl_ok = kullback_leibler.kl(a, a)
kl_ok = kullback_leibler.kl_divergence(a, a)
self.assertAllEqual([float("nan")], kl_ok.eval())
def testRegistrationFailures(self):
@ -116,16 +116,16 @@ class KLTest(test.TestCase):
sub2 = Sub2(loc=0.0, scale=1.0)
sub11 = Sub11(loc=0.0, scale=1.0)
self.assertEqual("sub1-1", kullback_leibler.kl(sub1, sub1))
self.assertEqual("sub1-2", kullback_leibler.kl(sub1, sub2))
self.assertEqual("sub2-1", kullback_leibler.kl(sub2, sub1))
self.assertEqual("sub1-1", kullback_leibler.kl(sub11, sub11))
self.assertEqual("sub1-1", kullback_leibler.kl(sub11, sub1))
self.assertEqual("sub1-2", kullback_leibler.kl(sub11, sub2))
self.assertEqual("sub1-1", kullback_leibler.kl(sub11, sub1))
self.assertEqual("sub1-2", kullback_leibler.kl(sub11, sub2))
self.assertEqual("sub2-1", kullback_leibler.kl(sub2, sub11))
self.assertEqual("sub1-1", kullback_leibler.kl(sub1, sub11))
self.assertEqual("sub1-1", kullback_leibler.kl_divergence(sub1, sub1))
self.assertEqual("sub1-2", kullback_leibler.kl_divergence(sub1, sub2))
self.assertEqual("sub2-1", kullback_leibler.kl_divergence(sub2, sub1))
self.assertEqual("sub1-1", kullback_leibler.kl_divergence(sub11, sub11))
self.assertEqual("sub1-1", kullback_leibler.kl_divergence(sub11, sub1))
self.assertEqual("sub1-2", kullback_leibler.kl_divergence(sub11, sub2))
self.assertEqual("sub1-1", kullback_leibler.kl_divergence(sub11, sub1))
self.assertEqual("sub1-2", kullback_leibler.kl_divergence(sub11, sub2))
self.assertEqual("sub2-1", kullback_leibler.kl_divergence(sub2, sub11))
self.assertEqual("sub1-1", kullback_leibler.kl_divergence(sub1, sub11))
if __name__ == "__main__":

View File

@ -18,13 +18,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import importlib
import math
import numpy as np
from scipy import stats
from tensorflow.contrib.distributions.python.ops import kullback_leibler
from tensorflow.contrib.distributions.python.ops import normal as normal_lib
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@ -33,7 +31,21 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.distributions import kullback_leibler
from tensorflow.python.ops.distributions import normal as normal_lib
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging
def try_import(name): # pylint: disable=invalid-name
module = None
try:
module = importlib.import_module(name)
except ImportError as e:
tf_logging.warning("Could not import %s: %s" % (name, str(e)))
return module
stats = try_import("scipy.stats")
class NormalTest(test.TestCase):
@ -90,10 +102,8 @@ class NormalTest(test.TestCase):
sigma = constant_op.constant([math.sqrt(10.0)] * batch_size)
x = np.array([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0], dtype=np.float32)
normal = normal_lib.Normal(loc=mu, scale=sigma)
expected_log_pdf = stats.norm(mu.eval(), sigma.eval()).logpdf(x)
log_pdf = normal.log_prob(x)
self.assertAllClose(expected_log_pdf, log_pdf.eval())
self.assertAllEqual(normal.batch_shape_tensor().eval(),
log_pdf.get_shape())
self.assertAllEqual(normal.batch_shape_tensor().eval(),
@ -102,12 +112,17 @@ class NormalTest(test.TestCase):
self.assertAllEqual(normal.batch_shape, log_pdf.eval().shape)
pdf = normal.prob(x)
self.assertAllClose(np.exp(expected_log_pdf), pdf.eval())
self.assertAllEqual(normal.batch_shape_tensor().eval(), pdf.get_shape())
self.assertAllEqual(normal.batch_shape_tensor().eval(), pdf.eval().shape)
self.assertAllEqual(normal.batch_shape, pdf.get_shape())
self.assertAllEqual(normal.batch_shape, pdf.eval().shape)
if not stats:
return
expected_log_pdf = stats.norm(mu.eval(), sigma.eval()).logpdf(x)
self.assertAllClose(expected_log_pdf, log_pdf.eval())
self.assertAllClose(np.exp(expected_log_pdf), pdf.eval())
def testNormalLogPDFMultidimensional(self):
with self.test_session():
batch_size = 6
@ -116,12 +131,10 @@ class NormalTest(test.TestCase):
batch_size)
x = np.array([[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]], dtype=np.float32).T
normal = normal_lib.Normal(loc=mu, scale=sigma)
expected_log_pdf = stats.norm(mu.eval(), sigma.eval()).logpdf(x)
log_pdf = normal.log_prob(x)
log_pdf_values = log_pdf.eval()
self.assertEqual(log_pdf.get_shape(), (6, 2))
self.assertAllClose(expected_log_pdf, log_pdf_values)
self.assertAllEqual(normal.batch_shape_tensor().eval(),
log_pdf.get_shape())
self.assertAllEqual(normal.batch_shape_tensor().eval(),
@ -132,12 +145,17 @@ class NormalTest(test.TestCase):
pdf = normal.prob(x)
pdf_values = pdf.eval()
self.assertEqual(pdf.get_shape(), (6, 2))
self.assertAllClose(np.exp(expected_log_pdf), pdf_values)
self.assertAllEqual(normal.batch_shape_tensor().eval(), pdf.get_shape())
self.assertAllEqual(normal.batch_shape_tensor().eval(), pdf_values.shape)
self.assertAllEqual(normal.batch_shape, pdf.get_shape())
self.assertAllEqual(normal.batch_shape, pdf_values.shape)
if not stats:
return
expected_log_pdf = stats.norm(mu.eval(), sigma.eval()).logpdf(x)
self.assertAllClose(expected_log_pdf, log_pdf_values)
self.assertAllClose(np.exp(expected_log_pdf), pdf_values)
def testNormalCDF(self):
with self.test_session():
batch_size = 50
@ -146,14 +164,15 @@ class NormalTest(test.TestCase):
x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64)
normal = normal_lib.Normal(loc=mu, scale=sigma)
expected_cdf = stats.norm(mu, sigma).cdf(x)
cdf = normal.cdf(x)
self.assertAllClose(expected_cdf, cdf.eval(), atol=0)
self.assertAllEqual(normal.batch_shape_tensor().eval(), cdf.get_shape())
self.assertAllEqual(normal.batch_shape_tensor().eval(), cdf.eval().shape)
self.assertAllEqual(normal.batch_shape, cdf.get_shape())
self.assertAllEqual(normal.batch_shape, cdf.eval().shape)
if not stats:
return
expected_cdf = stats.norm(mu, sigma).cdf(x)
self.assertAllClose(expected_cdf, cdf.eval(), atol=0)
def testNormalSurvivalFunction(self):
with self.test_session():
@ -163,14 +182,16 @@ class NormalTest(test.TestCase):
x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64)
normal = normal_lib.Normal(loc=mu, scale=sigma)
expected_sf = stats.norm(mu, sigma).sf(x)
sf = normal.survival_function(x)
self.assertAllClose(expected_sf, sf.eval(), atol=0)
self.assertAllEqual(normal.batch_shape_tensor().eval(), sf.get_shape())
self.assertAllEqual(normal.batch_shape_tensor().eval(), sf.eval().shape)
self.assertAllEqual(normal.batch_shape, sf.get_shape())
self.assertAllEqual(normal.batch_shape, sf.eval().shape)
if not stats:
return
expected_sf = stats.norm(mu, sigma).sf(x)
self.assertAllClose(expected_sf, sf.eval(), atol=0)
def testNormalLogCDF(self):
with self.test_session():
@ -180,15 +201,18 @@ class NormalTest(test.TestCase):
x = np.linspace(-100.0, 10.0, batch_size).astype(np.float64)
normal = normal_lib.Normal(loc=mu, scale=sigma)
expected_cdf = stats.norm(mu, sigma).logcdf(x)
cdf = normal.log_cdf(x)
self.assertAllClose(expected_cdf, cdf.eval(), atol=0, rtol=1e-5)
self.assertAllEqual(normal.batch_shape_tensor().eval(), cdf.get_shape())
self.assertAllEqual(normal.batch_shape_tensor().eval(), cdf.eval().shape)
self.assertAllEqual(normal.batch_shape, cdf.get_shape())
self.assertAllEqual(normal.batch_shape, cdf.eval().shape)
if not stats:
return
expected_cdf = stats.norm(mu, sigma).logcdf(x)
self.assertAllClose(expected_cdf, cdf.eval(), atol=0, rtol=1e-5)
def testFiniteGradientAtDifficultPoints(self):
for dtype in [np.float32, np.float64]:
g = ops.Graph()
@ -217,15 +241,18 @@ class NormalTest(test.TestCase):
x = np.linspace(-10.0, 100.0, batch_size).astype(np.float64)
normal = normal_lib.Normal(loc=mu, scale=sigma)
expected_sf = stats.norm(mu, sigma).logsf(x)
sf = normal.log_survival_function(x)
self.assertAllClose(expected_sf, sf.eval(), atol=0, rtol=1e-5)
self.assertAllEqual(normal.batch_shape_tensor().eval(), sf.get_shape())
self.assertAllEqual(normal.batch_shape_tensor().eval(), sf.eval().shape)
self.assertAllEqual(normal.batch_shape, sf.get_shape())
self.assertAllEqual(normal.batch_shape, sf.eval().shape)
if not stats:
return
expected_sf = stats.norm(mu, sigma).logsf(x)
self.assertAllClose(expected_sf, sf.eval(), atol=0, rtol=1e-5)
def testNormalEntropyWithScalarInputs(self):
# Scipy.stats.norm cannot deal with the shapes in the other test.
with self.test_session():
@ -233,16 +260,18 @@ class NormalTest(test.TestCase):
sigma_v = 4.56
normal = normal_lib.Normal(loc=mu_v, scale=sigma_v)
# scipy.stats.norm cannot deal with these shapes.
expected_entropy = stats.norm(mu_v, sigma_v).entropy()
entropy = normal.entropy()
self.assertAllClose(expected_entropy, entropy.eval())
self.assertAllEqual(normal.batch_shape_tensor().eval(),
entropy.get_shape())
self.assertAllEqual(normal.batch_shape_tensor().eval(),
entropy.eval().shape)
self.assertAllEqual(normal.batch_shape, entropy.get_shape())
self.assertAllEqual(normal.batch_shape, entropy.eval().shape)
# scipy.stats.norm cannot deal with these shapes.
if not stats:
return
expected_entropy = stats.norm(mu_v, sigma_v).entropy()
self.assertAllClose(expected_entropy, entropy.eval())
def testNormalEntropy(self):
with self.test_session():
@ -288,15 +317,18 @@ class NormalTest(test.TestCase):
p = np.hstack((p, np.exp(-33), 1. - np.exp(-33)))
normal = normal_lib.Normal(loc=mu, scale=sigma)
expected_x = stats.norm(mu, sigma).ppf(p)
x = normal.quantile(p)
self.assertAllClose(expected_x, x.eval(), atol=0.)
self.assertAllEqual(normal.batch_shape_tensor().eval(), x.get_shape())
self.assertAllEqual(normal.batch_shape_tensor().eval(), x.eval().shape)
self.assertAllEqual(normal.batch_shape, x.get_shape())
self.assertAllEqual(normal.batch_shape, x.eval().shape)
if not stats:
return
expected_x = stats.norm(mu, sigma).ppf(p)
self.assertAllClose(expected_x, x.eval(), atol=0.)
def _baseQuantileFiniteGradientAtDifficultPoints(self, dtype):
g = ops.Graph()
with g.as_default():
@ -450,7 +482,7 @@ class NormalTest(test.TestCase):
n_a = normal_lib.Normal(loc=mu_a, scale=sigma_a)
n_b = normal_lib.Normal(loc=mu_b, scale=sigma_b)
kl = kullback_leibler.kl(n_a, n_b)
kl = kullback_leibler.kl_divergence(n_a, n_b)
kl_val = sess.run(kl)
kl_expected = ((mu_a - mu_b)**2 / (2 * sigma_b**2) + 0.5 * (

View File

@ -19,18 +19,30 @@ from __future__ import division
from __future__ import print_function
import collections
import importlib
import numpy as np
from scipy import special
from scipy import stats
from tensorflow.contrib.distributions.python.ops import special_math
from tensorflow.python.framework import ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import variables
from tensorflow.python.ops.distributions import special_math
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging
def try_import(name): # pylint: disable=invalid-name
module = None
try:
module = importlib.import_module(name)
except ImportError as e:
tf_logging.warning("Could not import %s: %s" % (name, str(e)))
return module
special = try_import("scipy.special")
stats = try_import("scipy.stats")
sm = special_math
@ -66,6 +78,9 @@ class NdtriTest(test.TestCase):
def testNdtri(self):
"""Verifies that ndtri computation is correct."""
with self.test_session():
if not special:
return
p = np.linspace(0., 1.0, 50).astype(np.float64)
# Quantile performs piecewise rational approximation so adding some
# special input values to make sure we hit all the pieces.
@ -113,6 +128,9 @@ class NdtrTest(test.TestCase):
self._test_grid_no_log(dtype, grid_spec, error_spec)
def _test_grid_log(self, dtype, grid_spec, error_spec):
if not special:
return
with self.test_session():
grid = _make_grid(dtype, grid_spec)
actual = sm.log_ndtr(grid).eval()
@ -137,6 +155,9 @@ class NdtrTest(test.TestCase):
atol=error_spec.atol)
def _test_grid_no_log(self, dtype, grid_spec, error_spec):
if not special:
return
with self.test_session():
grid = _make_grid(dtype, grid_spec)
actual = sm.ndtr(grid).eval()
@ -267,6 +288,9 @@ class NdtrGradientTest(test.TestCase):
self.assert_all_true(np.isfinite(grad_eval))
# Versus scipy.
if not (special and stats):
return
expected = stats.norm.pdf(raw_grid)
if self._use_log:
expected /= special.ndtr(raw_grid)
@ -323,6 +347,9 @@ class LogCDFLaplaceTest(test.TestCase):
_check_strictly_increasing(actual)
# Versus scipy.
if not stats:
return
scipy_dist = stats.laplace(loc=0., scale=1.)
expected = scipy_dist.logcdf(grid.astype(scipy_dtype))
self.assertAllClose(

View File

@ -1128,7 +1128,7 @@ class TensorArrayTest(test.TestCase):
dtype=dtypes.float32, size=0, dynamic_size=False, infer_shape=True)
self.assertEqual(0, ta.size().eval())
# Don't actually perform the pack. This stores the static shape.
ta.unstack(array_ops.zeros([0, 3, 5]))
ta.unstack(array_ops.zeros([0, 3, 5])).mark_used()
packed = ta.stack()
self.assertAllEqual([0, 3, 5], packed.eval().shape)
# Concatenating zero tensors along their first dimension gives a

View File

@ -71,7 +71,9 @@ from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops.gen_control_flow_ops import *
# pylint: enable=wildcard-import
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import deprecation
from tensorflow.python.util import nest
from tensorflow.python.util import tf_should_use
# We override the 'tuple' for a control flow op, so we keep python's
@ -84,6 +86,7 @@ _basetuple = tuple
# Assert and Print are special symbols in python, so we must
# use an upper-case version of them.
@tf_should_use.should_use_result
def Assert(condition, data, summarize=None, name=None):
"""Asserts that the given condition is true.
@ -1677,14 +1680,20 @@ def _UnpackIfSingleton(res):
return res
def cond(pred, fn1, fn2, strict=False, name=None):
"""Return either `fn1()` or `fn2()` based on the boolean predicate `pred`.
# pylint: disable=g-doc-args
@deprecation.deprecated_args(
None,
"fn1/fn2 are deprecated in favor of the true_fn/false_fn arguments.",
"fn1", "fn2")
def cond(pred, true_fn=None, false_fn=None, strict=False, name=None,
fn1=None, fn2=None):
"""Return `true_fn()` if the predicate `pred` is true else `false_fn()`.
`fn1` and `fn2` both return lists of output tensors. `fn1` and `fn2` must have
the same non-zero number and type of outputs.
`true_fn` and `false_fn` both return lists of output tensors. `true_fn` and
`false_fn` must have the same non-zero number and type of outputs.
Note that the conditional execution applies only to the operations defined in
`fn1` and `fn2`. Consider the following simple program:
`true_fn` and `false_fn`. Consider the following simple program:
```python
z = tf.multiply(a, b)
@ -1698,28 +1707,35 @@ def cond(pred, fn1, fn2, strict=False, name=None):
Although this behavior is consistent with the dataflow model of TensorFlow,
it has occasionally surprised some users who expected a lazier semantics.
Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the
call to `cond`, and not at all during `Session.run()`). `cond`
stitches together the graph fragments created during the `true_fn` and
`false_fn` calls with some additional graph nodes to ensure that the right
branch gets executed depending on the value of `pred`.
`tf.cond` supports nested structures as implemented in
`tensorflow.python.util.nest`. Both `fn1` and `fn2` must return the same
(possibly nested) value structure of lists, tuples, and/or named tuples.
`tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the
same (possibly nested) value structure of lists, tuples, and/or named tuples.
Singleton lists and tuples form the only exceptions to this: when returned by
`fn1` and/or `fn2`, they are implicitly unpacked to single values. This
behavior is disabled by passing `strict=True`.
`true_fn` and/or `false_fn`, they are implicitly unpacked to single values.
This behavior is disabled by passing `strict=True`.
Args:
pred: A scalar determining whether to return the result of `fn1` or `fn2`.
fn1: The callable to be performed if pred is true.
fn2: The callable to be performed if pred is false.
pred: A scalar determining whether to return the result of `true_fn` or
`false_fn`.
true_fn: The callable to be performed if pred is true.
false_fn: The callable to be performed if pred is false.
strict: A boolean that enables/disables 'strict' mode; see above.
name: Optional name prefix for the returned tensors.
Returns:
Tensors returned by the call to either `fn1` or `fn2`. If the callables
return a singleton list, the element is extracted from the list.
Tensors returned by the call to either `true_fn` or `false_fn`. If the
callables return a singleton list, the element is extracted from the list.
Raises:
TypeError: if `fn1` or `fn2` is not callable.
ValueError: if `fn1` and `fn2` do not return the same number of tensors, or
return tensors of different types.
TypeError: if `true_fn` or `false_fn` is not callable.
ValueError: if `true_fn` and `false_fn` do not return the same number of
tensors, or return tensors of different types.
Example:
@ -1734,12 +1750,30 @@ def cond(pred, fn1, fn2, strict=False, name=None):
```
"""
with ops.name_scope(name, "cond", [pred]) as name:
if not callable(fn1):
raise TypeError("fn1 must be callable.")
if not callable(fn2):
raise TypeError("fn2 must be callable.")
# We needed to make true_fn/false_fn keyword arguments for
# backwards-compatibility. This check exists so that we can convert back to
# having them be positional arguments.
# TODO(josh11b): Make `true_fn` and `false_fn` positional arguments after
# `fn1` and `fn2` are deleted.
if fn1 is not None:
if true_fn is not None:
raise TypeError("cond(): true_fn and fn1 may not be set simultaneously.")
true_fn = fn1
elif true_fn is None:
raise TypeError("cond(): true_fn argument required")
if fn2 is not None:
if false_fn is not None:
raise TypeError("cond(): false_fn and fn2 may not be set simultaneously.")
false_fn = fn2
elif false_fn is None:
raise TypeError("cond(): false_fn argument required")
if not callable(true_fn):
raise TypeError("true_fn must be callable.")
if not callable(false_fn):
raise TypeError("false_fn must be callable.")
with ops.name_scope(name, "cond", [pred]) as name:
# Add the Switch to the graph.
if isinstance(pred, bool):
raise TypeError("pred must not be a Python bool")
@ -1754,18 +1788,18 @@ def cond(pred, fn1, fn2, strict=False, name=None):
# Build the graph for the true branch in a new context.
context_t = CondContext(pred, pivot_1, branch=1)
context_t.Enter()
orig_res_t, res_t = context_t.BuildCondBranch(fn1)
orig_res_t, res_t = context_t.BuildCondBranch(true_fn)
if orig_res_t is None:
raise ValueError("fn1 must have a return value.")
raise ValueError("true_fn must have a return value.")
context_t.ExitResult(res_t)
context_t.Exit()
# Build the graph for the false branch in a new context.
context_f = CondContext(pred, pivot_2, branch=0)
context_f.Enter()
orig_res_f, res_f = context_f.BuildCondBranch(fn2)
orig_res_f, res_f = context_f.BuildCondBranch(false_fn)
if orig_res_f is None:
raise ValueError("fn2 must have a return value.")
raise ValueError("false_fn must have a return value.")
context_f.ExitResult(res_f)
context_f.Exit()
@ -1778,14 +1812,14 @@ def cond(pred, fn1, fn2, strict=False, name=None):
nest.assert_same_structure(orig_res_t, orig_res_f)
except TypeError as e:
raise TypeError(
"Incompatible return types of fn1 and fn2: {}".format(e))
"Incompatible return types of true_fn and false_fn: {}".format(e))
except ValueError as e:
raise ValueError(
"Incompatible return values of fn1 and fn2: {}".format(e))
"Incompatible return values of true_fn and false_fn: {}".format(e))
# Add the final merge to the graph.
if not res_t:
raise ValueError("fn1 and fn2 must return at least one result.")
raise ValueError("true_fn and false_fn must return at least one result.")
res_t_flat = nest.flatten(res_t)
res_f_flat = nest.flatten(res_f)
@ -1799,8 +1833,9 @@ def cond(pred, fn1, fn2, strict=False, name=None):
val_x = x if isinstance(x, ops.Tensor) else x.values
val_y = y if isinstance(y, ops.Tensor) else y.values
if val_x.dtype.base_dtype != val_y.dtype.base_dtype:
raise ValueError("Outputs of fn1 and fn2 must have the same type: "
"%s, %s" % (val_x.dtype.name, val_y.dtype.name))
raise ValueError(
"Outputs of true_fn and false_fn must have the same type: %s, %s" %
(val_x.dtype.name, val_y.dtype.name))
merges = [merge(pair)[0] for pair in zip(res_f_flat, res_t_flat)]
merges = _convert_flows_to_tensorarrays(nest.flatten(orig_res_t), merges)
@ -1815,6 +1850,7 @@ def cond(pred, fn1, fn2, strict=False, name=None):
if not strict:
merges = _UnpackIfSingleton(merges)
return merges
# pylint: enable=g-doc-args
def _resource_safe_shape(t):
@ -2546,12 +2582,16 @@ def while_loop(cond, body, loop_vars, shape_invariants=None,
`cond` and `body`. `cond` and `body` both take as many arguments as there are
`loop_vars`.
While `cond` evaluates to true, `body` is executed.
In addition to regular Tensors or IndexedSlices, the body may accept and
return TensorArray objects. The flows of the TensorArray objects will
be appropriately forwarded between loops and during gradient calculations.
Note that `while_loop` calls `cond` and `body` *exactly once* (inside the
call to `while_loop`, and not at all during `Session.run()`). `while_loop`
stitches together the graph fragments created during the `cond` and `body`
calls with some additional graph nodes to make something the repeats
`body` until `cond` returns false.
For correctness, `tf.while_loop()` strictly enforces shape invariants for
the loop variables. A shape invariant is a (possibly partial) shape that
is unchanged across the iterations of the loop. An error will be raised
@ -2880,10 +2920,10 @@ def case(pred_fn_pairs, default, exclusive=False, strict=False, name="case"):
operation returns the tensors generated by `default`.
`tf.case` supports nested structures as implemented in
`tensorflow.python.util.nest`. Both `fn1` and `fn2` must return the same
`tensorflow.python.util.nest`. All of the callables must return the same
(possibly nested) value structure of lists, tuples, and/or named tuples.
Singleton lists and tuples form the only exceptions to this: when returned by
`fn1` and/or `fn2`, they are implicitly unpacked to single values. This
a callable, they are implicitly unpacked to single values. This
behavior is disabled by passing `strict=True`.
Example 1:
@ -2911,9 +2951,6 @@ def case(pred_fn_pairs, default, exclusive=False, strict=False, name="case"):
Expressions:
```
x = tf.constant(0)
y = tf.constant(1)
z = tf.constant(2)
def f1(): return tf.constant(17)
def f2(): return tf.constant(23)
def f3(): return tf.constant(-1)

View File

@ -324,6 +324,69 @@ class SwitchTestCase(TensorFlowTestCase):
self.assertEquals(grad_x_false.eval(), 0.)
class CondTest(TensorFlowTestCase):
def testCondTrue(self):
with self.test_session():
x = constant_op.constant(2)
y = constant_op.constant(5)
z = control_flow_ops.cond(
math_ops.less(x, y), lambda: math_ops.multiply(x, 17),
lambda: math_ops.add(y, 23))
self.assertEquals(z.eval(), 34)
def testCondFalse(self):
with self.test_session():
x = constant_op.constant(2)
y = constant_op.constant(1)
z = control_flow_ops.cond(
math_ops.less(x, y), lambda: math_ops.multiply(x, 17),
lambda: math_ops.add(y, 23))
self.assertEquals(z.eval(), 24)
def testCondTrueLegacy(self):
with self.test_session():
x = constant_op.constant(2)
y = constant_op.constant(5)
z = control_flow_ops.cond(
math_ops.less(x, y), fn1=lambda: math_ops.multiply(x, 17),
fn2=lambda: math_ops.add(y, 23))
self.assertEquals(z.eval(), 34)
def testCondFalseLegacy(self):
with self.test_session():
x = constant_op.constant(2)
y = constant_op.constant(1)
z = control_flow_ops.cond(
math_ops.less(x, y), fn1=lambda: math_ops.multiply(x, 17),
fn2=lambda: math_ops.add(y, 23))
self.assertEquals(z.eval(), 24)
def testCondMissingArg1(self):
with self.test_session():
x = constant_op.constant(1)
with self.assertRaises(TypeError):
control_flow_ops.cond(True, false_fn=lambda: x)
def testCondMissingArg2(self):
with self.test_session():
x = constant_op.constant(1)
with self.assertRaises(TypeError):
control_flow_ops.cond(True, lambda: x)
def testCondDuplicateArg1(self):
with self.test_session():
x = constant_op.constant(1)
with self.assertRaises(TypeError):
control_flow_ops.cond(True, lambda: x, lambda: x, fn1=lambda: x)
def testCondDuplicateArg2(self):
with self.test_session():
x = constant_op.constant(1)
with self.assertRaises(TypeError):
control_flow_ops.cond(True, lambda: x, lambda: x, fn2=lambda: x)
class ContextTest(TensorFlowTestCase):
def testCondContext(self):

View File

@ -44,11 +44,13 @@ def _registered_kl(type_a, type_b):
return kl_fn
def kl(dist_a, dist_b, allow_nan_stats=True, name=None):
"""Get the KL-divergence KL(dist_a || dist_b).
def kl_divergence(distribution_a, distribution_b,
allow_nan_stats=True, name=None):
"""Get the KL-divergence KL(distribution_a || distribution_b).
If there is no KL method registered specifically for `type(dist_a)` and
`type(dist_b)`, then the class hierarchies of these types are searched.
If there is no KL method registered specifically for `type(distribution_a)`
and `type(distribution_b)`, then the class hierarchies of these types are
searched.
If one KL method is registered between any pairs of classes in these two
parent hierarchies, it is used.
@ -58,11 +60,11 @@ def kl(dist_a, dist_b, allow_nan_stats=True, name=None):
If more than one such shortest path exists, the first method
identified in the search is used (favoring a shorter MRO distance to
`type(dist_a)`).
`type(distribution_a)`).
Args:
dist_a: The first distribution.
dist_b: The second distribution.
distribution_a: The first distribution.
distribution_b: The second distribution.
allow_nan_stats: Python `bool`, default `True`. When `True`,
statistics (e.g., mean, mode, variance) use the value "`NaN`" to
indicate the result is undefined. When `False`, an exception is raised
@ -70,20 +72,22 @@ def kl(dist_a, dist_b, allow_nan_stats=True, name=None):
name: Python `str` name prefixed to Ops created by this class.
Returns:
A Tensor with the batchwise KL-divergence between dist_a and dist_b.
A Tensor with the batchwise KL-divergence between `distribution_a`
and `distribution_b`.
Raises:
NotImplementedError: If no KL method is defined for distribution types
of dist_a and dist_b.
of `distribution_a` and `distribution_b`.
"""
kl_fn = _registered_kl(type(dist_a), type(dist_b))
kl_fn = _registered_kl(type(distribution_a), type(distribution_b))
if kl_fn is None:
raise NotImplementedError(
"No KL(dist_a || dist_b) registered for dist_a type %s and dist_b "
"type %s" % (type(dist_a).__name__, type(dist_b).__name__))
"No KL(distribution_a || distribution_b) registered for distribution_a "
"type %s and distribution_b type %s"
% (type(distribution_a).__name__, type(distribution_b).__name__))
with ops.name_scope("KullbackLeibler"):
kl_t = kl_fn(dist_a, dist_b, name=name)
kl_t = kl_fn(distribution_a, distribution_b, name=name)
if allow_nan_stats:
return kl_t
@ -96,7 +100,7 @@ def kl(dist_a, dist_b, allow_nan_stats=True, name=None):
math_ops.reduce_any(math_ops.is_nan(kl_t))),
["KL calculation between %s and %s returned NaN values "
"(and was called with allow_nan_stats=False). Values:"
% (dist_a.name, dist_b.name), kl_t])]):
% (distribution_a.name, distribution_b.name), kl_t])]):
return array_ops.identity(kl_t, name="checked_kl")

View File

@ -20,8 +20,6 @@ from __future__ import print_function
import math
from tensorflow.contrib.distributions.python.ops import kullback_leibler
from tensorflow.contrib.distributions.python.ops import special_math
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@ -32,6 +30,8 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import kullback_leibler
from tensorflow.python.ops.distributions import special_math
__all__ = [

View File

@ -63,16 +63,27 @@ GetSessionHandle
GetSessionHandleV2
GetSessionTensor
HashTable
HashTableV2
InitializeTable
InitializeTableV2
InitializeTableFromTextFile
InitializeTableFromTextFileV2
LookupTableExport
LookupTableExportV2
LookupTableFind
LookupTableFindV2
LookupTableImport
LookupTableImportV2
LookupTableInsert
LookupTableInsertV2
LookupTableSize
LookupTableSizeV2
MutableDenseHashTable
MutableDenseHashTableV2
MutableHashTable
MutableHashTableV2
MutableHashTableOfTensors
MutableHashTableOfTensorsV2
Mutex
MutexAcquire
MutexRelease
@ -220,6 +231,7 @@ BatchFFT3D
BatchIFFT
BatchIFFT2D
BatchIFFT3D
Bucketize
Complex
ComplexAbs
Conj

Some files were not shown because too many files have changed in this diff Show More