Merge pull request #19950 from akshaym/branch_200251004
Branch 200251004
This commit is contained in:
commit
b202db076e
@ -142,8 +142,10 @@ void TestRemoteExecute(bool async) {
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetServerDef(opts, serialized.data(), serialized.size(),
|
||||
status);
|
||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(1));
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(1));
|
||||
TFE_ContextOptionsSetDevicePlacementPolicy(opts,
|
||||
TFE_DEVICE_PLACEMENT_EXPLICIT);
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
@ -205,6 +207,83 @@ void TestRemoteExecute(bool async) {
|
||||
TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); }
|
||||
TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); }
|
||||
|
||||
void TestRemoteExecuteSilentCopies(bool async) {
|
||||
tensorflow::ServerDef server_def = GetServerDef(2);
|
||||
|
||||
// This server def has the task index set to 0.
|
||||
string serialized = server_def.SerializeAsString();
|
||||
|
||||
server_def.set_task_index(1);
|
||||
|
||||
std::unique_ptr<tensorflow::eager::EagerGrpcServer> worker_server;
|
||||
ASSERT_TRUE(
|
||||
tensorflow::eager::EagerGrpcServer::Create(server_def, &worker_server)
|
||||
.ok());
|
||||
ASSERT_TRUE(worker_server->Start().ok());
|
||||
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetServerDef(opts, serialized.data(), serialized.size(),
|
||||
status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(1));
|
||||
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle();
|
||||
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle();
|
||||
const char remote_device_name[] =
|
||||
"/job:localhost/replica:0/task:1/device:CPU:0";
|
||||
|
||||
// Handles are on task0, but op is on remote (task1).
|
||||
TFE_Op* matmul = MatMulOp(ctx, h0_task0, h1_task0);
|
||||
TFE_OpSetDevice(matmul, remote_device_name, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TFE_TensorHandle* retvals[1];
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
auto* retval_task0 = TFE_TensorHandleCopyToDevice(
|
||||
retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(retval_task0, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteTensorHandle(retval_task0);
|
||||
float product[4] = {0};
|
||||
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
|
||||
memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
|
||||
TF_DeleteTensor(t);
|
||||
EXPECT_EQ(7, product[0]);
|
||||
EXPECT_EQ(10, product[1]);
|
||||
EXPECT_EQ(15, product[2]);
|
||||
EXPECT_EQ(22, product[3]);
|
||||
|
||||
TFE_DeleteTensorHandle(h0_task0);
|
||||
TFE_DeleteTensorHandle(h1_task0);
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
|
||||
TFE_DeleteOp(matmul);
|
||||
|
||||
TFE_ContextAsyncWait(ctx, status);
|
||||
TFE_DeleteContext(ctx, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
||||
TF_DeleteStatus(status);
|
||||
|
||||
// TODO(nareshmodi): Figure out how to correctly shut the server down.
|
||||
worker_server.release();
|
||||
}
|
||||
|
||||
TEST(CAPI, RemoteExecuteSilentCopies) { TestRemoteExecuteSilentCopies(false); }
|
||||
TEST(CAPI, RemoteExecuteSilentCopiesAsync) {
|
||||
TestRemoteExecuteSilentCopies(true);
|
||||
}
|
||||
|
||||
TEST(CAPI, TensorHandle) {
|
||||
TFE_TensorHandle* h = TestMatrixTensorHandle();
|
||||
EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h));
|
||||
|
@ -42,7 +42,7 @@ tf_cc_binary(
|
||||
"//tensorflow/compiler/xla/service:cpu_plugin",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"@grpc//:grpc++_unsecure",
|
||||
"@grpc//:grpc++",
|
||||
],
|
||||
)
|
||||
|
||||
@ -61,7 +61,7 @@ tf_cc_test(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"@grpc//:grpc++_unsecure",
|
||||
"@grpc//:grpc++",
|
||||
],
|
||||
)
|
||||
|
||||
@ -74,6 +74,6 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service",
|
||||
"//tensorflow/compiler/xla/service:platform_util",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
|
||||
"@grpc//:grpc++_unsecure",
|
||||
"@grpc//:grpc++",
|
||||
],
|
||||
)
|
||||
|
@ -2379,7 +2379,6 @@ cc_library(
|
||||
":hlo_graph_dumper",
|
||||
":hlo_pass",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
@ -2574,6 +2573,7 @@ cc_library(
|
||||
hdrs = ["hlo_graph_dumper.h"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_casting_utils",
|
||||
":hlo_execution_profile",
|
||||
":hlo_tfgraph_builder",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
|
@ -47,12 +47,16 @@ bool GpuMultiOutputFusion::ShapesCompatibleForFusion(HloInstruction* instr1,
|
||||
element_instr = fused_expression_root;
|
||||
}
|
||||
}
|
||||
// Special handling of kReduce instructions -- the fusion
|
||||
// applies to the first operand.
|
||||
if (element_instr->opcode() == HloOpcode::kReduce) {
|
||||
return element_instr->operand(0)->shape();
|
||||
}
|
||||
return element_instr->shape();
|
||||
};
|
||||
|
||||
// The elementwise output shapes must be the same (including layout)
|
||||
return ShapeUtil::ShapeUtil::Equal(get_element_shape(instr1),
|
||||
get_element_shape(instr2));
|
||||
return ShapeUtil::Equal(get_element_shape(instr1), get_element_shape(instr2));
|
||||
}
|
||||
|
||||
bool GpuMultiOutputFusion::IsProfitableOperand(HloInstruction* instr) {
|
||||
|
@ -36,6 +36,11 @@ const char kModulePrefix[] = R"(
|
||||
scalar_lhs = f32[] parameter(0)
|
||||
scalar_rhs = f32[] parameter(1)
|
||||
ROOT add = f32[] add(scalar_lhs, scalar_rhs)
|
||||
}
|
||||
scalar_mul_computation {
|
||||
scalar_lhs = f32[] parameter(0)
|
||||
scalar_rhs = f32[] parameter(1)
|
||||
ROOT mul = f32[] add(scalar_lhs, scalar_rhs)
|
||||
})";
|
||||
|
||||
TEST_F(InstructionFusionTest, MultiOutputFusionSiblingReduceAndReduceFusion) {
|
||||
@ -67,6 +72,34 @@ TEST_F(InstructionFusionTest, MultiOutputFusionSiblingReduceAndReduceFusion) {
|
||||
op::Tuple(op::Reduce(), op::Reduce()));
|
||||
}
|
||||
|
||||
TEST_F(InstructionFusionTest, MultiOutputFusionDifferentReduceInputShapes) {
|
||||
auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
|
||||
fused_computation_1 {
|
||||
p1.1 = f32[6400]{0} parameter(1)
|
||||
mul = f32[6400]{0} multiply(p1.1, p1.1)
|
||||
const.1 = f32[] parameter(0)
|
||||
ROOT reduce.1 = f32[] reduce(p1.1, const.1), dimensions={0}, to_apply=scalar_add_computation
|
||||
}
|
||||
|
||||
fused_computation_2 {
|
||||
p1.2 = f32[6400]{0} parameter(1)
|
||||
r1 = f32[64,100]{0,1} reshape(p1.2)
|
||||
const.2 = f32[] parameter(0)
|
||||
ROOT reduce.2 = f32[] reduce(r1, const.2), dimensions={1,0}, to_apply=scalar_mul_computation
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
p0 = f32[] parameter(0)
|
||||
p1 = f32[6400]{0} parameter(1)
|
||||
const.2 = f32[] constant(1)
|
||||
fusion.1 = f32[] fusion(p0, p1), kind=kInput, calls=fused_computation_1
|
||||
fusion.2 = f32[] fusion(p0, p1), kind=kInput, calls=fused_computation_2
|
||||
ROOT root = (f32[], f32[]) tuple(fusion.1, fusion.2)
|
||||
})"))
|
||||
.ValueOrDie();
|
||||
ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
|
||||
}
|
||||
|
||||
TEST_F(InstructionFusionTest, MultiOutputFusionSiblingReduceFusions) {
|
||||
// Two sibling fusions with reduce instruction roots sharing the same input
|
||||
// param.
|
||||
|
@ -357,7 +357,6 @@ std::list<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
|
||||
std::list<HloInstruction*> post_order;
|
||||
std::list<HloInstruction*> trace_instructions;
|
||||
tensorflow::gtl::FlatSet<HloInstruction*> added_instructions;
|
||||
std::vector<HloInstruction> dfs_stack;
|
||||
for (auto& instruction : instructions_) {
|
||||
if (instruction->opcode() == HloOpcode::kTrace) {
|
||||
// Trace instructions aren't handled by the DFS visitor. Add trace
|
||||
|
@ -28,6 +28,8 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
@ -723,17 +725,14 @@ string HloDotDumper::DumpRootTag() {
|
||||
to_id, node_body, node_shape, NodeColorAttributes(color));
|
||||
}
|
||||
|
||||
static const HloInstruction* TryGetFusionParameterConstant(
|
||||
static const HloConstantInstruction* TryGetFusionParameterConstant(
|
||||
const HloInstruction* instr) {
|
||||
if (instr->opcode() != HloOpcode::kParameter || !instr->IsFused()) {
|
||||
return nullptr;
|
||||
}
|
||||
const HloInstruction* fusion = instr->parent()->FusionInstruction();
|
||||
const HloInstruction* operand = fusion->operand(instr->parameter_number());
|
||||
if (operand->opcode() == HloOpcode::kConstant) {
|
||||
return operand;
|
||||
}
|
||||
return nullptr;
|
||||
return DynCast<HloConstantInstruction>(operand);
|
||||
}
|
||||
|
||||
bool HloDotDumper::ShouldMergeIntoUsers(const HloInstruction* instr) const {
|
||||
@ -826,7 +825,7 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) {
|
||||
|
||||
string HloDotDumper::GetInstructionNodeInlinedOperands(
|
||||
const HloInstruction* instr) {
|
||||
auto stringify_constant = [](const HloInstruction* constant) {
|
||||
auto stringify_constant = [](const HloConstantInstruction* constant) {
|
||||
const auto& shape = constant->shape();
|
||||
|
||||
// If the shape has a dimension of size zero, print it as e.g.
|
||||
@ -845,7 +844,7 @@ string HloDotDumper::GetInstructionNodeInlinedOperands(
|
||||
*elem_count *= dim;
|
||||
}
|
||||
}
|
||||
if (elem_count.has_value() && *elem_count <= 8 && constant->HasLiteral()) {
|
||||
if (elem_count.has_value() && *elem_count <= 8) {
|
||||
return Printf("%s (%s)", constant->literal().ToString(),
|
||||
ShapeUtil::HumanString(constant->shape()));
|
||||
}
|
||||
@ -864,9 +863,10 @@ string HloDotDumper::GetInstructionNodeInlinedOperands(
|
||||
std::vector<string> lines;
|
||||
for (int64 i = 0; i < instr->operand_count(); ++i) {
|
||||
const HloInstruction* operand = instr->operand(i);
|
||||
const auto* constant_operand = DynCast<HloConstantInstruction>(operand);
|
||||
optional<string> operand_str;
|
||||
if (operand->opcode() == HloOpcode::kConstant) {
|
||||
operand_str = stringify_constant(operand);
|
||||
if (constant_operand != nullptr) {
|
||||
operand_str = stringify_constant(constant_operand);
|
||||
} else if (ShouldMergeIntoUsers(operand)) {
|
||||
// Special case: If the operand is a parameter to a fusion node and it
|
||||
// always has a constant value, display it like a regular constant.
|
||||
@ -874,7 +874,7 @@ string HloDotDumper::GetInstructionNodeInlinedOperands(
|
||||
// For other parameters, use the parameter number rather than the proper
|
||||
// name, because that's generally how people think of the node.
|
||||
if (operand->opcode() == HloOpcode::kParameter) {
|
||||
if (const HloInstruction* constant =
|
||||
if (const HloConstantInstruction* constant =
|
||||
TryGetFusionParameterConstant(operand)) {
|
||||
operand_str = stringify_constant(constant);
|
||||
} else {
|
||||
|
@ -178,6 +178,23 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
slice_limits, slice_strides);
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kConstant: {
|
||||
CHECK(proto.has_literal());
|
||||
TF_ASSIGN_OR_RETURN(auto literal,
|
||||
Literal::CreateFromProto(proto.literal()));
|
||||
instruction = CreateConstant(std::move(literal));
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kTrace: {
|
||||
TF_RET_CHECK(proto.operand_ids_size() == 1)
|
||||
<< "Trace instruction should have 1 operand but sees "
|
||||
<< proto.operand_ids_size();
|
||||
CHECK(proto.has_literal());
|
||||
TF_ASSIGN_OR_RETURN(auto literal,
|
||||
Literal::CreateFromProto(proto.literal()));
|
||||
instruction = CreateTrace(literal->GetR1U8AsString(), operands(0));
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
instruction = WrapUnique(new HloInstruction(opcode, proto.shape()));
|
||||
for (const int64 operand_id : proto.operand_ids()) {
|
||||
@ -223,22 +240,11 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
instruction->called_computations_.push_back(fused_computation);
|
||||
}
|
||||
|
||||
if (instruction->opcode() == HloOpcode::kTrace) {
|
||||
TF_RET_CHECK(instruction->operands().size() == 1)
|
||||
<< "Trace instruction should have 1 operand but sees "
|
||||
<< instruction->operands().size();
|
||||
instruction->mutable_operand(0)->set_tracing(instruction.get());
|
||||
}
|
||||
|
||||
TF_RET_CHECK(!proto.name().empty());
|
||||
instruction->SetAndSanitizeName(proto.name());
|
||||
|
||||
instruction->metadata_ = proto.metadata();
|
||||
instruction->backend_config_ = proto.backend_config();
|
||||
if (proto.has_literal()) {
|
||||
TF_ASSIGN_OR_RETURN(instruction->literal_,
|
||||
Literal::CreateFromProto(proto.literal()));
|
||||
}
|
||||
instruction->parameter_number_ = proto.parameter_number();
|
||||
|
||||
instruction->tuple_index_ = proto.tuple_index();
|
||||
@ -301,20 +307,12 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
|
||||
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTrace(
|
||||
const string& tag, HloInstruction* operand) {
|
||||
auto instruction =
|
||||
WrapUnique(new HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil()));
|
||||
instruction->operands_.push_back(operand);
|
||||
instruction->literal_ = Literal::CreateR1U8(tag);
|
||||
operand->set_tracing(instruction.get());
|
||||
return instruction;
|
||||
return MakeUnique<HloTraceInstruction>(tag, operand);
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConstant(
|
||||
std::unique_ptr<Literal> literal) {
|
||||
auto instruction =
|
||||
WrapUnique(new HloInstruction(HloOpcode::kConstant, literal->shape()));
|
||||
instruction->literal_ = std::move(literal);
|
||||
return instruction;
|
||||
return MakeUnique<HloConstantInstruction>(std::move(literal));
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<HloInstruction>
|
||||
@ -1321,6 +1319,8 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
|
||||
case HloOpcode::kBroadcast:
|
||||
case HloOpcode::kMap:
|
||||
case HloOpcode::kSlice:
|
||||
case HloOpcode::kConstant:
|
||||
case HloOpcode::kTrace:
|
||||
clone = CloneWithNewOperandsImpl(shape, new_operands, context);
|
||||
break;
|
||||
// Unary ops.
|
||||
@ -1470,9 +1470,6 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
|
||||
clone =
|
||||
CreateWhile(shape, while_condition(), while_body(), new_operands[0]);
|
||||
break;
|
||||
case HloOpcode::kConstant:
|
||||
clone = CreateConstant(literal_->CloneToUnique());
|
||||
break;
|
||||
case HloOpcode::kFusion: {
|
||||
HloModule* module = context != nullptr ? context->module() : GetModule();
|
||||
HloComputation* new_fused_computation = nullptr;
|
||||
@ -1520,8 +1517,6 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
|
||||
case HloOpcode::kGenerateToken:
|
||||
clone = CreateGenerateToken(new_operands);
|
||||
break;
|
||||
case HloOpcode::kTrace:
|
||||
LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode_);
|
||||
}
|
||||
SetupDerivedInstruction(clone.get());
|
||||
clone->set_parent(parent_);
|
||||
@ -1602,13 +1597,6 @@ const HloInstruction* HloInstruction::LatestNonGteAncestor() const {
|
||||
return hlo;
|
||||
}
|
||||
|
||||
const Literal& HloInstruction::literal() const {
|
||||
CHECK_EQ(HloOpcode::kConstant, opcode_);
|
||||
return *literal_;
|
||||
}
|
||||
|
||||
bool HloInstruction::HasLiteral() const { return literal_ != nullptr; }
|
||||
|
||||
int64 HloInstruction::tuple_index() const {
|
||||
CHECK_EQ(HloOpcode::kGetTupleElement, opcode_);
|
||||
return tuple_index_;
|
||||
@ -1702,10 +1690,6 @@ void HloInstruction::AddUser(HloInstruction* user) {
|
||||
}
|
||||
}
|
||||
|
||||
bool HloInstruction::IsConstant() const {
|
||||
return opcode_ == HloOpcode::kConstant;
|
||||
}
|
||||
|
||||
bool HloInstruction::HasConstantOperand() const {
|
||||
for (const HloInstruction* operand : operands_) {
|
||||
if (operand->IsConstant()) {
|
||||
@ -1782,7 +1766,6 @@ bool HloInstruction::IdenticalSlowPath(
|
||||
// These opcodes have complex or special behavior so just return false.
|
||||
case HloOpcode::kDomain:
|
||||
case HloOpcode::kRng:
|
||||
case HloOpcode::kTrace:
|
||||
case HloOpcode::kWhile:
|
||||
case HloOpcode::kGenerateToken:
|
||||
return false;
|
||||
@ -1790,10 +1773,6 @@ bool HloInstruction::IdenticalSlowPath(
|
||||
case HloOpcode::kParameter:
|
||||
return parameter_number() == other.parameter_number();
|
||||
|
||||
// A constant is defined by the value in the literal.
|
||||
case HloOpcode::kConstant:
|
||||
return literal() == other.literal();
|
||||
|
||||
// A reduce-precision operation is determined by the bit sizes.
|
||||
case HloOpcode::kReducePrecision:
|
||||
return exponent_bits() == other.exponent_bits() &&
|
||||
@ -1878,6 +1857,8 @@ bool HloInstruction::IdenticalSlowPath(
|
||||
case HloOpcode::kBroadcast:
|
||||
case HloOpcode::kMap:
|
||||
case HloOpcode::kSlice:
|
||||
case HloOpcode::kConstant:
|
||||
case HloOpcode::kTrace:
|
||||
LOG(FATAL) << "Base class impl called for opcode with subclass: "
|
||||
<< opcode();
|
||||
}
|
||||
@ -2172,34 +2153,7 @@ string HloInstruction::OperandsToStringWithCanonicalNameMap(
|
||||
const HloPrintOptions& options,
|
||||
CanonicalNameMap* canonical_name_map) const {
|
||||
string operands;
|
||||
if (opcode() == HloOpcode::kConstant) {
|
||||
// For constants, show the actual value in place of an empty operand list.
|
||||
//
|
||||
// In HloInstruction, sometimes a constant literal is not constructed due
|
||||
// to its size. Skip the printing in this case.
|
||||
if (HasLiteral() && ((!ShapeUtil::IsTuple(shape()) &&
|
||||
ShapeUtil::ElementsIn(shape()) <= 10) ||
|
||||
options.print_large_constants())) {
|
||||
// Literal::ToString emits multidimensional arrays over multiple
|
||||
// lines. Compact this into one line by stripping out white space.
|
||||
string tmp = literal().ToString();
|
||||
std::replace(tmp.begin(), tmp.end(), '\n', ' ');
|
||||
std::vector<string> v = tensorflow::str_util::Split(tmp, ' ');
|
||||
bool first = true;
|
||||
// Concatenate elements in "v" with spaces separating them, but ignoring
|
||||
// empty entries.
|
||||
for (const auto& s : v) {
|
||||
if (s.empty()) {
|
||||
continue;
|
||||
}
|
||||
StrAppend(&operands, (first ? "" : " "), s);
|
||||
first = false;
|
||||
}
|
||||
} else {
|
||||
// Do not show large constants or tuples.
|
||||
operands = "{...}";
|
||||
}
|
||||
} else if (opcode() == HloOpcode::kParameter) {
|
||||
if (opcode() == HloOpcode::kParameter) {
|
||||
StrAppend(&operands, parameter_number_);
|
||||
} else {
|
||||
tensorflow::gtl::ArraySlice<HloInstruction*> slice(operands_);
|
||||
@ -2410,9 +2364,6 @@ HloInstructionProto HloInstruction::ToProto() const {
|
||||
|
||||
*proto.mutable_metadata() = metadata_;
|
||||
proto.set_backend_config(backend_config_);
|
||||
if (literal_ != nullptr) {
|
||||
*proto.mutable_literal() = literal_->ToProto();
|
||||
}
|
||||
proto.set_parameter_number(parameter_number_);
|
||||
if (opcode() == HloOpcode::kFusion) {
|
||||
proto.set_fusion_kind(xla::ToString(fusion_kind()));
|
||||
@ -2518,12 +2469,6 @@ void HloInstruction::set_tracing(HloInstruction* trace_instruction) {
|
||||
trace_instruction_ = trace_instruction;
|
||||
}
|
||||
|
||||
string HloInstruction::TracingTag() const {
|
||||
CHECK_EQ(HloOpcode::kTrace, opcode());
|
||||
CHECK(literal_ != nullptr);
|
||||
return literal_->GetR1U8AsString();
|
||||
}
|
||||
|
||||
bool HloInstruction::IsFused() const { return parent_->IsFusionComputation(); }
|
||||
|
||||
bool HloInstruction::IsFusable() const {
|
||||
@ -3035,10 +2980,6 @@ bool HloInstruction::IsElementwiseBinary() const {
|
||||
|
||||
bool HloInstruction::IsElementwise() const {
|
||||
switch (opcode_) {
|
||||
// Nullary elementwise operations.
|
||||
case HloOpcode::kConstant:
|
||||
return true;
|
||||
|
||||
// Unary elementwise operations.
|
||||
case HloOpcode::kAbs:
|
||||
case HloOpcode::kRoundNearestAfz:
|
||||
@ -3500,23 +3441,6 @@ void HloInstruction::set_outer_dimension_partitions(
|
||||
outer_dimension_partitions_ = outer_dimension_partitions;
|
||||
}
|
||||
|
||||
void HloInstruction::RelayoutConstant(const Layout& new_layout,
|
||||
const ShapeIndex& shape_index) {
|
||||
CHECK_EQ(opcode(), HloOpcode::kConstant);
|
||||
Shape* mutable_array_subshape =
|
||||
ShapeUtil::GetMutableSubshape(mutable_shape(), shape_index);
|
||||
CHECK(ShapeUtil::IsArray(*mutable_array_subshape));
|
||||
|
||||
// Normally array_subshape will always have a layout, but this invariant is
|
||||
// temporarily broken in LayoutAssignment::AssignLayouts.
|
||||
|
||||
if (!mutable_array_subshape->has_layout() ||
|
||||
!LayoutUtil::Equal(mutable_array_subshape->layout(), new_layout)) {
|
||||
literal_ = literal_->Relayout(new_layout, shape_index);
|
||||
*mutable_array_subshape->mutable_layout() = new_layout;
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(b/80131774): Remove these temporary methods after transition.
|
||||
int64 HloInstruction::feature_index() const {
|
||||
return Cast<HloBatchNormInstruction>(this)->feature_index();
|
||||
@ -3574,4 +3498,21 @@ const std::vector<int64>& HloInstruction::slice_strides() const {
|
||||
bool HloInstruction::IsInPlaceSlice() const {
|
||||
return Cast<HloSliceInstruction>(this)->IsInPlaceSlice();
|
||||
}
|
||||
|
||||
const Literal& HloInstruction::literal() const {
|
||||
return Cast<HloConstantInstruction>(this)->literal();
|
||||
}
|
||||
|
||||
bool HloInstruction::IsConstant() const {
|
||||
return DynCast<HloConstantInstruction>(this) != nullptr;
|
||||
}
|
||||
|
||||
void HloInstruction::RelayoutConstant(const Layout& new_layout,
|
||||
const ShapeIndex& shape_index) {
|
||||
Cast<HloConstantInstruction>(this)->RelayoutConstant(new_layout, shape_index);
|
||||
}
|
||||
|
||||
string HloInstruction::TracingTag() const {
|
||||
return Cast<HloTraceInstruction>(this)->TracingTag();
|
||||
}
|
||||
} // namespace xla
|
||||
|
@ -875,14 +875,6 @@ class HloInstruction {
|
||||
template <typename HloInstructionPtr>
|
||||
Status Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor);
|
||||
|
||||
// Returns the literal associated with this instruction.
|
||||
//
|
||||
// Note: only constant and parameter opcodes have an associated literal.
|
||||
const Literal& literal() const;
|
||||
|
||||
// Returns whether there is literal associated with this instruction.
|
||||
bool HasLiteral() const;
|
||||
|
||||
// Returns the parameter number associated with this instruction.
|
||||
//
|
||||
// Note: only parameter opcodes have an associated parameter number.
|
||||
@ -1014,14 +1006,6 @@ class HloInstruction {
|
||||
string infeed_config() const { return infeed_config_; }
|
||||
void set_infeed_config(const string& config) { infeed_config_ = config; }
|
||||
|
||||
// Returns a tag to be used in tracing.
|
||||
//
|
||||
// Precondition: opcode() == HloOpcode::kTrace
|
||||
string TracingTag() const;
|
||||
|
||||
// Returns whether the instruction is a constant.
|
||||
bool IsConstant() const;
|
||||
|
||||
// Returns true if this instruction is fused, ie contained within a fusion
|
||||
// instruction.
|
||||
bool IsFused() const;
|
||||
@ -1452,12 +1436,6 @@ class HloInstruction {
|
||||
void set_outer_dimension_partitions(
|
||||
const std::vector<int64>& outer_dimension_partitions);
|
||||
|
||||
// Change the layout for an Constant Hlo instruction to match new_layout. For
|
||||
// tuple shaped constants shape_index is the path to the internal array
|
||||
// subshape whose layout needs to be changed.
|
||||
void RelayoutConstant(const Layout& new_layout,
|
||||
const ShapeIndex& shape_index = {});
|
||||
|
||||
// Old methods kept for smooth subclassing transition BEGIN.
|
||||
// TODO(b/80131774): Remove this code.
|
||||
|
||||
@ -1504,6 +1482,19 @@ class HloInstruction {
|
||||
|
||||
// Delegates to HloSliceInstruction::IsInPlaceSlice.
|
||||
bool IsInPlaceSlice() const;
|
||||
|
||||
// Returns the literal associated with this instruction.
|
||||
const Literal& literal() const;
|
||||
|
||||
// Returns whether the instruction is a constant.
|
||||
bool IsConstant() const;
|
||||
|
||||
// Delegate to HloConstantInstruction::RelayoutConstant.
|
||||
void RelayoutConstant(const Layout& new_layout,
|
||||
const ShapeIndex& shape_index = {});
|
||||
|
||||
// Delegates to HloTraceInstruction::TracingTag.
|
||||
string TracingTag() const;
|
||||
// Old methods kept for smooth subclassing transition END.
|
||||
|
||||
protected:
|
||||
@ -1544,7 +1535,7 @@ class HloInstruction {
|
||||
CanonicalNameMap* canonical_name_map) const;
|
||||
|
||||
// Prints an operand to a string.
|
||||
string OperandsToStringWithCanonicalNameMap(
|
||||
virtual string OperandsToStringWithCanonicalNameMap(
|
||||
const HloPrintOptions& options,
|
||||
CanonicalNameMap* canonical_name_map) const;
|
||||
|
||||
@ -1639,9 +1630,6 @@ class HloInstruction {
|
||||
// Result shape of this instruction.
|
||||
Shape shape_;
|
||||
|
||||
// Literal, only present for kConstant.
|
||||
std::unique_ptr<Literal> literal_;
|
||||
|
||||
// Constant index, only present for kGetTupleElement.
|
||||
int64 tuple_index_ = -1;
|
||||
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
namespace xla {
|
||||
|
||||
using ::tensorflow::str_util::Join;
|
||||
using ::tensorflow::strings::StrAppend;
|
||||
using ::tensorflow::strings::StrCat;
|
||||
|
||||
HloBatchNormInstruction::HloBatchNormInstruction(
|
||||
@ -586,4 +587,105 @@ std::unique_ptr<HloInstruction> HloSliceInstruction::CloneWithNewOperandsImpl(
|
||||
return MakeUnique<HloSliceInstruction>(shape, new_operands[0], slice_starts_,
|
||||
slice_limits_, slice_strides_);
|
||||
}
|
||||
|
||||
HloConstantInstruction::HloConstantInstruction(std::unique_ptr<Literal> literal)
|
||||
: HloInstruction(HloOpcode::kConstant, CHECK_NOTNULL(literal)->shape()),
|
||||
literal_(std::move(literal)) {}
|
||||
|
||||
HloInstructionProto HloConstantInstruction::ToProto() const {
|
||||
HloInstructionProto proto = HloInstruction::ToProto();
|
||||
*proto.mutable_literal() = literal_->ToProto();
|
||||
return proto;
|
||||
}
|
||||
|
||||
bool HloConstantInstruction::IsElementwise() const { return true; }
|
||||
|
||||
void HloConstantInstruction::RelayoutConstant(const Layout& new_layout,
|
||||
const ShapeIndex& shape_index) {
|
||||
Shape* mutable_array_subshape =
|
||||
ShapeUtil::GetMutableSubshape(mutable_shape(), shape_index);
|
||||
CHECK(ShapeUtil::IsArray(*mutable_array_subshape));
|
||||
|
||||
// Normally array_subshape will always have a layout, but this invariant is
|
||||
// temporarily broken in LayoutAssignment::AssignLayouts.
|
||||
|
||||
if (!mutable_array_subshape->has_layout() ||
|
||||
!LayoutUtil::Equal(mutable_array_subshape->layout(), new_layout)) {
|
||||
literal_ = literal_->Relayout(new_layout, shape_index);
|
||||
*mutable_array_subshape->mutable_layout() = new_layout;
|
||||
}
|
||||
}
|
||||
|
||||
bool HloConstantInstruction::IdenticalSlowPath(
|
||||
const HloInstruction& other,
|
||||
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
||||
eq_computations) const {
|
||||
const auto& other_slice = static_cast<const HloSliceInstruction&>(other);
|
||||
return literal() == other_slice.literal();
|
||||
}
|
||||
|
||||
std::unique_ptr<HloInstruction>
|
||||
HloConstantInstruction::CloneWithNewOperandsImpl(
|
||||
const Shape& shape,
|
||||
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
|
||||
HloCloneContext* context) const {
|
||||
return MakeUnique<HloConstantInstruction>(literal_->CloneToUnique());
|
||||
}
|
||||
|
||||
string HloConstantInstruction::OperandsToStringWithCanonicalNameMap(
|
||||
const HloPrintOptions& options,
|
||||
CanonicalNameMap* canonical_name_map) const {
|
||||
string operands;
|
||||
// For constants, show the actual value in place of an empty operand list.
|
||||
if ((!ShapeUtil::IsTuple(shape()) && ShapeUtil::ElementsIn(shape()) <= 10) ||
|
||||
options.print_large_constants()) {
|
||||
// Literal::ToString emits multidimensional arrays over multiple
|
||||
// lines. Compact this into one line by stripping out white space.
|
||||
string tmp = literal().ToString();
|
||||
std::replace(tmp.begin(), tmp.end(), '\n', ' ');
|
||||
std::vector<string> v = tensorflow::str_util::Split(tmp, ' ');
|
||||
bool first = true;
|
||||
// Concatenate elements in "v" with spaces separating them, but ignoring
|
||||
// empty entries.
|
||||
for (const auto& s : v) {
|
||||
if (s.empty()) {
|
||||
continue;
|
||||
}
|
||||
StrAppend(&operands, (first ? "" : " "), s);
|
||||
first = false;
|
||||
}
|
||||
} else {
|
||||
// Do not show large constants or tuples.
|
||||
operands = "{...}";
|
||||
}
|
||||
return operands;
|
||||
}
|
||||
|
||||
HloTraceInstruction::HloTraceInstruction(const string& tag,
|
||||
HloInstruction* operand)
|
||||
: HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil()),
|
||||
literal_(Literal::CreateR1U8(tag)) {
|
||||
AppendOperand(operand);
|
||||
operand->set_tracing(this);
|
||||
}
|
||||
|
||||
HloInstructionProto HloTraceInstruction::ToProto() const {
|
||||
HloInstructionProto proto = HloInstruction::ToProto();
|
||||
*proto.mutable_literal() = literal_->ToProto();
|
||||
return proto;
|
||||
}
|
||||
|
||||
bool HloTraceInstruction::IdenticalSlowPath(
|
||||
const HloInstruction& other,
|
||||
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
||||
eq_computations) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::unique_ptr<HloInstruction> HloTraceInstruction::CloneWithNewOperandsImpl(
|
||||
const Shape& shape,
|
||||
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
|
||||
HloCloneContext* context) const {
|
||||
LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode());
|
||||
}
|
||||
} // namespace xla
|
||||
|
@ -433,6 +433,62 @@ class HloSliceInstruction : public HloInstruction {
|
||||
// Describes whether the slice can be lowered to an offset into the operand.
|
||||
bool is_in_place_slice_ = false;
|
||||
};
|
||||
|
||||
class HloConstantInstruction : public HloInstruction {
|
||||
public:
|
||||
explicit HloConstantInstruction(std::unique_ptr<Literal> literal);
|
||||
// Returns the literal associated with this instruction.
|
||||
const Literal& literal() const { return *literal_; }
|
||||
// Returns a serialized representation of this instruction.
|
||||
HloInstructionProto ToProto() const override;
|
||||
// Returns true if this instruction is elementwise on all its operands.
|
||||
bool IsElementwise() const override;
|
||||
|
||||
// Change the layout for an Constant Hlo instruction to match new_layout. For
|
||||
// tuple shaped constants shape_index is the path to the internal array
|
||||
// subshape whose layout needs to be changed.
|
||||
void RelayoutConstant(const Layout& new_layout,
|
||||
const ShapeIndex& shape_index = {});
|
||||
|
||||
private:
|
||||
bool IdenticalSlowPath(
|
||||
const HloInstruction& other,
|
||||
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
||||
eq_computations) const override;
|
||||
string OperandsToStringWithCanonicalNameMap(
|
||||
const HloPrintOptions& options,
|
||||
CanonicalNameMap* canonical_name_map) const override;
|
||||
// Implementation for non-common logic of CloneWithNewOperands.
|
||||
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
||||
const Shape& shape,
|
||||
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
|
||||
HloCloneContext* context) const override;
|
||||
// TODO(b/36360764): Remove unique_ptr wrapping.
|
||||
std::unique_ptr<Literal> literal_;
|
||||
};
|
||||
|
||||
class HloTraceInstruction : public HloInstruction {
|
||||
public:
|
||||
explicit HloTraceInstruction(const string& tag, HloInstruction* operand);
|
||||
// Returns a tag to be used in tracing.
|
||||
string TracingTag() const { return literal_->GetR1U8AsString(); }
|
||||
// Returns a serialized representation of this instruction.
|
||||
HloInstructionProto ToProto() const override;
|
||||
|
||||
private:
|
||||
bool IdenticalSlowPath(
|
||||
const HloInstruction& other,
|
||||
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
||||
eq_computations) const override;
|
||||
// Implementation for non-common logic of CloneWithNewOperands.
|
||||
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
||||
const Shape& shape,
|
||||
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
|
||||
HloCloneContext* context) const override;
|
||||
// TODO(b/36360764): Remove unique_ptr wrapping.
|
||||
std::unique_ptr<Literal> literal_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_
|
||||
|
@ -127,9 +127,14 @@ Status HloModuleGroupMetadata::VerifyCompanionSets() const {
|
||||
for (HloInstruction* instruction : *companions) {
|
||||
// Go through all the communicating instructions (send, recv) of the given
|
||||
// companion, and record their device.
|
||||
auto it = tracked_instructions_comms_.find(instruction);
|
||||
if (it == tracked_instructions_comms_.end()) {
|
||||
// Companions can be added even if they have no communicating
|
||||
// instructions, if they are parent of companions.
|
||||
continue;
|
||||
}
|
||||
std::unordered_set<int64> comm_devices;
|
||||
for (HloInstruction* comm_instruction :
|
||||
tracked_instructions_comms_.at(instruction)) {
|
||||
for (HloInstruction* comm_instruction : it->second) {
|
||||
auto device = GetInstructionDevice(*comm_instruction);
|
||||
TF_RET_CHECK(device) << "Instruction " << comm_instruction->ToString()
|
||||
<< " does not have a device";
|
||||
|
@ -232,7 +232,13 @@ def _dnn_tree_combined_model_fn(features,
|
||||
return update_op
|
||||
|
||||
if predict_with_tree_only:
|
||||
tree_train_logits = tree_logits
|
||||
if mode == model_fn.ModeKeys.TRAIN or mode == model_fn.ModeKeys.PREDICT:
|
||||
tree_train_logits = tree_logits
|
||||
else:
|
||||
tree_train_logits = control_flow_ops.cond(
|
||||
global_step > dnn_steps_to_train,
|
||||
lambda: tree_logits,
|
||||
lambda: dnn_logits)
|
||||
else:
|
||||
tree_train_logits = dnn_logits + tree_logits
|
||||
|
||||
|
@ -36,6 +36,7 @@ except ImportError:
|
||||
|
||||
|
||||
_GKE_ENV_VARIABLE = 'KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'
|
||||
_ENDPOINTS_SEPARATOR = ','
|
||||
_DEFAULT_ENV_VARIABLE = 'TPU_NAME'
|
||||
_DISCOVERY_SERVICE_URL_ENV_VARIABLE = 'TPU_API_DISCOVERY_URL'
|
||||
|
||||
@ -69,8 +70,8 @@ class TPUClusterResolver(ClusterResolver):
|
||||
return _GKE_ENV_VARIABLE in os.environ
|
||||
|
||||
@staticmethod
|
||||
def _gkeMaster():
|
||||
return os.environ[_GKE_ENV_VARIABLE].split(',')[0]
|
||||
def _gkeEndpoints():
|
||||
return os.environ[_GKE_ENV_VARIABLE]
|
||||
|
||||
@staticmethod
|
||||
def _envVarFallback():
|
||||
@ -143,7 +144,7 @@ class TPUClusterResolver(ClusterResolver):
|
||||
# When using GKE with Cloud TPUs, the env variable will be set.
|
||||
if tpu is None:
|
||||
if in_gke:
|
||||
tpu = self._gkeMaster()
|
||||
tpu = self._gkeEndpoints()
|
||||
else:
|
||||
tpu = self._envVarFallback()
|
||||
|
||||
@ -214,7 +215,7 @@ class TPUClusterResolver(ClusterResolver):
|
||||
ValueError: If none of the TPUs specified exists.
|
||||
"""
|
||||
if not self._shouldResolve():
|
||||
return self._tpu
|
||||
return self._tpu.split(compat.as_bytes(_ENDPOINTS_SEPARATOR))[0]
|
||||
|
||||
job_tasks = self.cluster_spec().job_tasks(self._job_name)
|
||||
if not job_tasks:
|
||||
@ -280,8 +281,12 @@ class TPUClusterResolver(ClusterResolver):
|
||||
# Case 3.
|
||||
return None
|
||||
# Case 2.
|
||||
cluster_spec = {self._job_name: [self._tpu[len(
|
||||
compat.as_bytes('grpc://')):]]}
|
||||
cluster_spec = {
|
||||
self._job_name: [
|
||||
x[len(compat.as_bytes('grpc://')):]
|
||||
for x in self._tpu.split(compat.as_bytes(_ENDPOINTS_SEPARATOR))
|
||||
]
|
||||
}
|
||||
|
||||
if self._coordinator_address:
|
||||
# {1, 2}.a
|
||||
|
@ -402,13 +402,61 @@ class TPUClusterResolverTest(test.TestCase):
|
||||
compat.as_bytes('/bns/foo/bar'), tpu_cluster_resolver.master())
|
||||
self.assertEqual(None, tpu_cluster_resolver.cluster_spec())
|
||||
|
||||
def testGkeEnvironment(self):
|
||||
def testGkeEnvironmentForDonut(self):
|
||||
os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] = 'grpc://10.120.27.5:8470'
|
||||
self.assertTrue('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS' in os.environ)
|
||||
|
||||
self.assertIn('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS', os.environ)
|
||||
self.assertTrue(TPUClusterResolver._inGke())
|
||||
self.assertEqual(
|
||||
compat.as_bytes('grpc://10.120.27.5:8470'),
|
||||
compat.as_bytes(TPUClusterResolver._gkeMaster()))
|
||||
compat.as_bytes(TPUClusterResolver._gkeEndpoints()))
|
||||
|
||||
tpu_cluster_resolver = TPUClusterResolver()
|
||||
self.assertEqual(
|
||||
compat.as_bytes('grpc://10.120.27.5:8470'),
|
||||
compat.as_bytes(tpu_cluster_resolver.master()))
|
||||
actual_cluster_spec = tpu_cluster_resolver.cluster_spec()
|
||||
expected_proto = """
|
||||
job {
|
||||
name: 'worker'
|
||||
tasks { key: 0 value: '10.120.27.5:8470' }
|
||||
}
|
||||
"""
|
||||
self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
|
||||
|
||||
del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS']
|
||||
|
||||
def testGkeEnvironmentForPod(self):
|
||||
os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] = ('grpc://10.120.27.5:8470,'
|
||||
'grpc://10.120.27.6:8470,'
|
||||
'grpc://10.120.27.7:8470,'
|
||||
'grpc://10.120.27.8:8470')
|
||||
|
||||
self.assertIn('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS', os.environ)
|
||||
self.assertTrue(TPUClusterResolver._inGke())
|
||||
self.assertEqual(
|
||||
compat.as_bytes('grpc://10.120.27.5:8470,'
|
||||
'grpc://10.120.27.6:8470,'
|
||||
'grpc://10.120.27.7:8470,'
|
||||
'grpc://10.120.27.8:8470'),
|
||||
compat.as_bytes(TPUClusterResolver._gkeEndpoints()))
|
||||
|
||||
tpu_cluster_resolver = TPUClusterResolver()
|
||||
self.assertEqual(
|
||||
compat.as_bytes('grpc://10.120.27.5:8470'),
|
||||
compat.as_bytes(tpu_cluster_resolver.master()))
|
||||
actual_cluster_spec = tpu_cluster_resolver.cluster_spec()
|
||||
expected_proto = """
|
||||
job {
|
||||
name: 'worker'
|
||||
tasks { key: 0 value: '10.120.27.5:8470' }
|
||||
tasks { key: 1 value: '10.120.27.6:8470' }
|
||||
tasks { key: 2 value: '10.120.27.7:8470' }
|
||||
tasks { key: 3 value: '10.120.27.8:8470' }
|
||||
}
|
||||
"""
|
||||
self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
|
||||
|
||||
del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS']
|
||||
|
||||
def testDiscoveryUrl(self):
|
||||
|
@ -18,7 +18,16 @@ cmake_policy(SET CMP0022 NEW)
|
||||
|
||||
# Options
|
||||
option(tensorflow_VERBOSE "Enable for verbose output" OFF)
|
||||
|
||||
if(WIN32)
|
||||
# BoringSSL is disabled for windows as it currently doesn't build with
|
||||
# MSBuild. (Ninja is required.)
|
||||
option(tensorflow_ENABLE_SSL_SUPPORT "Enable boringssl support" OFF)
|
||||
else()
|
||||
# BoringSSL is enabled for gRPC.
|
||||
option(tensorflow_ENABLE_SSL_SUPPORT "Enable boringssl support" ON)
|
||||
endif()
|
||||
|
||||
option(tensorflow_ENABLE_GRPC_SUPPORT "Enable gRPC support" ON)
|
||||
option(tensorflow_ENABLE_HDFS_SUPPORT "Enable HDFS support" OFF)
|
||||
option(tensorflow_ENABLE_JEMALLOC_SUPPORT "Enable jemalloc support" OFF)
|
||||
|
17
tensorflow/contrib/cmake/external/grpc.cmake
vendored
17
tensorflow/contrib/cmake/external/grpc.cmake
vendored
@ -20,6 +20,10 @@ set(GRPC_BUILD ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc)
|
||||
set(GRPC_TAG d184fa229d75d336aedea0041bd59cb93e7e267f)
|
||||
|
||||
if(WIN32)
|
||||
# We use unsecure gRPC because boringssl does not build on windows
|
||||
set(grpc_TARGET grpc++_unsecure)
|
||||
set(grpc_DEPENDS protobuf zlib)
|
||||
set(grpc_SSL_PROVIDER NONE)
|
||||
if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*")
|
||||
set(grpc_STATIC_LIBRARIES
|
||||
${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/Release/grpc++_unsecure.lib
|
||||
@ -32,9 +36,12 @@ if(WIN32)
|
||||
${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/gpr.lib)
|
||||
endif()
|
||||
else()
|
||||
set(grpc_TARGET grpc++)
|
||||
set(grpc_DEPENDS boringssl protobuf zlib)
|
||||
set(grpc_SSL_PROVIDER module)
|
||||
set(grpc_STATIC_LIBRARIES
|
||||
${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc++_unsecure.a
|
||||
${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc_unsecure.a
|
||||
${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc++.a
|
||||
${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc.a
|
||||
${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libaddress_sorting.a
|
||||
${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/third_party/cares/cares/lib/libcares.a
|
||||
${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgpr.a)
|
||||
@ -44,13 +51,13 @@ add_definitions(-DGRPC_ARES=0)
|
||||
|
||||
ExternalProject_Add(grpc
|
||||
PREFIX grpc
|
||||
DEPENDS protobuf zlib
|
||||
DEPENDS ${grpc_DEPENDS}
|
||||
GIT_REPOSITORY ${GRPC_URL}
|
||||
GIT_TAG ${GRPC_TAG}
|
||||
DOWNLOAD_DIR "${DOWNLOAD_LOCATION}"
|
||||
BUILD_IN_SOURCE 1
|
||||
BUILD_BYPRODUCTS ${grpc_STATIC_LIBRARIES}
|
||||
BUILD_COMMAND ${CMAKE_COMMAND} --build . --config Release --target grpc++_unsecure
|
||||
BUILD_COMMAND ${CMAKE_COMMAND} --build . --config Release --target ${grpc_TARGET}
|
||||
COMMAND ${CMAKE_COMMAND} --build . --config Release --target grpc_cpp_plugin
|
||||
INSTALL_COMMAND ""
|
||||
CMAKE_CACHE_ARGS
|
||||
@ -59,7 +66,7 @@ ExternalProject_Add(grpc
|
||||
-DPROTOBUF_INCLUDE_DIRS:STRING=${PROTOBUF_INCLUDE_DIRS}
|
||||
-DPROTOBUF_LIBRARIES:STRING=${protobuf_STATIC_LIBRARIES}
|
||||
-DZLIB_ROOT:STRING=${ZLIB_INSTALL}
|
||||
-DgRPC_SSL_PROVIDER:STRING=NONE
|
||||
-DgRPC_SSL_PROVIDER:STRING=${grpc_SSL_PROVIDER}
|
||||
)
|
||||
|
||||
# grpc/src/core/ext/census/tracing.c depends on the existence of openssl/rand.h.
|
||||
|
@ -77,6 +77,7 @@ py_library(
|
||||
"//tensorflow/python:device_util",
|
||||
"//tensorflow/python:distribute",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:pywrap_tensorflow",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:variable_scope",
|
||||
@ -590,3 +591,22 @@ cuda_py_test(
|
||||
"notsan",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "metrics_v1_test",
|
||||
srcs = ["metrics_v1_test.py"],
|
||||
additional_deps = [
|
||||
":combinations",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
"//tensorflow/contrib/data/python/ops:batching",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:metrics",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/eager:test",
|
||||
],
|
||||
tags = [
|
||||
"multi_and_single_gpu",
|
||||
"no_pip",
|
||||
],
|
||||
)
|
||||
|
438
tensorflow/contrib/distribute/python/metrics_v1_test.py
Normal file
438
tensorflow/contrib/distribute/python/metrics_v1_test.py
Normal file
@ -0,0 +1,438 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for V1 metrics."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.contrib.data.python.ops import batching
|
||||
from tensorflow.contrib.distribute.python import combinations
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import metrics
|
||||
from tensorflow.python.ops import variables
|
||||
|
||||
|
||||
def _labeled_dataset_fn():
|
||||
# First four batches of x: labels, predictions -> (labels == predictions)
|
||||
# 0: 0, 0 -> True; 1: 1, 1 -> True; 2: 2, 2 -> True; 3: 3, 0 -> False
|
||||
# 4: 4, 1 -> False; 5: 0, 2 -> False; 6: 1, 0 -> False; 7: 2, 1 -> False
|
||||
# 8: 3, 2 -> False; 9: 4, 0 -> False; 10: 0, 1 -> False; 11: 1, 2 -> False
|
||||
# 12: 2, 0 -> False; 13: 3, 1 -> False; 14: 4, 2 -> False; 15: 0, 0 -> True
|
||||
return dataset_ops.Dataset.range(1000).map(
|
||||
lambda x: {"labels": x % 5, "predictions": x % 3}).batch(4)
|
||||
|
||||
|
||||
def _boolean_dataset_fn():
|
||||
# First four batches of labels, predictions: {TP, FP, TN, FN}
|
||||
# with a threshold of 0.5:
|
||||
# T, T -> TP; F, T -> FP; T, F -> FN
|
||||
# F, F -> TN; T, T -> TP; F, T -> FP
|
||||
# T, F -> FN; F, F -> TN; T, T -> TP
|
||||
# F, T -> FP; T, F -> FN; F, F -> TN
|
||||
return dataset_ops.Dataset.from_tensor_slices({
|
||||
"labels": [True, False, True, False],
|
||||
"predictions": [True, True, False, False]}).repeat().batch(3)
|
||||
|
||||
|
||||
def _threshold_dataset_fn():
|
||||
# First four batches of labels, predictions: {TP, FP, TN, FN}
|
||||
# with a threshold of 0.5:
|
||||
# True, 1.0 -> TP; False, .75 -> FP; True, .25 -> FN
|
||||
# False, 0.0 -> TN; True, 1.0 -> TP; False, .75 -> FP
|
||||
# True, .25 -> FN; False, 0.0 -> TN; True, 1.0 -> TP
|
||||
# False, .75 -> FP; True, .25 -> FN; False, 0.0 -> TN
|
||||
return dataset_ops.Dataset.from_tensor_slices({
|
||||
"labels": [True, False, True, False],
|
||||
"predictions": [1.0, 0.75, 0.25, 0.]}).repeat().batch(3)
|
||||
|
||||
|
||||
def _regression_dataset_fn():
|
||||
return dataset_ops.Dataset.from_tensor_slices({
|
||||
"labels": [1., .5, 1., 0.],
|
||||
"predictions": [1., .75, .25, 0.]}).repeat()
|
||||
|
||||
|
||||
def all_combinations():
|
||||
return combinations.combine(
|
||||
distribution=[combinations.default_strategy,
|
||||
combinations.one_device_strategy,
|
||||
combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
combinations.mirrored_strategy_with_two_gpus],
|
||||
mode=["graph"])
|
||||
|
||||
|
||||
# TODO(josh11b): Test metrics.recall_at_top_k, metrics.average_precision_at_k,
|
||||
# metrics.precision_at_k
|
||||
class MetricsV1Test(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def _test_metric(self, distribution, dataset_fn, metric_fn, expected_fn):
|
||||
with ops.Graph().as_default(), distribution.scope():
|
||||
iterator = distribution.distribute_dataset(
|
||||
dataset_fn).make_one_shot_iterator()
|
||||
value, update = distribution.call_for_each_tower(
|
||||
metric_fn, iterator.get_next())
|
||||
update = distribution.group(update)
|
||||
self.evaluate(variables.local_variables_initializer())
|
||||
# TODO(josh11b): Once we switch to using a global batch size for input,
|
||||
# replace "distribution.num_towers" with "1".
|
||||
batches_per_update = distribution.num_towers
|
||||
|
||||
# Update variables using the first `num_towers` batches.
|
||||
self.evaluate(update)
|
||||
self.assertAllClose(expected_fn(batches_per_update), self.evaluate(value),
|
||||
0.001, msg="After first update")
|
||||
|
||||
# Update variables using the second `num_towers` batches.
|
||||
self.evaluate(update)
|
||||
self.assertAllClose(expected_fn(2 * batches_per_update),
|
||||
self.evaluate(value),
|
||||
0.001,
|
||||
msg="After second update")
|
||||
|
||||
if batches_per_update == 1: # Consume 4 input batches
|
||||
self.evaluate(update)
|
||||
self.assertAllClose(expected_fn(3 * batches_per_update),
|
||||
self.evaluate(value),
|
||||
0.001,
|
||||
msg="After third update")
|
||||
self.evaluate(update)
|
||||
self.assertAllClose(expected_fn(4 * batches_per_update),
|
||||
self.evaluate(value),
|
||||
0.001,
|
||||
msg="After fourth update")
|
||||
|
||||
@combinations.generate(all_combinations())
|
||||
def testMean(self, distribution):
|
||||
def _dataset_fn():
|
||||
return dataset_ops.Dataset.range(1000).map(math_ops.to_float).batch(4)
|
||||
|
||||
def _expected_fn(num_batches):
|
||||
# Mean(0..3) = 1.5, Mean(0..7) = 3.5, Mean(0..11) = 5.5, etc.
|
||||
return num_batches * 2 - 0.5
|
||||
|
||||
self._test_metric(distribution, _dataset_fn, metrics.mean, _expected_fn)
|
||||
|
||||
@combinations.generate(all_combinations())
|
||||
def testAccuracy(self, distribution):
|
||||
def _metric_fn(x):
|
||||
labels = x["labels"]
|
||||
predictions = x["predictions"]
|
||||
return metrics.accuracy(labels, predictions)
|
||||
|
||||
def _expected_fn(num_batches):
|
||||
return [3./4, 3./8, 3./12, 4./16][num_batches - 1]
|
||||
|
||||
self._test_metric(
|
||||
distribution, _labeled_dataset_fn, _metric_fn, _expected_fn)
|
||||
|
||||
@combinations.generate(all_combinations())
|
||||
def testMeanPerClassAccuracy(self, distribution):
|
||||
def _metric_fn(x):
|
||||
labels = x["labels"]
|
||||
predictions = x["predictions"]
|
||||
return metrics.mean_per_class_accuracy(
|
||||
labels, predictions, num_classes=5)
|
||||
|
||||
def _expected_fn(num_batches):
|
||||
mean = lambda x: sum(x) / len(x)
|
||||
return [mean([1., 1., 1., 0., 0.]),
|
||||
mean([0.5, 0.5, 0.5, 0., 0.]),
|
||||
mean([1./3, 1./3, 0.5, 0., 0.]),
|
||||
mean([0.5, 1./3, 1./3, 0., 0.])][num_batches - 1]
|
||||
|
||||
self._test_metric(
|
||||
distribution, _labeled_dataset_fn, _metric_fn, _expected_fn)
|
||||
|
||||
@combinations.generate(all_combinations())
|
||||
def testMeanIOU(self, distribution):
|
||||
def _metric_fn(x):
|
||||
labels = x["labels"]
|
||||
predictions = x["predictions"]
|
||||
return metrics.mean_iou(
|
||||
labels, predictions, num_classes=5)
|
||||
|
||||
def _expected_fn(num_batches):
|
||||
mean = lambda x: sum(x) / len(x)
|
||||
return [mean([1./2, 1./1, 1./1, 0.]), # no class 4 in first batch
|
||||
mean([1./4, 1./4, 1./3, 0., 0.]),
|
||||
mean([1./6, 1./6, 1./5, 0., 0.]),
|
||||
mean([2./8, 1./7, 1./7, 0., 0.])][num_batches - 1]
|
||||
|
||||
self._test_metric(
|
||||
distribution, _labeled_dataset_fn, _metric_fn, _expected_fn)
|
||||
|
||||
@combinations.generate(all_combinations())
|
||||
def testMeanTensor(self, distribution):
|
||||
def _dataset_fn():
|
||||
dataset = dataset_ops.Dataset.range(1000).map(math_ops.to_float)
|
||||
# Want to produce a fixed, known shape, so drop remainder when batching.
|
||||
dataset = dataset.apply(batching.batch_and_drop_remainder(4))
|
||||
return dataset
|
||||
|
||||
def _expected_fn(num_batches):
|
||||
# Mean(0, 4, ..., 4 * num_batches - 4) == 2 * num_batches - 2
|
||||
# Mean(1, 5, ..., 4 * num_batches - 3) == 2 * num_batches - 1
|
||||
# Mean(2, 6, ..., 4 * num_batches - 2) == 2 * num_batches
|
||||
# Mean(3, 7, ..., 4 * num_batches - 1) == 2 * num_batches + 1
|
||||
first = 2. * num_batches - 2.
|
||||
return [first, first + 1., first + 2., first + 3.]
|
||||
|
||||
self._test_metric(
|
||||
distribution, _dataset_fn, metrics.mean_tensor, _expected_fn)
|
||||
|
||||
@combinations.generate(all_combinations())
|
||||
def testAUCROC(self, distribution):
|
||||
def _metric_fn(x):
|
||||
labels = x["labels"]
|
||||
predictions = x["predictions"]
|
||||
return metrics.auc(labels, predictions, num_thresholds=8, curve="ROC",
|
||||
summation_method="careful_interpolation")
|
||||
|
||||
def _expected_fn(num_batches):
|
||||
return [0.5, 7./9, 0.8, 0.75][num_batches - 1]
|
||||
|
||||
self._test_metric(
|
||||
distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
|
||||
|
||||
@combinations.generate(all_combinations())
|
||||
def testAUCPR(self, distribution):
|
||||
def _metric_fn(x):
|
||||
labels = x["labels"]
|
||||
predictions = x["predictions"]
|
||||
return metrics.auc(labels, predictions, num_thresholds=8, curve="PR",
|
||||
summation_method="careful_interpolation")
|
||||
|
||||
def _expected_fn(num_batches):
|
||||
return [0.797267, 0.851238, 0.865411, 0.797267][num_batches - 1]
|
||||
|
||||
self._test_metric(
|
||||
distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
|
||||
|
||||
@combinations.generate(all_combinations())
|
||||
def testFalseNegatives(self, distribution):
|
||||
def _metric_fn(x):
|
||||
labels = x["labels"]
|
||||
predictions = x["predictions"]
|
||||
return metrics.false_negatives(labels, predictions)
|
||||
|
||||
def _expected_fn(num_batches):
|
||||
return [1., 1., 2., 3.][num_batches - 1]
|
||||
|
||||
self._test_metric(
|
||||
distribution, _boolean_dataset_fn, _metric_fn, _expected_fn)
|
||||
|
||||
@combinations.generate(all_combinations())
|
||||
def testFalseNegativesAtThresholds(self, distribution):
|
||||
def _metric_fn(x):
|
||||
labels = x["labels"]
|
||||
predictions = x["predictions"]
|
||||
return metrics.false_negatives_at_thresholds(labels, predictions, [.5])
|
||||
|
||||
def _expected_fn(num_batches):
|
||||
return [[1.], [1.], [2.], [3.]][num_batches - 1]
|
||||
|
||||
self._test_metric(
|
||||
distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
|
||||
|
||||
@combinations.generate(all_combinations())
|
||||
def testTrueNegatives(self, distribution):
|
||||
def _metric_fn(x):
|
||||
labels = x["labels"]
|
||||
predictions = x["predictions"]
|
||||
return metrics.true_negatives(labels, predictions)
|
||||
|
||||
def _expected_fn(num_batches):
|
||||
return [0., 1., 2., 3.][num_batches - 1]
|
||||
|
||||
self._test_metric(
|
||||
distribution, _boolean_dataset_fn, _metric_fn, _expected_fn)
|
||||
|
||||
@combinations.generate(all_combinations())
|
||||
def testTrueNegativesAtThresholds(self, distribution):
|
||||
def _metric_fn(x):
|
||||
labels = x["labels"]
|
||||
predictions = x["predictions"]
|
||||
return metrics.true_negatives_at_thresholds(labels, predictions, [.5])
|
||||
|
||||
def _expected_fn(num_batches):
|
||||
return [[0.], [1.], [2.], [3.]][num_batches - 1]
|
||||
|
||||
self._test_metric(
|
||||
distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
|
||||
|
||||
@combinations.generate(all_combinations())
|
||||
def testFalsePositives(self, distribution):
|
||||
def _metric_fn(x):
|
||||
labels = x["labels"]
|
||||
predictions = x["predictions"]
|
||||
return metrics.false_positives(labels, predictions)
|
||||
|
||||
def _expected_fn(num_batches):
|
||||
return [1., 2., 2., 3.][num_batches - 1]
|
||||
|
||||
self._test_metric(
|
||||
distribution, _boolean_dataset_fn, _metric_fn, _expected_fn)
|
||||
|
||||
@combinations.generate(all_combinations())
|
||||
def testFalsePositivesAtThresholds(self, distribution):
|
||||
def _metric_fn(x):
|
||||
labels = x["labels"]
|
||||
predictions = x["predictions"]
|
||||
return metrics.false_positives_at_thresholds(labels, predictions, [.5])
|
||||
|
||||
def _expected_fn(num_batches):
|
||||
return [[1.], [2.], [2.], [3.]][num_batches - 1]
|
||||
|
||||
self._test_metric(
|
||||
distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
|
||||
|
||||
@combinations.generate(all_combinations())
|
||||
def testTruePositives(self, distribution):
|
||||
def _metric_fn(x):
|
||||
labels = x["labels"]
|
||||
predictions = x["predictions"]
|
||||
return metrics.true_positives(labels, predictions)
|
||||
|
||||
def _expected_fn(num_batches):
|
||||
return [1., 2., 3., 3.][num_batches - 1]
|
||||
|
||||
self._test_metric(
|
||||
distribution, _boolean_dataset_fn, _metric_fn, _expected_fn)
|
||||
|
||||
@combinations.generate(all_combinations())
|
||||
def testTruePositivesAtThresholds(self, distribution):
|
||||
def _metric_fn(x):
|
||||
labels = x["labels"]
|
||||
predictions = x["predictions"]
|
||||
return metrics.true_positives_at_thresholds(labels, predictions, [.5])
|
||||
|
||||
def _expected_fn(num_batches):
|
||||
return [[1.], [2.], [3.], [3.]][num_batches - 1]
|
||||
|
||||
self._test_metric(
|
||||
distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
|
||||
|
||||
@combinations.generate(all_combinations())
|
||||
def testPrecision(self, distribution):
|
||||
def _metric_fn(x):
|
||||
labels = x["labels"]
|
||||
predictions = x["predictions"]
|
||||
return metrics.precision(labels, predictions)
|
||||
|
||||
def _expected_fn(num_batches):
|
||||
return [0.5, 0.5, 0.6, 0.5][num_batches - 1]
|
||||
|
||||
self._test_metric(
|
||||
distribution, _boolean_dataset_fn, _metric_fn, _expected_fn)
|
||||
|
||||
@combinations.generate(all_combinations())
|
||||
def testPrecisionAtThreshold(self, distribution):
|
||||
def _metric_fn(x):
|
||||
labels = x["labels"]
|
||||
predictions = x["predictions"]
|
||||
return metrics.precision_at_thresholds(labels, predictions, [0.5])
|
||||
|
||||
def _expected_fn(num_batches):
|
||||
return [[0.5], [0.5], [0.6], [0.5]][num_batches - 1]
|
||||
|
||||
self._test_metric(
|
||||
distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
|
||||
|
||||
@combinations.generate(all_combinations())
|
||||
def testRecall(self, distribution):
|
||||
def _metric_fn(x):
|
||||
labels = x["labels"]
|
||||
predictions = x["predictions"]
|
||||
return metrics.recall(labels, predictions)
|
||||
|
||||
def _expected_fn(num_batches):
|
||||
return [0.5, 2./3, 0.6, 0.5][num_batches - 1]
|
||||
|
||||
self._test_metric(
|
||||
distribution, _boolean_dataset_fn, _metric_fn, _expected_fn)
|
||||
|
||||
@combinations.generate(all_combinations())
|
||||
def testRecallAtThreshold(self, distribution):
|
||||
def _metric_fn(x):
|
||||
labels = x["labels"]
|
||||
predictions = x["predictions"]
|
||||
return metrics.recall_at_thresholds(labels, predictions, [0.5])
|
||||
|
||||
def _expected_fn(num_batches):
|
||||
return [[0.5], [2./3], [0.6], [0.5]][num_batches - 1]
|
||||
|
||||
self._test_metric(
|
||||
distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
|
||||
|
||||
@combinations.generate(all_combinations())
|
||||
def testMeanSquaredError(self, distribution):
|
||||
def _metric_fn(x):
|
||||
labels = x["labels"]
|
||||
predictions = x["predictions"]
|
||||
return metrics.mean_squared_error(labels, predictions)
|
||||
|
||||
def _expected_fn(num_batches):
|
||||
return [0., 1./32, 0.208333, 0.15625][num_batches - 1]
|
||||
|
||||
self._test_metric(
|
||||
distribution, _regression_dataset_fn, _metric_fn, _expected_fn)
|
||||
|
||||
@combinations.generate(all_combinations())
|
||||
def testRootMeanSquaredError(self, distribution):
|
||||
def _metric_fn(x):
|
||||
labels = x["labels"]
|
||||
predictions = x["predictions"]
|
||||
return metrics.root_mean_squared_error(labels, predictions)
|
||||
|
||||
def _expected_fn(num_batches):
|
||||
return [0., 0.176777, 0.456435, 0.395285][num_batches - 1]
|
||||
|
||||
self._test_metric(
|
||||
distribution, _regression_dataset_fn, _metric_fn, _expected_fn)
|
||||
|
||||
@combinations.generate(all_combinations())
|
||||
def testSensitivityAtSpecificity(self, distribution):
|
||||
def _metric_fn(x):
|
||||
labels = x["labels"]
|
||||
predictions = x["predictions"]
|
||||
return metrics.sensitivity_at_specificity(labels, predictions, 0.8)
|
||||
|
||||
def _expected_fn(num_batches):
|
||||
return [0.5, 2./3, 0.6, 0.5][num_batches - 1]
|
||||
|
||||
self._test_metric(
|
||||
distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
|
||||
|
||||
@combinations.generate(all_combinations())
|
||||
def testSpecificityAtSensitivity(self, distribution):
|
||||
def _metric_fn(x):
|
||||
labels = x["labels"]
|
||||
predictions = x["predictions"]
|
||||
return metrics.specificity_at_sensitivity(labels, predictions, 0.95)
|
||||
|
||||
def _expected_fn(num_batches):
|
||||
return [0., 1./3, 0.5, 0.5][num_batches - 1]
|
||||
|
||||
self._test_metric(
|
||||
distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -31,6 +31,7 @@ from tensorflow.python.eager import tape
|
||||
from tensorflow.python.framework import device as tf_device
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.training import coordinator
|
||||
from tensorflow.python.training import device_util
|
||||
@ -343,6 +344,13 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
|
||||
**values.select_device_mirrored(d, kwargs))
|
||||
return values.regroup(updates, values.Mirrored)
|
||||
|
||||
def read_var(self, tower_local_var):
|
||||
"""Read the aggregate value of a tower-local variable."""
|
||||
if isinstance(tower_local_var, values.TowerLocalVariable):
|
||||
return math_ops.add_n(self.unwrap(tower_local_var))
|
||||
assert isinstance(tower_local_var, values.Mirrored)
|
||||
return array_ops.identity(tower_local_var.get())
|
||||
|
||||
def _fetch(self, val, destination, fn):
|
||||
"""Return a copy of `val` or `fn(val)` on `destination`."""
|
||||
if isinstance(val, values.TowerLocalVariable):
|
||||
|
@ -102,6 +102,10 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy):
|
||||
with ops.device(self._device), distribute_lib.UpdateContext(self._device):
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
def read_var(self, tower_local_var):
|
||||
"""Read the aggregate value of a tower-local variable."""
|
||||
return array_ops.identity(tower_local_var)
|
||||
|
||||
def _fetch(self, val, destination, fn):
|
||||
"""Return a copy of `val` or `fn(val)` on `destination`."""
|
||||
with ops.device(self._device):
|
||||
|
@ -29,7 +29,9 @@ from tensorflow.contrib.distributions.python.ops import mvn_diag
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops.distributions import categorical
|
||||
from tensorflow.python.ops.distributions import normal
|
||||
from tensorflow.python.ops.linalg import linear_operator_diag
|
||||
@ -540,5 +542,51 @@ class PadDynamicTest(_PadTest, test.TestCase):
|
||||
return False
|
||||
|
||||
|
||||
class TestMoveDimension(test.TestCase):
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def test_move_dimension_static_shape(self):
|
||||
|
||||
x = random_ops.random_normal(shape=[200, 30, 4, 1, 6])
|
||||
|
||||
x_perm = distribution_util.move_dimension(x, 1, 1)
|
||||
self.assertAllEqual(x_perm.shape.as_list(), [200, 30, 4, 1, 6])
|
||||
|
||||
x_perm = distribution_util.move_dimension(x, 0, 3)
|
||||
self.assertAllEqual(x_perm.shape.as_list(), [30, 4, 1, 200, 6])
|
||||
|
||||
x_perm = distribution_util.move_dimension(x, 0, -2)
|
||||
self.assertAllEqual(x_perm.shape.as_list(), [30, 4, 1, 200, 6])
|
||||
|
||||
x_perm = distribution_util.move_dimension(x, 4, 2)
|
||||
self.assertAllEqual(x_perm.shape.as_list(), [200, 30, 6, 4, 1])
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def test_move_dimension_dynamic_shape(self):
|
||||
|
||||
x_ = random_ops.random_normal(shape=[200, 30, 4, 1, 6])
|
||||
x = array_ops.placeholder_with_default(input=x_, shape=None)
|
||||
|
||||
x_perm = distribution_util.move_dimension(x, 1, 1)
|
||||
self.assertAllEqual(self.evaluate(array_ops.shape(x_perm)),
|
||||
[200, 30, 4, 1, 6])
|
||||
|
||||
x_perm = distribution_util.move_dimension(x, 0, 3)
|
||||
self.assertAllEqual(self.evaluate(array_ops.shape(x_perm)),
|
||||
[30, 4, 1, 200, 6])
|
||||
|
||||
x_perm = distribution_util.move_dimension(x, 0, -2)
|
||||
self.assertAllEqual(self.evaluate(array_ops.shape(x_perm)),
|
||||
[30, 4, 1, 200, 6])
|
||||
|
||||
x_perm = distribution_util.move_dimension(x, 4, 2)
|
||||
self.assertAllEqual(self.evaluate(array_ops.shape(x_perm)),
|
||||
[200, 30, 6, 4, 1])
|
||||
|
||||
x_perm = distribution_util.move_dimension(x, -1, 2)
|
||||
self.assertAllEqual(self.evaluate(array_ops.shape(x_perm)),
|
||||
[200, 30, 6, 4, 1])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -21,12 +21,19 @@ from __future__ import print_function
|
||||
from tensorflow.contrib import linalg
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import smart_cond
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.distributions import distribution as distribution_lib
|
||||
|
||||
# The following two lines are redundant, in a sense. The first enables
|
||||
# good coding practice *within* this file (`util.prefer_static_value`
|
||||
# rather than `prefer_static_value`). The second ensures that users
|
||||
# also get the core utils when they import this file.
|
||||
from tensorflow.python.ops.distributions import util
|
||||
from tensorflow.python.ops.distributions.util import * # pylint: disable=wildcard-import
|
||||
|
||||
|
||||
@ -484,3 +491,75 @@ def pad_mixture_dimensions(x, mixture_distribution, categorical_distribution,
|
||||
def static_value(x):
|
||||
"""Returns the static value of a `Tensor` or `None`."""
|
||||
return tensor_util.constant_value(ops.convert_to_tensor(x))
|
||||
|
||||
|
||||
def move_dimension(x, source_idx, dest_idx):
|
||||
"""Move a single tensor dimension within its shape.
|
||||
|
||||
This is a special case of `tf.transpose()`, which applies
|
||||
arbitrary permutations to tensor dimensions.
|
||||
|
||||
Args:
|
||||
x: Tensor of rank `ndims`.
|
||||
source_idx: Integer index into `x.shape` (negative indexing is
|
||||
supported).
|
||||
dest_idx: Integer index into `x.shape` (negative indexing is
|
||||
supported).
|
||||
|
||||
Returns:
|
||||
x_perm: Tensor of rank `ndims`, in which the dimension at original
|
||||
index `source_idx` has been moved to new index `dest_idx`, with
|
||||
all other dimensions retained in their original order.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
x = tf.placeholder(shape=[200, 30, 4, 1, 6])
|
||||
x_perm = _move_dimension(x, 1, 1) # no-op
|
||||
x_perm = _move_dimension(x, 0, 3) # result shape [30, 4, 1, 200, 6]
|
||||
x_perm = _move_dimension(x, 0, -2) # equivalent to previous
|
||||
x_perm = _move_dimension(x, 4, 2) # result shape [200, 30, 6, 4, 1]
|
||||
```
|
||||
"""
|
||||
ndims = util.prefer_static_rank(x)
|
||||
if isinstance(source_idx, int):
|
||||
dtype = dtypes.int32
|
||||
else:
|
||||
dtype = dtypes.as_dtype(source_idx.dtype)
|
||||
|
||||
# Handle negative indexing. Since ndims might be dynamic, this makes
|
||||
# source_idx and dest_idx also possibly dynamic.
|
||||
if source_idx < 0:
|
||||
source_idx = ndims + source_idx
|
||||
if dest_idx < 0:
|
||||
dest_idx = ndims + dest_idx
|
||||
|
||||
# Construct the appropriate permutation of dimensions, depending
|
||||
# whether the source is before or after the destination.
|
||||
def move_left_permutation():
|
||||
return util.prefer_static_value(
|
||||
array_ops.concat([
|
||||
math_ops.range(0, dest_idx, dtype=dtype),
|
||||
[source_idx],
|
||||
math_ops.range(dest_idx, source_idx, dtype=dtype),
|
||||
math_ops.range(source_idx+1, ndims, dtype=dtype)], axis=0))
|
||||
|
||||
def move_right_permutation():
|
||||
return util.prefer_static_value(
|
||||
array_ops.concat([
|
||||
math_ops.range(0, source_idx, dtype=dtype),
|
||||
math_ops.range(source_idx+1, dest_idx+1, dtype=dtype),
|
||||
[source_idx],
|
||||
math_ops.range(dest_idx+1, ndims, dtype=dtype)], axis=0))
|
||||
|
||||
def x_permuted():
|
||||
return array_ops.transpose(
|
||||
x, perm=smart_cond.smart_cond(source_idx < dest_idx,
|
||||
move_right_permutation,
|
||||
move_left_permutation))
|
||||
|
||||
# One final conditional to handle the special case where source
|
||||
# and destination indices are equal.
|
||||
return smart_cond.smart_cond(math_ops.equal(source_idx, dest_idx),
|
||||
lambda: x,
|
||||
x_permuted)
|
||||
|
@ -529,6 +529,7 @@ def multi_label_head(n_classes,
|
||||
applications, the shape is `[batch_size, n_classes]`.
|
||||
|
||||
Labels can be:
|
||||
|
||||
* A multi-hot tensor of shape `[D0, D1, ... DN, n_classes]`
|
||||
* An integer `SparseTensor` of class indices. The `dense_shape` must be
|
||||
`[D0, D1, ... DN, ?]` and the values within `[0, n_classes)`.
|
||||
|
@ -28,7 +28,6 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import functional_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import tensor_array_ops
|
||||
|
||||
@ -279,13 +278,27 @@ def _assert_increasing(t):
|
||||
return ops.control_dependencies([assert_increasing])
|
||||
|
||||
|
||||
def _check_input_types(t, y0):
|
||||
def _check_input_types(y0, t, dt=None):
|
||||
if not (y0.dtype.is_floating or y0.dtype.is_complex):
|
||||
raise TypeError('`y0` must have a floating point or complex floating '
|
||||
'point dtype')
|
||||
if not t.dtype.is_floating:
|
||||
raise TypeError('`t` must have a floating point dtype')
|
||||
|
||||
if dt is not None and not dt.dtype.is_floating:
|
||||
raise TypeError('`dt` must have a floating point dtype')
|
||||
|
||||
|
||||
def _check_input_sizes(t, dt):
|
||||
if len(t.get_shape().as_list()) > 1:
|
||||
raise ValueError('t must be a 1D tensor')
|
||||
|
||||
if len(dt.get_shape().as_list()) > 1:
|
||||
raise ValueError('t must be a 1D tensor')
|
||||
|
||||
if t.get_shape()[0] != dt.get_shape()[0] + 1:
|
||||
raise ValueError('t and dt have incompatible lengths, must be N and N-1')
|
||||
|
||||
|
||||
def _dopri5(func,
|
||||
y0,
|
||||
@ -510,7 +523,7 @@ def odeint(func,
|
||||
# avoiding the need to pack/unpack in user functions.
|
||||
y0 = ops.convert_to_tensor(y0, name='y0')
|
||||
t = ops.convert_to_tensor(t, preferred_dtype=dtypes.float64, name='t')
|
||||
_check_input_types(t, y0)
|
||||
_check_input_types(y0, t)
|
||||
|
||||
error_dtype = abs(y0).dtype
|
||||
rtol = ops.convert_to_tensor(rtol, dtype=error_dtype, name='rtol')
|
||||
@ -530,24 +543,74 @@ def odeint(func,
|
||||
class _FixedGridIntegrator(six.with_metaclass(abc.ABCMeta)):
|
||||
"""Base class for fixed-grid ODE integrators."""
|
||||
|
||||
def integrate(self, evol_func, y0, time_grid):
|
||||
time_delta_grid = time_grid[1:] - time_grid[:-1]
|
||||
def integrate(self, evol_func, y0, time_grid, dt_grid, steps_on_intervals):
|
||||
"""Returns integrated values of differential equation on the `time grid`.
|
||||
|
||||
scan_func = self._make_scan_func(evol_func)
|
||||
Numerically integrates differential equation defined via time derivative
|
||||
evaluator `evol_func` using fixed time steps specified in dt_grid.
|
||||
|
||||
y_grid = functional_ops.scan(scan_func, (time_grid[:-1], time_delta_grid),
|
||||
y0)
|
||||
return array_ops.concat([[y0], y_grid], axis=0)
|
||||
Args:
|
||||
evol_func: Callable, evaluates time derivative of y at a given time.
|
||||
y0: N-D Tensor holds initial values of the solution.
|
||||
time_grid: 1-D Tensor holding the time points at which the solution
|
||||
will be recorded, must have a floating dtype.
|
||||
dt_grid: 1-D Tensor holds fixed time steps to be used on time_grid
|
||||
intervals. Must be a floating dtype and have one less element than that
|
||||
of the time_grid.
|
||||
steps_on_intervals: 1-D Tensor of integer dtype, must have the same size
|
||||
as dt_grid. Specifies number of steps needed for every interval. Assumes
|
||||
steps_on_intervals * dt_grid == time intervals.
|
||||
|
||||
def _make_scan_func(self, evol_func):
|
||||
Returns:
|
||||
(N+1)-D tensor, where the first dimension corresponds to different
|
||||
time points. Contains the solved value of y for each desired time point in
|
||||
`t`, with the initial value `y0` being the first element along the first
|
||||
dimension.
|
||||
"""
|
||||
|
||||
def scan_func(y, t_and_dt):
|
||||
t, dt = t_and_dt
|
||||
iteration_func = self._make_iteration_func(evol_func, dt_grid)
|
||||
integrate_interval = self._make_interval_integrator(iteration_func,
|
||||
steps_on_intervals)
|
||||
|
||||
num_times = array_ops.size(time_grid)
|
||||
current_time = time_grid[0]
|
||||
solution_array = tensor_array_ops.TensorArray(y0.dtype, num_times)
|
||||
solution_array = solution_array.write(0, y0)
|
||||
|
||||
solution_array, _, _, _ = control_flow_ops.while_loop(
|
||||
lambda _, __, ___, i: i < num_times,
|
||||
integrate_interval,
|
||||
(solution_array, y0, current_time, 1)
|
||||
)
|
||||
solution_array = solution_array.stack()
|
||||
solution_array.set_shape(time_grid.get_shape().concatenate(y0.get_shape()))
|
||||
return solution_array
|
||||
|
||||
def _make_iteration_func(self, evol_func, dt_grid):
|
||||
"""Returns a function that builds operations of a single time step."""
|
||||
|
||||
def iteration_func(y, t, dt_step, interval_step):
|
||||
"""Performs a single time step advance."""
|
||||
dt = dt_grid[interval_step - 1]
|
||||
dy = self._step_func(evol_func, t, dt, y)
|
||||
dy = math_ops.cast(dy, dtype=y.dtype)
|
||||
return y + dy
|
||||
return y + dy, t + dt, dt_step + 1, interval_step
|
||||
|
||||
return scan_func
|
||||
return iteration_func
|
||||
|
||||
def _make_interval_integrator(self, iteration_func, interval_sizes):
|
||||
"""Returns a function that builds operations for interval integration."""
|
||||
|
||||
def integrate_interval(solution_array, y, t, interval_num):
|
||||
"""Integrates y with fixed time step on interval `interval_num`."""
|
||||
y, t, _, _ = control_flow_ops.while_loop(
|
||||
lambda _, __, j, interval_num: j < interval_sizes[interval_num - 1],
|
||||
iteration_func,
|
||||
(y, t, 0, interval_num)
|
||||
)
|
||||
return solution_array.write(interval_num, y), y, t, interval_num + 1
|
||||
|
||||
return integrate_interval
|
||||
|
||||
@abc.abstractmethod
|
||||
def _step_func(self, evol_func, t, dt, y):
|
||||
@ -555,6 +618,7 @@ class _FixedGridIntegrator(six.with_metaclass(abc.ABCMeta)):
|
||||
|
||||
|
||||
class _MidpointFixedGridIntegrator(_FixedGridIntegrator):
|
||||
"""Fixed grid integrator implementing midpoint scheme."""
|
||||
|
||||
def _step_func(self, evol_func, t, dt, y):
|
||||
dt_cast = math_ops.cast(dt, y.dtype)
|
||||
@ -563,6 +627,7 @@ class _MidpointFixedGridIntegrator(_FixedGridIntegrator):
|
||||
|
||||
|
||||
class _RK4FixedGridIntegrator(_FixedGridIntegrator):
|
||||
"""Fixed grid integrator implementing RK4 scheme."""
|
||||
|
||||
def _step_func(self, evol_func, t, dt, y):
|
||||
k1 = evol_func(y, t)
|
||||
@ -575,7 +640,7 @@ class _RK4FixedGridIntegrator(_FixedGridIntegrator):
|
||||
return math_ops.add_n([k1, 2 * k2, 2 * k3, k4]) * (dt_cast / 6)
|
||||
|
||||
|
||||
def odeint_fixed(func, y0, t, method='rk4', name=None):
|
||||
def odeint_fixed(func, y0, t, dt=None, method='rk4', name=None):
|
||||
"""ODE integration on a fixed grid (with no step size control).
|
||||
|
||||
Useful in certain scenarios to avoid the overhead of adaptive step size
|
||||
@ -590,6 +655,14 @@ def odeint_fixed(func, y0, t, method='rk4', name=None):
|
||||
`y`. The initial time point should be the first element of this sequence,
|
||||
and each time must be larger than the previous time. May have any floating
|
||||
point dtype.
|
||||
dt: 0-D or 1-D Tensor providing time step suggestion to be used on time
|
||||
integration intervals in `t`. 1-D Tensor should provide values
|
||||
for all intervals, must have 1 less element than that of `t`.
|
||||
If given a 0-D Tensor, the value is interpreted as time step suggestion
|
||||
same for all intervals. If passed None, then time step is set to be the
|
||||
t[1:] - t[:-1]. Defaults to None. The actual step size is obtained by
|
||||
insuring an integer number of steps per interval, potentially reducing the
|
||||
time step.
|
||||
method: One of 'midpoint' or 'rk4'.
|
||||
name: Optional name for the resulting operation.
|
||||
|
||||
@ -602,16 +675,29 @@ def odeint_fixed(func, y0, t, method='rk4', name=None):
|
||||
Raises:
|
||||
ValueError: Upon caller errors.
|
||||
"""
|
||||
with ops.name_scope(name, 'odeint_fixed', [y0, t]):
|
||||
with ops.name_scope(name, 'odeint_fixed', [y0, t, dt]):
|
||||
t = ops.convert_to_tensor(t, preferred_dtype=dtypes.float64, name='t')
|
||||
y0 = ops.convert_to_tensor(y0, name='y0')
|
||||
_check_input_types(t, y0)
|
||||
|
||||
intervals = t[1:] - t[:-1]
|
||||
if dt is None:
|
||||
dt = intervals
|
||||
dt = ops.convert_to_tensor(dt, preferred_dtype=dtypes.float64, name='dt')
|
||||
|
||||
steps_on_intervals = math_ops.ceil(intervals / dt)
|
||||
dt = intervals / steps_on_intervals
|
||||
steps_on_intervals = math_ops.cast(steps_on_intervals, dtype=dtypes.int32)
|
||||
|
||||
_check_input_types(y0, t, dt)
|
||||
_check_input_sizes(t, dt)
|
||||
|
||||
with _assert_increasing(t):
|
||||
with ops.name_scope(method):
|
||||
if method == 'midpoint':
|
||||
return _MidpointFixedGridIntegrator().integrate(func, y0, t)
|
||||
return _MidpointFixedGridIntegrator().integrate(func, y0, t, dt,
|
||||
steps_on_intervals)
|
||||
elif method == 'rk4':
|
||||
return _RK4FixedGridIntegrator().integrate(func, y0, t)
|
||||
return _RK4FixedGridIntegrator().integrate(func, y0, t, dt,
|
||||
steps_on_intervals)
|
||||
else:
|
||||
raise ValueError('method not supported: {!s}'.format(method))
|
||||
|
@ -242,40 +242,56 @@ class InterpolationTest(test.TestCase):
|
||||
|
||||
class OdeIntFixedTest(test.TestCase):
|
||||
|
||||
def _test_integrate_sine(self, method):
|
||||
def _test_integrate_sine(self, method, t, dt=None):
|
||||
|
||||
def evol_func(y, t):
|
||||
del t
|
||||
return array_ops.stack([y[1], -y[0]])
|
||||
|
||||
y0 = [0., 1.]
|
||||
time_grid = np.linspace(0., 10., 200)
|
||||
y_grid = odes.odeint_fixed(evol_func, y0, time_grid, method=method)
|
||||
y_grid = odes.odeint_fixed(evol_func, y0, t, dt, method=method)
|
||||
|
||||
with self.test_session() as sess:
|
||||
y_grid_array = sess.run(y_grid)
|
||||
|
||||
np.testing.assert_allclose(
|
||||
y_grid_array[:, 0], np.sin(time_grid), rtol=1e-2, atol=1e-2)
|
||||
y_grid_array[:, 0], np.sin(t), rtol=1e-2, atol=1e-2)
|
||||
|
||||
def _test_integrate_gaussian(self, method):
|
||||
def _test_integrate_gaussian(self, method, t, dt=None):
|
||||
|
||||
def evol_func(y, t):
|
||||
return -math_ops.cast(t, dtype=y.dtype) * y[0]
|
||||
|
||||
y0 = [1.]
|
||||
time_grid = np.linspace(0., 2., 100)
|
||||
y_grid = odes.odeint_fixed(evol_func, y0, time_grid, method=method)
|
||||
y_grid = odes.odeint_fixed(evol_func, y0, t, dt, method=method)
|
||||
|
||||
with self.test_session() as sess:
|
||||
y_grid_array = sess.run(y_grid)
|
||||
|
||||
np.testing.assert_allclose(
|
||||
y_grid_array[:, 0], np.exp(-time_grid**2 / 2), rtol=1e-2, atol=1e-2)
|
||||
y_grid_array[:, 0], np.exp(-t**2 / 2), rtol=1e-2, atol=1e-2)
|
||||
|
||||
def _test_integrate_sine_all(self, method):
|
||||
uniform_time_grid = np.linspace(0., 10., 200)
|
||||
non_uniform_time_grid = np.asarray([0.0, 0.4, 4.7, 5.2, 7.0])
|
||||
uniform_dt = 0.02
|
||||
non_uniform_dt = np.asarray([0.01, 0.001, 0.05, 0.03])
|
||||
self._test_integrate_sine(method, uniform_time_grid)
|
||||
self._test_integrate_sine(method, non_uniform_time_grid, uniform_dt)
|
||||
self._test_integrate_sine(method, non_uniform_time_grid, non_uniform_dt)
|
||||
|
||||
def _test_integrate_gaussian_all(self, method):
|
||||
uniform_time_grid = np.linspace(0., 2., 100)
|
||||
non_uniform_time_grid = np.asarray([0.0, 0.1, 0.7, 1.2, 2.0])
|
||||
uniform_dt = 0.01
|
||||
non_uniform_dt = np.asarray([0.01, 0.001, 0.1, 0.03])
|
||||
self._test_integrate_gaussian(method, uniform_time_grid)
|
||||
self._test_integrate_gaussian(method, non_uniform_time_grid, uniform_dt)
|
||||
self._test_integrate_gaussian(method, non_uniform_time_grid, non_uniform_dt)
|
||||
|
||||
def _test_everything(self, method):
|
||||
self._test_integrate_sine(method)
|
||||
self._test_integrate_gaussian(method)
|
||||
self._test_integrate_sine_all(method)
|
||||
self._test_integrate_gaussian_all(method)
|
||||
|
||||
def test_midpoint(self):
|
||||
self._test_everything('midpoint')
|
||||
@ -283,6 +299,21 @@ class OdeIntFixedTest(test.TestCase):
|
||||
def test_rk4(self):
|
||||
self._test_everything('rk4')
|
||||
|
||||
def test_dt_size_exceptions(self):
|
||||
times = np.linspace(0., 2., 100)
|
||||
dt = np.ones(99) * 0.01
|
||||
dt_wrong_length = np.asarray([0.01, 0.001, 0.1, 0.03])
|
||||
dt_wrong_dim = np.expand_dims(np.linspace(0., 2., 99), axis=0)
|
||||
times_wrong_dim = np.expand_dims(np.linspace(0., 2., 100), axis=0)
|
||||
with self.assertRaises(ValueError):
|
||||
self._test_integrate_gaussian('midpoint', times, dt_wrong_length)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
self._test_integrate_gaussian('midpoint', times, dt_wrong_dim)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
self._test_integrate_gaussian('midpoint', times_wrong_dim, dt)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -17,7 +17,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_CONTRIB_LITE_BUILTIN_OPS_H_
|
||||
|
||||
// DO NOT EDIT MANUALLY: This file is automatically generated by
|
||||
// `schema_builtin_ops_header_generator.py`.
|
||||
// `schema/builtin_ops_header/generator.cc`.
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
|
@ -474,8 +474,9 @@ cc_test(
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "resize_bilinear_float_test",
|
||||
srcs = ["resize_bilinear_float_test.cc"],
|
||||
name = "resize_bilinear_test",
|
||||
srcs = ["resize_bilinear_test.cc"],
|
||||
tags = ["tflite_not_portable"],
|
||||
deps = [
|
||||
":optimized_base",
|
||||
":reference_base",
|
||||
|
@ -5722,6 +5722,46 @@ inline void ResizeBilinearGeneric(const float* input_data,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void ResizeBilinearGenericSmallChannel(
|
||||
const T* input_data, const Dims<4>& input_dims, T* output_data,
|
||||
const Dims<4>& output_dims, int32 batches, int32 input_height,
|
||||
int32 input_width, int32 depth, int32 output_height, int32 output_width,
|
||||
float height_scale, float width_scale) {
|
||||
memset(output_data, 0,
|
||||
batches * output_height * output_width * depth * sizeof(T));
|
||||
|
||||
T* output_ptr = &output_data[0];
|
||||
for (int b = 0; b < batches; ++b) {
|
||||
for (int y = 0; y < output_height; ++y) {
|
||||
float input_y = y * height_scale;
|
||||
int32 y0 = static_cast<int32>(std::floor(input_y));
|
||||
int32 y1 = std::min(y0 + 1, input_height - 1);
|
||||
for (int x = 0; x < output_width; ++x) {
|
||||
float input_x = x * width_scale;
|
||||
int32 x0 = static_cast<int32>(input_x);
|
||||
int32 x1 = std::min(x0 + 1, input_width - 1);
|
||||
|
||||
int32 input_offset[4] = {
|
||||
Offset(input_dims, 0, x0, y0, b), Offset(input_dims, 0, x1, y0, b),
|
||||
Offset(input_dims, 0, x0, y1, b), Offset(input_dims, 0, x1, y1, b)};
|
||||
float scale[4] = {(1 - (input_y - y0)) * (1 - (input_x - x0)),
|
||||
(1 - (input_y - y0)) * (input_x - x0),
|
||||
(input_y - y0) * (1 - (input_x - x0)),
|
||||
(input_y - y0) * (input_x - x0)};
|
||||
|
||||
for (int d = 0; d < depth; d++) {
|
||||
const T* input_ptr = &input_data[d];
|
||||
*output_ptr++ = static_cast<T>(input_ptr[input_offset[0]] * scale[0] +
|
||||
input_ptr[input_offset[1]] * scale[1] +
|
||||
input_ptr[input_offset[2]] * scale[2] +
|
||||
input_ptr[input_offset[3]] * scale[3]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
|
||||
const int32* output_size_data,
|
||||
const Dims<4>& output_size_dims, float* output_data,
|
||||
@ -5762,6 +5802,41 @@ inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(prabhumk): This is not a real quantized bilinear. It does not use int8
|
||||
// or int16 arithmetic.
|
||||
inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims,
|
||||
const int32* output_size_data,
|
||||
const Dims<4>& output_size_dims, uint8* output_data,
|
||||
const Dims<4>& output_dims, bool align_corners) {
|
||||
gemmlowp::ScopedProfilingLabel label("ResizeBilinear");
|
||||
int32 batches = MatchingArraySize(input_dims, 3, output_dims, 3);
|
||||
int32 input_height = ArraySize(input_dims, 2);
|
||||
int32 input_width = ArraySize(input_dims, 1);
|
||||
int32 depth = MatchingArraySize(input_dims, 0, output_dims, 0);
|
||||
|
||||
TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 3), 1);
|
||||
TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 2), 1);
|
||||
TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 1), 1);
|
||||
TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 0), 2);
|
||||
int32 output_height = output_size_data[Offset(output_size_dims, 0, 0, 0, 0)];
|
||||
int32 output_width = output_size_data[Offset(output_size_dims, 1, 0, 0, 0)];
|
||||
|
||||
float height_scale =
|
||||
(align_corners && output_height > 1)
|
||||
? (static_cast<float>(input_height - 1) / (output_height - 1))
|
||||
: (static_cast<float>(input_height) / output_height);
|
||||
|
||||
float width_scale =
|
||||
(align_corners && output_width > 1)
|
||||
? (static_cast<float>(input_width - 1) / (output_width - 1))
|
||||
: (static_cast<float>(input_width) / output_width);
|
||||
|
||||
ResizeBilinearGenericSmallChannel<uint8>(
|
||||
input_data, input_dims, output_data, output_dims, batches, input_height,
|
||||
input_width, depth, output_height, output_width, height_scale,
|
||||
width_scale);
|
||||
}
|
||||
|
||||
// legacy, for compatibility with old checked-in code
|
||||
inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
|
||||
const int32* output_size_data,
|
||||
@ -5771,6 +5846,15 @@ inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
|
||||
output_data, output_dims, /*align_corners=*/false);
|
||||
}
|
||||
|
||||
// legacy, for compatibility with old checked-in code
|
||||
inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims,
|
||||
const int32* output_size_data,
|
||||
const Dims<4>& output_size_dims, uint8* output_data,
|
||||
const Dims<4>& output_dims) {
|
||||
ResizeBilinear(input_data, input_dims, output_size_data, output_size_dims,
|
||||
output_data, output_dims, /*align_corners=*/false);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims,
|
||||
const int32* block_shape_data,
|
||||
|
@ -3202,9 +3202,10 @@ inline void Gather(const T* input_data, const Dims<4>& input_dims,
|
||||
}
|
||||
}
|
||||
|
||||
inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
|
||||
template <typename T>
|
||||
inline void ResizeBilinear(const T* input_data, const Dims<4>& input_dims,
|
||||
const int32* output_size_data,
|
||||
const Dims<4>& output_size_dims, float* output_data,
|
||||
const Dims<4>& output_size_dims, T* output_data,
|
||||
const Dims<4>& output_dims, bool align_corners) {
|
||||
int32 batches = MatchingArraySize(input_dims, 3, output_dims, 3);
|
||||
int32 input_height = ArraySize(input_dims, 2);
|
||||
@ -3236,15 +3237,15 @@ inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
|
||||
int32 x0 = static_cast<int32>(std::floor(input_x));
|
||||
int32 x1 = std::min(x0 + 1, input_width - 1);
|
||||
for (int c = 0; c < depth; ++c) {
|
||||
float interpolation = input_data[Offset(input_dims, c, x0, y0, b)] *
|
||||
(1 - (input_y - y0)) *
|
||||
(1 - (input_x - x0)) +
|
||||
input_data[Offset(input_dims, c, x0, y1, b)] *
|
||||
(input_y - y0) * (1 - (input_x - x0)) +
|
||||
input_data[Offset(input_dims, c, x1, y0, b)] *
|
||||
(1 - (input_y - y0)) * (input_x - x0) +
|
||||
input_data[Offset(input_dims, c, x1, y1, b)] *
|
||||
(input_y - y0) * (input_x - x0);
|
||||
T interpolation =
|
||||
static_cast<T>(input_data[Offset(input_dims, c, x0, y0, b)] *
|
||||
(1 - (input_y - y0)) * (1 - (input_x - x0)) +
|
||||
input_data[Offset(input_dims, c, x0, y1, b)] *
|
||||
(input_y - y0) * (1 - (input_x - x0)) +
|
||||
input_data[Offset(input_dims, c, x1, y0, b)] *
|
||||
(1 - (input_y - y0)) * (input_x - x0) +
|
||||
input_data[Offset(input_dims, c, x1, y1, b)] *
|
||||
(input_y - y0) * (input_x - x0));
|
||||
output_data[Offset(output_dims, c, x, y, b)] = interpolation;
|
||||
}
|
||||
}
|
||||
@ -3257,8 +3258,18 @@ inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
|
||||
const int32* output_size_data,
|
||||
const Dims<4>& output_size_dims, float* output_data,
|
||||
const Dims<4>& output_dims) {
|
||||
ResizeBilinear(input_data, input_dims, output_size_data, output_size_dims,
|
||||
output_data, output_dims, /*align_corners=*/false);
|
||||
ResizeBilinear<float>(input_data, input_dims, output_size_data,
|
||||
output_size_dims, output_data, output_dims,
|
||||
/*align_corners=*/false);
|
||||
}
|
||||
|
||||
inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims,
|
||||
const int32* output_size_data,
|
||||
const Dims<4>& output_size_dims, uint8* output_data,
|
||||
const Dims<4>& output_dims) {
|
||||
ResizeBilinear<uint8>(input_data, input_dims, output_size_data,
|
||||
output_size_dims, output_data, output_dims,
|
||||
/*align_corners=*/false);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
@ -24,9 +24,10 @@ limitations under the License.
|
||||
|
||||
namespace tflite {
|
||||
namespace {
|
||||
template <typename T>
|
||||
void TestOneResizeBilinear(int batch, int depth, int input_width,
|
||||
int input_height, int output_width,
|
||||
int output_height) {
|
||||
int output_height, float error_threshold) {
|
||||
Dims<4> input_dims_inference =
|
||||
MakeDimsForInference(depth, input_width, input_height, batch);
|
||||
Dims<4> output_dims_inference =
|
||||
@ -36,14 +37,15 @@ void TestOneResizeBilinear(int batch, int depth, int input_width,
|
||||
const int output_buffer_size =
|
||||
RequiredBufferSizeForDims(output_dims_inference);
|
||||
|
||||
std::vector<float> input_data(input_buffer_size, 0);
|
||||
std::vector<float> reference_output_data(output_buffer_size, 0);
|
||||
std::vector<T> input_data(input_buffer_size, 0);
|
||||
std::vector<T> reference_output_data(output_buffer_size, 0);
|
||||
// Initialize the output data with something other than zero, so we can catch
|
||||
// issue with kernels failing to initialize the output.
|
||||
std::vector<float> output_data(output_buffer_size, 3.1415);
|
||||
std::vector<T> output_data(output_buffer_size, 3);
|
||||
|
||||
const float input_amplitude = 1.f;
|
||||
FillRandom(&input_data, -input_amplitude, input_amplitude);
|
||||
const T min_amplitude = static_cast<T>(0);
|
||||
const T max_amplitude = static_cast<T>(255);
|
||||
FillRandom(&input_data, min_amplitude, max_amplitude);
|
||||
|
||||
Dims<4> output_size_dims = MakeDimsForInference(2, 1, 1, 1);
|
||||
std::vector<int32> output_size_data = {output_height, output_width};
|
||||
@ -58,14 +60,46 @@ void TestOneResizeBilinear(int batch, int depth, int input_width,
|
||||
double sum_diff = 0;
|
||||
float max_abs_val = 0;
|
||||
for (int i = 0; i < output_buffer_size; i++) {
|
||||
sum_diff += std::abs(output_data[i] - reference_output_data[i]);
|
||||
max_abs_val = std::max(max_abs_val, std::abs(reference_output_data[i]));
|
||||
sum_diff += std::abs(static_cast<float>(output_data[i]) -
|
||||
static_cast<float>(reference_output_data[i]));
|
||||
max_abs_val = std::max(
|
||||
max_abs_val, std::abs(static_cast<float>(reference_output_data[i])));
|
||||
}
|
||||
|
||||
if (sum_diff != 0.f) {
|
||||
const float mean_diff = static_cast<float>(sum_diff / output_buffer_size);
|
||||
const float relative_error = std::abs(mean_diff) / max_abs_val;
|
||||
ASSERT_LT(relative_error, 1e-5f);
|
||||
ASSERT_LT(relative_error, error_threshold);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ResizeBilinear, TestResizeBilinear8Bit) {
|
||||
const int kTestsToRun = 100 * 1000;
|
||||
for (int i = 0; i < kTestsToRun; i++) {
|
||||
const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20);
|
||||
const int depth = ExponentialRandomPositiveInt(0.9f, 6, 50);
|
||||
const int input_width = ExponentialRandomPositiveInt(0.9f, 20, 200);
|
||||
const int input_height = ExponentialRandomPositiveInt(0.9f, 20, 200);
|
||||
const int output_width = ExponentialRandomPositiveInt(0.9f, 20, 200);
|
||||
const int output_height = ExponentialRandomPositiveInt(0.9f, 20, 200);
|
||||
|
||||
TestOneResizeBilinear<uint8>(batch, depth, input_width, input_height,
|
||||
output_width, output_height, 0.025);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ResizeBilinear2x2, TestResizeBilinear8Bit) {
|
||||
const int kTestsToRun = 100 * 1000;
|
||||
for (int i = 0; i < kTestsToRun; i++) {
|
||||
const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20);
|
||||
const int depth = ExponentialRandomPositiveInt(0.9f, 6, 50);
|
||||
const int input_width = ExponentialRandomPositiveInt(0.9f, 20, 200);
|
||||
const int input_height = ExponentialRandomPositiveInt(0.9f, 20, 200);
|
||||
const int output_width = input_width * 2;
|
||||
const int output_height = input_height * 2;
|
||||
|
||||
TestOneResizeBilinear<uint8>(batch, depth, input_width, input_height,
|
||||
output_width, output_height, 1e-5);
|
||||
}
|
||||
}
|
||||
|
||||
@ -79,8 +113,8 @@ TEST(ResizeBilinear, TestResizeBilinear) {
|
||||
const int output_width = ExponentialRandomPositiveInt(0.9f, 20, 200);
|
||||
const int output_height = ExponentialRandomPositiveInt(0.9f, 20, 200);
|
||||
|
||||
TestOneResizeBilinear(batch, depth, input_width, input_height, output_width,
|
||||
output_height);
|
||||
TestOneResizeBilinear<float>(batch, depth, input_width, input_height,
|
||||
output_width, output_height, 1e-5);
|
||||
}
|
||||
}
|
||||
|
||||
@ -94,8 +128,8 @@ TEST(ResizeBilinear2x2, TestResizeBilinear) {
|
||||
const int output_width = input_width * 2;
|
||||
const int output_height = input_height * 2;
|
||||
|
||||
TestOneResizeBilinear(batch, depth, input_width, input_height, output_width,
|
||||
output_height);
|
||||
TestOneResizeBilinear<float>(batch, depth, input_width, input_height,
|
||||
output_width, output_height, 1e-5);
|
||||
}
|
||||
}
|
||||
} // namespace
|
@ -121,6 +121,10 @@ class RuntimeShape {
|
||||
}
|
||||
}
|
||||
|
||||
inline void BuildFrom(const std::initializer_list<int> init_list) {
|
||||
BuildFrom<const std::initializer_list<int>>(init_list);
|
||||
}
|
||||
|
||||
// Returns the total count of elements, that is the size when flattened into a
|
||||
// vector.
|
||||
inline int FlatSize() const {
|
||||
|
@ -61,12 +61,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
|
||||
TF_LITE_ENSURE_EQ(context, NumDimensions(size), 1);
|
||||
|
||||
// TODO(ahentz): Our current implementations only support float32.
|
||||
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
|
||||
TF_LITE_ENSURE_EQ(context, size->type, kTfLiteInt32);
|
||||
// ResizeBilinear creates a float tensor even when the input is made of
|
||||
// integers.
|
||||
output->type = kTfLiteFloat32;
|
||||
output->type = input->type;
|
||||
|
||||
if (!IsConstantTensor(size)) {
|
||||
SetTensorToDynamic(output);
|
||||
@ -90,17 +88,24 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
}
|
||||
|
||||
if (output->type == kTfLiteFloat32) {
|
||||
#define TF_LITE_RESIZE_BILINEAR(type) \
|
||||
type::ResizeBilinear(GetTensorData<float>(input), GetTensorDims(input), \
|
||||
GetTensorData<int32>(size), GetTensorDims(size), \
|
||||
GetTensorData<float>(output), GetTensorDims(output), \
|
||||
#define TF_LITE_RESIZE_BILINEAR(type, datatype) \
|
||||
type::ResizeBilinear(GetTensorData<datatype>(input), GetTensorDims(input), \
|
||||
GetTensorData<int32>(size), GetTensorDims(size), \
|
||||
GetTensorData<datatype>(output), GetTensorDims(output), \
|
||||
params->align_corners)
|
||||
|
||||
if (kernel_type == kReference) {
|
||||
TF_LITE_RESIZE_BILINEAR(reference_ops);
|
||||
TF_LITE_RESIZE_BILINEAR(reference_ops, float);
|
||||
}
|
||||
if (kernel_type == kGenericOptimized || kernel_type == kNeonOptimized) {
|
||||
TF_LITE_RESIZE_BILINEAR(optimized_ops);
|
||||
TF_LITE_RESIZE_BILINEAR(optimized_ops, float);
|
||||
}
|
||||
} else if (output->type == kTfLiteUInt8) {
|
||||
if (kernel_type == kReference) {
|
||||
TF_LITE_RESIZE_BILINEAR(reference_ops, uint8_t);
|
||||
}
|
||||
if (kernel_type == kGenericOptimized || kernel_type == kNeonOptimized) {
|
||||
TF_LITE_RESIZE_BILINEAR(optimized_ops, uint8_t);
|
||||
}
|
||||
#undef TF_LITE_RESIZE_BILINEAR
|
||||
} else {
|
||||
|
@ -22,6 +22,7 @@ namespace tflite {
|
||||
namespace {
|
||||
|
||||
using ::testing::ElementsAreArray;
|
||||
using uint8 = std::uint8_t;
|
||||
|
||||
class ResizeBilinearOpModel : public SingleOpModel {
|
||||
public:
|
||||
@ -34,7 +35,7 @@ class ResizeBilinearOpModel : public SingleOpModel {
|
||||
} else {
|
||||
size_ = AddInput({TensorType_INT32, {2}});
|
||||
}
|
||||
output_ = AddOutput(TensorType_FLOAT32); // Always float.
|
||||
output_ = AddOutput(input.type);
|
||||
SetBuiltinOp(BuiltinOperator_RESIZE_BILINEAR,
|
||||
BuiltinOptions_ResizeBilinearOptions,
|
||||
CreateResizeBilinearOptions(builder_).Union());
|
||||
@ -45,12 +46,16 @@ class ResizeBilinearOpModel : public SingleOpModel {
|
||||
}
|
||||
}
|
||||
|
||||
void SetInput(std::initializer_list<float> data) {
|
||||
template <typename T>
|
||||
void SetInput(std::initializer_list<T> data) {
|
||||
PopulateTensor(input_, data);
|
||||
}
|
||||
void SetSize(std::initializer_list<int> data) { PopulateTensor(size_, data); }
|
||||
|
||||
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
|
||||
template <typename T>
|
||||
std::vector<T> GetOutput() {
|
||||
return ExtractVector<T>(output_);
|
||||
}
|
||||
|
||||
private:
|
||||
int input_;
|
||||
@ -60,60 +65,121 @@ class ResizeBilinearOpModel : public SingleOpModel {
|
||||
|
||||
TEST(ResizeBilinearOpTest, HorizontalResize) {
|
||||
ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 1, 2, 1}});
|
||||
m.SetInput({3, 6});
|
||||
m.SetInput<float>({3, 6});
|
||||
m.SetSize({1, 3});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 5, 6})));
|
||||
EXPECT_THAT(m.GetOutput<float>(),
|
||||
ElementsAreArray(ArrayFloatNear({3, 5, 6})));
|
||||
|
||||
ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 1, 2, 1}}, {1, 3});
|
||||
const_m.SetInput({3, 6});
|
||||
const_m.SetInput<float>({3, 6});
|
||||
const_m.Invoke();
|
||||
EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 5, 6})));
|
||||
EXPECT_THAT(const_m.GetOutput<float>(),
|
||||
ElementsAreArray(ArrayFloatNear({3, 5, 6})));
|
||||
}
|
||||
|
||||
TEST(ResizeBilinearOpTest, HorizontalResize8Bit) {
|
||||
ResizeBilinearOpModel m({TensorType_UINT8, {1, 1, 2, 1}});
|
||||
m.SetInput<uint8>({3, 6});
|
||||
m.SetSize({1, 3});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutput<uint8>(),
|
||||
ElementsAreArray(ArrayFloatNear({3, 5, 6})));
|
||||
|
||||
ResizeBilinearOpModel const_m({TensorType_UINT8, {1, 1, 2, 1}}, {1, 3});
|
||||
const_m.SetInput<uint8>({3, 6});
|
||||
const_m.Invoke();
|
||||
EXPECT_THAT(const_m.GetOutput<uint8>(),
|
||||
ElementsAreArray(ArrayFloatNear({3, 5, 6})));
|
||||
}
|
||||
|
||||
TEST(ResizeBilinearOpTest, VerticalResize) {
|
||||
ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 1, 1}});
|
||||
m.SetInput({3, 9});
|
||||
m.SetInput<float>({3, 9});
|
||||
m.SetSize({3, 1});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 7, 9})));
|
||||
EXPECT_THAT(m.GetOutput<float>(),
|
||||
ElementsAreArray(ArrayFloatNear({3, 7, 9})));
|
||||
|
||||
ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 1, 1}}, {3, 1});
|
||||
const_m.SetInput({3, 9});
|
||||
const_m.SetInput<float>({3, 9});
|
||||
const_m.Invoke();
|
||||
EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 7, 9})));
|
||||
EXPECT_THAT(const_m.GetOutput<float>(),
|
||||
ElementsAreArray(ArrayFloatNear({3, 7, 9})));
|
||||
}
|
||||
|
||||
TEST(ResizeBilinearOpTest, VerticalResize8Bit) {
|
||||
ResizeBilinearOpModel m({TensorType_UINT8, {1, 2, 1, 1}});
|
||||
m.SetInput<uint8>({3, 9});
|
||||
m.SetSize({3, 1});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutput<uint8>(),
|
||||
ElementsAreArray(ArrayFloatNear({3, 7, 9})));
|
||||
|
||||
ResizeBilinearOpModel const_m({TensorType_UINT8, {1, 2, 1, 1}}, {3, 1});
|
||||
const_m.SetInput<uint8>({3, 9});
|
||||
const_m.Invoke();
|
||||
EXPECT_THAT(const_m.GetOutput<uint8>(),
|
||||
ElementsAreArray(ArrayFloatNear({3, 7, 9})));
|
||||
}
|
||||
|
||||
TEST(ResizeBilinearOpTest, TwoDimensionalResize) {
|
||||
ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}});
|
||||
m.SetInput({
|
||||
m.SetInput<float>({
|
||||
3, 6, //
|
||||
9, 12 //
|
||||
});
|
||||
m.SetSize({3, 3});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
|
||||
3, 5, 6, //
|
||||
7, 9, 10, //
|
||||
9, 11, 12, //
|
||||
})));
|
||||
EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
|
||||
3, 5, 6, //
|
||||
7, 9, 10, //
|
||||
9, 11, 12, //
|
||||
})));
|
||||
|
||||
ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 2, 1}}, {3, 3});
|
||||
const_m.SetInput({
|
||||
const_m.SetInput<float>({
|
||||
3, 6, //
|
||||
9, 12 //
|
||||
});
|
||||
const_m.Invoke();
|
||||
EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({
|
||||
3, 5, 6, //
|
||||
7, 9, 10, //
|
||||
9, 11, 12, //
|
||||
})));
|
||||
EXPECT_THAT(const_m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
|
||||
3, 5, 6, //
|
||||
7, 9, 10, //
|
||||
9, 11, 12, //
|
||||
})));
|
||||
}
|
||||
|
||||
TEST(ResizeBilinearOpTest, TwoDimensionalResize8Bit) {
|
||||
ResizeBilinearOpModel m({TensorType_UINT8, {1, 2, 2, 1}});
|
||||
m.SetInput<uint8>({
|
||||
3, 6, //
|
||||
9, 12 //
|
||||
});
|
||||
m.SetSize({3, 3});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear({
|
||||
3, 5, 6, //
|
||||
7, 9, 10, //
|
||||
9, 11, 12, //
|
||||
})));
|
||||
|
||||
ResizeBilinearOpModel const_m({TensorType_UINT8, {1, 2, 2, 1}}, {3, 3});
|
||||
const_m.SetInput<uint8>({
|
||||
3, 6, //
|
||||
9, 12 //
|
||||
});
|
||||
const_m.Invoke();
|
||||
EXPECT_THAT(const_m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear({
|
||||
3, 5, 6, //
|
||||
7, 9, 10, //
|
||||
9, 11, 12, //
|
||||
})));
|
||||
}
|
||||
|
||||
TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches) {
|
||||
ResizeBilinearOpModel m({TensorType_FLOAT32, {2, 2, 2, 1}});
|
||||
m.SetInput({
|
||||
m.SetInput<float>({
|
||||
3, 6, //
|
||||
9, 12, //
|
||||
4, 10, //
|
||||
@ -121,60 +187,123 @@ TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches) {
|
||||
});
|
||||
m.SetSize({3, 3});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
|
||||
3, 5, 6, //
|
||||
7, 9, 10, //
|
||||
9, 11, 12, //
|
||||
4, 8, 10, //
|
||||
8, 12, 14, //
|
||||
10, 14, 16, //
|
||||
})));
|
||||
EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
|
||||
3, 5, 6, //
|
||||
7, 9, 10, //
|
||||
9, 11, 12, //
|
||||
4, 8, 10, //
|
||||
8, 12, 14, //
|
||||
10, 14, 16, //
|
||||
})));
|
||||
|
||||
ResizeBilinearOpModel const_m({TensorType_FLOAT32, {2, 2, 2, 1}}, {3, 3});
|
||||
const_m.SetInput({
|
||||
const_m.SetInput<float>({
|
||||
3, 6, //
|
||||
9, 12, //
|
||||
4, 10, //
|
||||
10, 16 //
|
||||
});
|
||||
const_m.Invoke();
|
||||
EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({
|
||||
3, 5, 6, //
|
||||
7, 9, 10, //
|
||||
9, 11, 12, //
|
||||
4, 8, 10, //
|
||||
8, 12, 14, //
|
||||
10, 14, 16, //
|
||||
})));
|
||||
EXPECT_THAT(const_m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
|
||||
3, 5, 6, //
|
||||
7, 9, 10, //
|
||||
9, 11, 12, //
|
||||
4, 8, 10, //
|
||||
8, 12, 14, //
|
||||
10, 14, 16, //
|
||||
})));
|
||||
}
|
||||
|
||||
TEST(ResizeBilinearOpTest, ThreeDimensionalResize) {
|
||||
ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 2, 2}});
|
||||
m.SetInput({
|
||||
m.SetInput<float>({
|
||||
3, 4, 6, 10, //
|
||||
9, 10, 12, 16, //
|
||||
});
|
||||
m.SetSize({3, 3});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
|
||||
3, 4, 5, 8, 6, 10, //
|
||||
7, 8, 9, 12, 10, 14, //
|
||||
9, 10, 11, 14, 12, 16, //
|
||||
})));
|
||||
EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
|
||||
3, 4, 5, 8, 6, 10, //
|
||||
7, 8, 9, 12, 10, 14, //
|
||||
9, 10, 11, 14, 12, 16, //
|
||||
})));
|
||||
|
||||
ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 2, 2}}, {3, 3});
|
||||
const_m.SetInput({
|
||||
const_m.SetInput<float>({
|
||||
3, 4, 6, 10, //
|
||||
9, 10, 12, 16, //
|
||||
});
|
||||
const_m.Invoke();
|
||||
EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({
|
||||
3, 4, 5, 8, 6, 10, //
|
||||
7, 8, 9, 12, 10, 14, //
|
||||
9, 10, 11, 14, 12, 16, //
|
||||
})));
|
||||
EXPECT_THAT(const_m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
|
||||
3, 4, 5, 8, 6, 10, //
|
||||
7, 8, 9, 12, 10, 14, //
|
||||
9, 10, 11, 14, 12, 16, //
|
||||
})));
|
||||
}
|
||||
|
||||
TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches8Bit) {
|
||||
ResizeBilinearOpModel m({TensorType_UINT8, {2, 2, 2, 1}});
|
||||
m.SetInput<uint8>({
|
||||
3, 6, //
|
||||
9, 12, //
|
||||
4, 10, //
|
||||
10, 16 //
|
||||
});
|
||||
m.SetSize({3, 3});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear({
|
||||
3, 5, 6, //
|
||||
7, 9, 10, //
|
||||
9, 11, 12, //
|
||||
4, 8, 10, //
|
||||
8, 12, 14, //
|
||||
10, 13, 16, //
|
||||
})));
|
||||
|
||||
ResizeBilinearOpModel const_m({TensorType_UINT8, {2, 2, 2, 1}}, {3, 3});
|
||||
const_m.SetInput<uint8>({
|
||||
3, 6, //
|
||||
9, 12, //
|
||||
4, 10, //
|
||||
10, 16 //
|
||||
});
|
||||
const_m.Invoke();
|
||||
EXPECT_THAT(const_m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear({
|
||||
3, 5, 6, //
|
||||
7, 9, 10, //
|
||||
9, 11, 12, //
|
||||
4, 8, 10, //
|
||||
8, 12, 14, //
|
||||
10, 13, 16, //
|
||||
})));
|
||||
}
|
||||
|
||||
TEST(ResizeBilinearOpTest, ThreeDimensionalResize8Bit) {
|
||||
ResizeBilinearOpModel m({TensorType_UINT8, {1, 2, 2, 2}});
|
||||
m.SetInput<uint8>({
|
||||
3, 4, 6, 10, //
|
||||
9, 10, 12, 16, //
|
||||
});
|
||||
m.SetSize({3, 3});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear({
|
||||
3, 4, 5, 8, 6, 10, //
|
||||
7, 8, 9, 12, 10, 14, //
|
||||
9, 10, 11, 13, 12, 16, //
|
||||
})));
|
||||
|
||||
ResizeBilinearOpModel const_m({TensorType_UINT8, {1, 2, 2, 2}}, {3, 3});
|
||||
const_m.SetInput<uint8>({
|
||||
3, 4, 6, 10, //
|
||||
9, 10, 12, 16, //
|
||||
});
|
||||
const_m.Invoke();
|
||||
EXPECT_THAT(const_m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear({
|
||||
3, 4, 5, 8, 6, 10, //
|
||||
7, 8, 9, 12, 10, 14, //
|
||||
9, 10, 11, 13, 12, 16, //
|
||||
})));
|
||||
}
|
||||
} // namespace
|
||||
} // namespace tflite
|
||||
|
||||
|
@ -322,12 +322,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
||||
|
||||
*builtin_data = nullptr;
|
||||
switch (op_type) {
|
||||
case BuiltinOperator_CALL:
|
||||
// TODO(aselle): Implement call in BuiltinOptions, but nullptrs are
|
||||
// ok for now, since there is no call implementation either.
|
||||
break;
|
||||
case BuiltinOperator_CUSTOM:
|
||||
break;
|
||||
case BuiltinOperator_CONV_2D: {
|
||||
TfLiteConvParams* params = MallocPOD<TfLiteConvParams>();
|
||||
if (auto* conv_params = op->builtin_options_as_Conv2DOptions()) {
|
||||
@ -343,22 +337,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
||||
*builtin_data = reinterpret_cast<void*>(params);
|
||||
break;
|
||||
}
|
||||
case BuiltinOperator_TANH:
|
||||
case BuiltinOperator_LOGISTIC:
|
||||
case BuiltinOperator_RELU:
|
||||
case BuiltinOperator_RELU_N1_TO_1:
|
||||
case BuiltinOperator_RELU6:
|
||||
case BuiltinOperator_CONCAT_EMBEDDINGS:
|
||||
case BuiltinOperator_EXP:
|
||||
case BuiltinOperator_TOPK_V2:
|
||||
case BuiltinOperator_LOG_SOFTMAX:
|
||||
case BuiltinOperator_DEQUANTIZE:
|
||||
case BuiltinOperator_PRELU:
|
||||
case BuiltinOperator_FLOOR:
|
||||
case BuiltinOperator_NEG:
|
||||
case BuiltinOperator_SIN:
|
||||
case BuiltinOperator_LOG:
|
||||
break;
|
||||
case BuiltinOperator_CAST: {
|
||||
TfLiteCastParams* params = MallocPOD<TfLiteCastParams>();
|
||||
if (auto* schema_params = op->builtin_options_as_CastOptions()) {
|
||||
@ -446,9 +424,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
||||
*builtin_data = reinterpret_cast<void*>(params);
|
||||
break;
|
||||
}
|
||||
case BuiltinOperator_EMBEDDING_LOOKUP:
|
||||
// no-op.
|
||||
break;
|
||||
case BuiltinOperator_EMBEDDING_LOOKUP_SPARSE: {
|
||||
TfLiteEmbeddingLookupSparseParams* params =
|
||||
MallocPOD<TfLiteEmbeddingLookupSparseParams>();
|
||||
@ -580,12 +555,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
||||
*builtin_data = reinterpret_cast<void*>(params);
|
||||
break;
|
||||
}
|
||||
case BuiltinOperator_PAD: {
|
||||
break;
|
||||
}
|
||||
case BuiltinOperator_PADV2: {
|
||||
break;
|
||||
}
|
||||
case BuiltinOperator_RESHAPE: {
|
||||
auto* params = MallocPOD<TfLiteReshapeParams>();
|
||||
if (auto* schema_params = op->builtin_options_as_ReshapeOptions()) {
|
||||
@ -625,15 +594,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
||||
*builtin_data = reinterpret_cast<void*>(params);
|
||||
break;
|
||||
}
|
||||
case BuiltinOperator_SPACE_TO_BATCH_ND: {
|
||||
break;
|
||||
}
|
||||
case BuiltinOperator_BATCH_TO_SPACE_ND: {
|
||||
break;
|
||||
}
|
||||
case BuiltinOperator_TRANSPOSE: {
|
||||
break;
|
||||
}
|
||||
case BuiltinOperator_MEAN: {
|
||||
auto* params = MallocPOD<TfLiteMeanParams>();
|
||||
if (auto* schema_params = op->builtin_options_as_MeanOptions()) {
|
||||
@ -673,10 +633,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
||||
*builtin_data = reinterpret_cast<void*>(params);
|
||||
break;
|
||||
}
|
||||
case BuiltinOperator_MAXIMUM:
|
||||
case BuiltinOperator_MINIMUM: {
|
||||
break;
|
||||
}
|
||||
case BuiltinOperator_ARG_MAX: {
|
||||
auto* params = MallocPOD<TfLiteArgMaxParams>();
|
||||
if (auto* schema_params = op->builtin_options_as_ArgMaxOptions()) {
|
||||
@ -686,18 +642,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
||||
*builtin_data = reinterpret_cast<void*>(params);
|
||||
break;
|
||||
}
|
||||
case BuiltinOperator_GREATER:
|
||||
case BuiltinOperator_GREATER_EQUAL:
|
||||
case BuiltinOperator_LESS:
|
||||
case BuiltinOperator_LESS_EQUAL:
|
||||
case BuiltinOperator_EQUAL:
|
||||
case BuiltinOperator_NOT_EQUAL:
|
||||
case BuiltinOperator_SELECT: {
|
||||
break;
|
||||
}
|
||||
case BuiltinOperator_SLICE: {
|
||||
break;
|
||||
}
|
||||
case BuiltinOperator_TRANSPOSE_CONV: {
|
||||
TfLiteTransposeConvParams* params =
|
||||
MallocPOD<TfLiteTransposeConvParams>();
|
||||
@ -725,10 +669,46 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
||||
error_reporter->Report("DELEGATE op shouldn't exist in model.");
|
||||
return kTfLiteError;
|
||||
}
|
||||
|
||||
// Below are the ops with no builtin_data strcture.
|
||||
case BuiltinOperator_BATCH_TO_SPACE_ND:
|
||||
// TODO(aselle): Implement call in BuiltinOptions, but nullptrs are
|
||||
// ok for now, since there is no call implementation either.
|
||||
case BuiltinOperator_CALL:
|
||||
case BuiltinOperator_CONCAT_EMBEDDINGS:
|
||||
case BuiltinOperator_CUSTOM:
|
||||
case BuiltinOperator_DEQUANTIZE:
|
||||
case BuiltinOperator_EMBEDDING_LOOKUP:
|
||||
case BuiltinOperator_EQUAL:
|
||||
case BuiltinOperator_EXP:
|
||||
case BuiltinOperator_EXPAND_DIMS:
|
||||
case BuiltinOperator_TILE: {
|
||||
case BuiltinOperator_FLOOR:
|
||||
case BuiltinOperator_GREATER:
|
||||
case BuiltinOperator_GREATER_EQUAL:
|
||||
case BuiltinOperator_LESS:
|
||||
case BuiltinOperator_LESS_EQUAL:
|
||||
case BuiltinOperator_LOG:
|
||||
case BuiltinOperator_LOGISTIC:
|
||||
case BuiltinOperator_LOG_SOFTMAX:
|
||||
case BuiltinOperator_MAXIMUM:
|
||||
case BuiltinOperator_MINIMUM:
|
||||
case BuiltinOperator_NEG:
|
||||
case BuiltinOperator_NOT_EQUAL:
|
||||
case BuiltinOperator_PAD:
|
||||
case BuiltinOperator_PADV2:
|
||||
case BuiltinOperator_PRELU:
|
||||
case BuiltinOperator_RELU:
|
||||
case BuiltinOperator_RELU6:
|
||||
case BuiltinOperator_RELU_N1_TO_1:
|
||||
case BuiltinOperator_SELECT:
|
||||
case BuiltinOperator_SIN:
|
||||
case BuiltinOperator_SLICE:
|
||||
case BuiltinOperator_SPACE_TO_BATCH_ND:
|
||||
case BuiltinOperator_TANH:
|
||||
case BuiltinOperator_TILE:
|
||||
case BuiltinOperator_TOPK_V2:
|
||||
case BuiltinOperator_TRANSPOSE:
|
||||
break;
|
||||
}
|
||||
}
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
@ -234,7 +234,10 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
|
||||
next_id++;
|
||||
};
|
||||
|
||||
auto add_add_params = [&add_scalar_int32]() { add_scalar_int32(0); };
|
||||
auto add_add_params = [&add_scalar_int32](void* data) {
|
||||
auto* builtin = reinterpret_cast<TfLiteAddParams*>(data);
|
||||
add_scalar_int32(builtin->activation);
|
||||
};
|
||||
|
||||
auto add_pooling_params = [&add_scalar_int32](void* data) {
|
||||
auto builtin = reinterpret_cast<TfLitePoolParams*>(data);
|
||||
@ -345,11 +348,11 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
|
||||
switch (builtin) {
|
||||
case tflite::BuiltinOperator_ADD:
|
||||
nn_op_type = ANEURALNETWORKS_ADD;
|
||||
add_add_params();
|
||||
add_add_params(node.builtin_data);
|
||||
break;
|
||||
case tflite::BuiltinOperator_MUL:
|
||||
nn_op_type = ANEURALNETWORKS_MUL;
|
||||
add_add_params();
|
||||
add_add_params(node.builtin_data);
|
||||
break;
|
||||
case tflite::BuiltinOperator_AVERAGE_POOL_2D:
|
||||
add_pooling_params(node.builtin_data);
|
||||
|
@ -2,9 +2,11 @@ package(default_visibility = ["//visibility:public"])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts")
|
||||
|
||||
common_copts = [
|
||||
"-Wall",
|
||||
]
|
||||
] + tflite_copts()
|
||||
|
||||
cc_library(
|
||||
name = "profiler",
|
||||
@ -36,12 +38,14 @@ cc_library(
|
||||
name = "time",
|
||||
srcs = ["time.cc"],
|
||||
hdrs = ["time.h"],
|
||||
copts = common_copts,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "profile_summarizer",
|
||||
srcs = ["profile_summarizer.cc"],
|
||||
hdrs = ["profile_summarizer.h"],
|
||||
copts = common_copts,
|
||||
deps = [
|
||||
":profiler",
|
||||
"//tensorflow/contrib/lite:framework",
|
||||
@ -53,6 +57,7 @@ cc_library(
|
||||
cc_test(
|
||||
name = "profile_summarizer_test",
|
||||
srcs = ["profile_summarizer_test.cc"],
|
||||
copts = common_copts,
|
||||
deps = [
|
||||
":profile_summarizer",
|
||||
"//tensorflow/contrib/lite:framework",
|
||||
|
@ -111,37 +111,35 @@ def tensor_name(x):
|
||||
return x.name.split(":")[0]
|
||||
|
||||
|
||||
def toco_convert(input_data,
|
||||
input_tensors,
|
||||
output_tensors,
|
||||
inference_type=lite_constants.FLOAT,
|
||||
inference_input_type=None,
|
||||
input_format=lite_constants.TENSORFLOW_GRAPHDEF,
|
||||
output_format=lite_constants.TFLITE,
|
||||
quantized_input_stats=None,
|
||||
default_ranges_stats=None,
|
||||
drop_control_dependency=True,
|
||||
reorder_across_fake_quant=False,
|
||||
allow_custom_ops=False,
|
||||
change_concat_input_ranges=False,
|
||||
quantize_weights=False,
|
||||
dump_graphviz_dir=None,
|
||||
dump_graphviz_video=False):
|
||||
"""Convert a model using TOCO from `input_format` to `output_format`.
|
||||
def build_toco_convert_protos(input_tensors,
|
||||
output_tensors,
|
||||
inference_type=lite_constants.FLOAT,
|
||||
inference_input_type=None,
|
||||
input_format=lite_constants.TENSORFLOW_GRAPHDEF,
|
||||
output_format=lite_constants.TFLITE,
|
||||
quantized_input_stats=None,
|
||||
default_ranges_stats=None,
|
||||
drop_control_dependency=True,
|
||||
reorder_across_fake_quant=False,
|
||||
allow_custom_ops=False,
|
||||
change_concat_input_ranges=False,
|
||||
quantize_weights=False,
|
||||
dump_graphviz_dir=None,
|
||||
dump_graphviz_video=False):
|
||||
"""Builds protocol buffers describing a conversion of a model using TOCO.
|
||||
|
||||
Typically this is to convert from TensorFlow GraphDef to TFLite, in which
|
||||
case the default `input_format` and `output_format` are sufficient.
|
||||
|
||||
Args:
|
||||
input_data: Input data (i.e. often `sess.graph_def`).
|
||||
input_tensors: List of input tensors. Type and shape are computed using
|
||||
`foo.get_shape()` and `foo.dtype`.
|
||||
output_tensors: List of output tensors (only .name is used from this).
|
||||
inference_type: Target data type of arrays in the output file. Currently
|
||||
must be `{FLOAT, QUANTIZED_UINT8}`. (default FLOAT)
|
||||
must be `{FLOAT, QUANTIZED_UINT8, STRING}`. (default FLOAT)
|
||||
inference_input_type: Target data type of input arrays. Allows for a
|
||||
different type for input arrays in the case of quantization. Currently
|
||||
must be `{FLOAT, QUANTIZED_UINT8}`. (default `inference_type`)
|
||||
must be `{FLOAT, QUANTIZED_UINT8, STRING}`. (default `inference_type`)
|
||||
input_format: Type of data to read Currently must be
|
||||
`{TENSORFLOW_GRAPHDEF}`. (default TENSORFLOW_GRAPHDEF)
|
||||
output_format: Output file format. Currently must be `{TFLITE,
|
||||
@ -180,8 +178,8 @@ def toco_convert(input_data,
|
||||
every graph transformation. (default False)
|
||||
|
||||
Returns:
|
||||
The converted data. For example if TFLite was the destination, then
|
||||
this will be a tflite flatbuffer in a bytes array.
|
||||
model_flags, toco_flags: two protocol buffers describing the conversion
|
||||
process.
|
||||
|
||||
Raises:
|
||||
ValueError: If the input tensor type is unknown
|
||||
@ -204,7 +202,6 @@ def toco_convert(input_data,
|
||||
if dump_graphviz_dir:
|
||||
toco.dump_graphviz_dir = dump_graphviz_dir
|
||||
toco.dump_graphviz_include_video = dump_graphviz_video
|
||||
|
||||
model = _model_flags_pb2.ModelFlags()
|
||||
model.change_concat_input_ranges = change_concat_input_ranges
|
||||
for idx, input_tensor in enumerate(input_tensors):
|
||||
@ -216,7 +213,8 @@ def toco_convert(input_data,
|
||||
tflite_input_type = lite_constants.INT64
|
||||
elif input_tensor.dtype == _dtypes.uint8:
|
||||
tflite_input_type = lite_constants.QUANTIZED_UINT8
|
||||
# TODO(aselle): Insert strings when they are available
|
||||
elif input_tensor.dtype == _dtypes.string:
|
||||
tflite_input_type = lite_constants.STRING
|
||||
else:
|
||||
raise ValueError("Tensors %s not known type %r" % (input_tensor.name,
|
||||
input_tensor.dtype))
|
||||
@ -233,10 +231,35 @@ def toco_convert(input_data,
|
||||
|
||||
for output_tensor in output_tensors:
|
||||
model.output_arrays.append(tensor_name(output_tensor))
|
||||
return model, toco
|
||||
|
||||
# TODO(aselle): Consider handling the case of allowing quantized
|
||||
# inputs to be converted to float (via the toco.inference_input_type field).
|
||||
data = toco_convert_protos(model.SerializeToString(),
|
||||
toco.SerializeToString(),
|
||||
|
||||
def toco_convert(input_data, input_tensors, output_tensors, *args, **kwargs):
|
||||
""""Convert a model using TOCO.
|
||||
|
||||
Typically this function is used to convert from TensorFlow GraphDef to TFLite.
|
||||
Conversion can be customized by providing arguments that are forwarded to
|
||||
`build_toco_convert_protos` (see documentation for details).
|
||||
|
||||
Args:
|
||||
input_data: Input data (i.e. often `sess.graph_def`),
|
||||
input_tensors: List of input tensors. Type and shape are computed using
|
||||
`foo.get_shape()` and `foo.dtype`.
|
||||
output_tensors: List of output tensors (only .name is used from this).
|
||||
*args: See `build_toco_convert_protos`,
|
||||
**kwargs: See `build_toco_convert_protos`.
|
||||
|
||||
Returns:
|
||||
The converted data. For example if TFLite was the destination, then
|
||||
this will be a tflite flatbuffer in a bytes array.
|
||||
|
||||
Raises:
|
||||
Defined in `build_toco_convert_protos`.
|
||||
"""
|
||||
model_flags, toco_flags = build_toco_convert_protos(input_tensors,
|
||||
output_tensors,
|
||||
*args, **kwargs)
|
||||
data = toco_convert_protos(model_flags.SerializeToString(),
|
||||
toco_flags.SerializeToString(),
|
||||
input_data.SerializeToString())
|
||||
return data
|
||||
|
@ -25,6 +25,7 @@ EXPERIMENTAL: APIs here are unstable and likely to change without notice.
|
||||
|
||||
@@FLOAT
|
||||
@@QUANTIZED_UINT8
|
||||
@@STRING
|
||||
@@TFLITE
|
||||
@@GRAPHVIZ_DOT
|
||||
|
||||
@ -38,6 +39,7 @@ from six import PY3
|
||||
from google.protobuf import text_format as _text_format
|
||||
from google.protobuf.message import DecodeError
|
||||
from tensorflow.contrib.lite.python import lite_constants as constants
|
||||
from tensorflow.contrib.lite.python.convert import build_toco_convert_protos # pylint: disable=unused-import
|
||||
from tensorflow.contrib.lite.python.convert import tensor_name
|
||||
from tensorflow.contrib.lite.python.convert import toco_convert
|
||||
from tensorflow.contrib.lite.python.convert import toco_convert_protos # pylint: disable=unused-import
|
||||
@ -65,10 +67,10 @@ class TocoConverter(object):
|
||||
Attributes:
|
||||
|
||||
inference_type: Target data type of arrays in the output file. Currently
|
||||
must be `{FLOAT, QUANTIZED_UINT8}`. (default FLOAT)
|
||||
must be `{FLOAT, QUANTIZED_UINT8, STRING}`. (default FLOAT)
|
||||
inference_input_type: Target data type of input arrays. Allows for a
|
||||
different type for input arrays in the case of quantization. Currently
|
||||
must be `{FLOAT, QUANTIZED_UINT8}`. (default `inference_type`)
|
||||
must be `{FLOAT, QUANTIZED_UINT8, STRING}`. (default `inference_type`)
|
||||
output_format: Output file format. Currently must be `{TFLITE,
|
||||
GRAPHVIZ_DOT}`. (default TFLITE)
|
||||
quantized_input_stats: Dict of strings representing input tensor names
|
||||
|
@ -116,7 +116,8 @@ def _convert_model(flags):
|
||||
"tensors in order to map between names and "
|
||||
"values.".format(",".join(input_arrays)))
|
||||
converter.quantized_input_stats = dict(zip(input_arrays, quant_stats))
|
||||
if flags.default_ranges_min and flags.default_ranges_max:
|
||||
if (flags.default_ranges_min is not None) and (flags.default_ranges_max is
|
||||
not None):
|
||||
converter.default_ranges_stats = (flags.default_ranges_min,
|
||||
flags.default_ranges_max)
|
||||
|
||||
@ -195,7 +196,7 @@ def _check_flags(flags, unparsed):
|
||||
raise ValueError("--std_dev_values, --mean_values must have the same "
|
||||
"number of items")
|
||||
|
||||
if bool(flags.default_ranges_min) != bool(flags.default_ranges_max):
|
||||
if (flags.default_ranges_min is None) != (flags.default_ranges_max is None):
|
||||
raise ValueError("--default_ranges_min and --default_ranges_max must be "
|
||||
"used together")
|
||||
|
||||
@ -233,12 +234,12 @@ def run_main(_):
|
||||
parser.add_argument(
|
||||
"--inference_type",
|
||||
type=str.upper,
|
||||
choices=["FLOAT", "QUANTIZED_UINT8"],
|
||||
choices=["FLOAT", "QUANTIZED_UINT8", "STRING"],
|
||||
help="Target data type of arrays in the output file.")
|
||||
parser.add_argument(
|
||||
"--inference_input_type",
|
||||
type=str.upper,
|
||||
choices=["FLOAT", "QUANTIZED_UINT8"],
|
||||
choices=["FLOAT", "QUANTIZED_UINT8", "STRING"],
|
||||
help=("Target data type of input arrays. Allows for a different type for "
|
||||
"input arrays in the case of quantization."))
|
||||
|
||||
|
@ -39,7 +39,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_CONTRIB_LITE_BUILTIN_OPS_H_
|
||||
|
||||
// DO NOT EDIT MANUALLY: This file is automatically generated by
|
||||
// `schema_builtin_ops_header_generator.py`.
|
||||
// `schema/builtin_ops_header/generator.cc`.
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
|
@ -362,6 +362,8 @@ bool HardcodeMinMax::Run(Model* model, std::size_t op_index) {
|
||||
changed = HardcodeMinMaxForAverageOrMaxPool(model, op);
|
||||
break;
|
||||
|
||||
case OperatorType::kResizeBilinear:
|
||||
case OperatorType::kSlice:
|
||||
case OperatorType::kStridedSlice:
|
||||
case OperatorType::kSqueeze:
|
||||
case OperatorType::kTensorFlowReshape:
|
||||
|
@ -45,12 +45,14 @@ bool SupportsQuantization(const Operator& op) {
|
||||
type == OperatorType::kTensorFlowMinimum ||
|
||||
type == OperatorType::kTensorFlowMaximum ||
|
||||
type == OperatorType::kLogistic || type == OperatorType::kSoftmax ||
|
||||
type == OperatorType::kLogSoftmax ||
|
||||
type == OperatorType::kLogSoftmax || type == OperatorType::kSlice ||
|
||||
type == OperatorType::kResizeBilinear ||
|
||||
type == OperatorType::kTensorFlowSplit || type == OperatorType::kSub ||
|
||||
type == OperatorType::kSqueeze || type == OperatorType::kPad ||
|
||||
type == OperatorType::kPadV2 ||
|
||||
type == OperatorType::kTensorFlowReshape ||
|
||||
type == OperatorType::kTanh || type == OperatorType::kMul ||
|
||||
type == OperatorType::kSpaceToBatchND ||
|
||||
type == OperatorType::kSpaceToDepth ||
|
||||
type == OperatorType::kStridedSlice ||
|
||||
type == OperatorType::kDepthToSpace ||
|
||||
|
@ -920,7 +920,7 @@ void CheckEachArray(const Model& model) {
|
||||
CHECK(array->buffer->type == array->data_type);
|
||||
// The presence of a fixed buffer should imply the presence of a fixed
|
||||
// shape.
|
||||
CHECK(array->has_shape());
|
||||
CHECK(array->has_shape()) << "Invalid array: " << array_entry.first;
|
||||
// Constant buffer should has a valid shape.
|
||||
for (int d : array->shape().dims()) {
|
||||
CHECK_GE(d, 1);
|
||||
|
@ -8,7 +8,7 @@ load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite"
|
||||
load("//tensorflow/contrib/lite:build_def.bzl", "tflite_linkopts")
|
||||
load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts")
|
||||
|
||||
common_copts = ["-Wall"]
|
||||
common_copts = ["-Wall"] + tflite_copts()
|
||||
|
||||
cc_binary(
|
||||
name = "benchmark_model",
|
||||
@ -16,14 +16,11 @@ cc_binary(
|
||||
"benchmark_main.cc",
|
||||
"logging.h",
|
||||
],
|
||||
copts = tflite_copts() + common_copts,
|
||||
linkopts = select({
|
||||
copts = common_copts,
|
||||
linkopts = tflite_linkopts() + select({
|
||||
"//tensorflow:android": [
|
||||
"-pie",
|
||||
"-landroid",
|
||||
"-lm",
|
||||
"-z defs",
|
||||
"-Wl,--exclude-libs,ALL", # Exclude syms in all libs from auto export
|
||||
"-pie", # Android 5.0 and later supports only PIE
|
||||
"-lm", # some builtin ops, e.g., tanh, need -lm
|
||||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
|
@ -53,7 +53,7 @@ tf_cc_binary(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
|
||||
"//tensorflow/core/platform/cloud:gcs_file_system",
|
||||
"@grpc//:grpc++_unsecure",
|
||||
"@grpc//:grpc++",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -58,7 +58,7 @@ cc_library(
|
||||
"//tensorflow/core/distributed_runtime/rpc:async_service_interface",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_call",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
|
||||
"@grpc//:grpc++_unsecure",
|
||||
"@grpc//:grpc++",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
@ -69,7 +69,7 @@ cc_library(
|
||||
hdrs = ["grpc_verbs_service_impl.h"],
|
||||
deps = [
|
||||
":verbs_service_proto_cc",
|
||||
"@grpc//:grpc++_unsecure",
|
||||
"@grpc//:grpc++",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -879,6 +879,7 @@ cc_library(
|
||||
hdrs = [
|
||||
"util/stats_calculator.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
|
10
tensorflow/core/api_def/base_api/api_def_BesselI0e.pbtxt
Normal file
10
tensorflow/core/api_def/base_api/api_def_BesselI0e.pbtxt
Normal file
@ -0,0 +1,10 @@
|
||||
op {
|
||||
graph_op_name: "BesselI0e"
|
||||
summary: "Computes the Bessel i0e function of `x` element-wise."
|
||||
description: <<END
|
||||
Exponentially scaled modified Bessel function of order 0 defined as
|
||||
`bessel_i0e(x) = exp(-abs(x)) bessel_i0(x)`.
|
||||
|
||||
This function is faster and numerically stabler than `bessel_i0(x)`.
|
||||
END
|
||||
}
|
10
tensorflow/core/api_def/base_api/api_def_BesselI1e.pbtxt
Normal file
10
tensorflow/core/api_def/base_api/api_def_BesselI1e.pbtxt
Normal file
@ -0,0 +1,10 @@
|
||||
op {
|
||||
graph_op_name: "BesselI1e"
|
||||
summary: "Computes the Bessel i1e function of `x` element-wise."
|
||||
description: <<END
|
||||
Exponentially scaled modified Bessel function of order 0 defined as
|
||||
`bessel_i1e(x) = exp(-abs(x)) bessel_i1(x)`.
|
||||
|
||||
This function is faster and numerically stabler than `bessel_i1(x)`.
|
||||
END
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "BesselI0e"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -0,0 +1,4 @@
|
||||
op {
|
||||
graph_op_name: "BesselI1e"
|
||||
visibility: HIDDEN
|
||||
}
|
@ -66,6 +66,88 @@ int StepStatsDeviceIndex(StepStats* step_stats, EagerContext* ctx,
|
||||
return 0;
|
||||
}
|
||||
|
||||
// This function expects *handle to point to an existing tensor handle. The
|
||||
// function will (maybe) update the *handle to be pointed to the newly copied
|
||||
// tensor handle.
|
||||
//
|
||||
// The passed in *handle will be Unreffed if it is replaced.
|
||||
Status MaybeCopyInputToExpectedDevice(EagerOperation* op, int i,
|
||||
const Device* expected_device,
|
||||
RunMetadata* run_metadata,
|
||||
TensorHandle** handle) {
|
||||
EagerContext* ctx = op->EagerContext();
|
||||
Device* handle_device = nullptr;
|
||||
TF_RETURN_IF_ERROR((*handle)->Device(&handle_device));
|
||||
const Device* actual_device =
|
||||
handle_device == nullptr ? ctx->HostCPU() : handle_device;
|
||||
|
||||
if (expected_device != actual_device) {
|
||||
switch (ctx->GetDevicePlacementPolicy()) {
|
||||
case DEVICE_PLACEMENT_SILENT_FOR_INT32:
|
||||
// TODO(xpan): See if we could bubble python related error up
|
||||
// to python level.
|
||||
if ((*handle)->dtype == DT_INT32) {
|
||||
// Note: enabling silent copies of int32 tensors to match behavior
|
||||
// of graph mode.
|
||||
break;
|
||||
}
|
||||
TF_FALLTHROUGH_INTENDED;
|
||||
case DEVICE_PLACEMENT_EXPLICIT:
|
||||
return errors::InvalidArgument(
|
||||
"Tensors on conflicting devices:"
|
||||
" cannot compute ",
|
||||
op->Name(), " as input #", i, " was expected to be on ",
|
||||
expected_device->name(), " but is actually on ",
|
||||
actual_device->name(), " (operation running on ",
|
||||
op->Device()->name(), ")",
|
||||
" Tensors can be copied explicitly using .gpu() or .cpu() "
|
||||
"methods,"
|
||||
" or transparently copied by using tf.enable_eager_execution("
|
||||
"device_policy=tfe.DEVICE_PLACEMENT_SILENT). Copying tensors "
|
||||
"between devices"
|
||||
" may slow down your model");
|
||||
case DEVICE_PLACEMENT_WARN:
|
||||
LOG(WARNING) << "before computing " << op->Name() << " input #" << i
|
||||
<< " was expected to be on " << expected_device->name()
|
||||
<< " but is actually on " << actual_device->name()
|
||||
<< " (operation running on " << op->Device()->name()
|
||||
<< "). This triggers a copy which can be a performance "
|
||||
"bottleneck.";
|
||||
break;
|
||||
case DEVICE_PLACEMENT_SILENT: // Do nothing.
|
||||
break;
|
||||
}
|
||||
// We are only here if the policy is warn or silent copies, so we should
|
||||
// trigger a copy.
|
||||
auto pre_time = Env::Default()->NowMicros();
|
||||
TensorHandle* result_handle;
|
||||
Status status = EagerCopyToDevice(
|
||||
*handle, ctx, expected_device->name().c_str(), &result_handle);
|
||||
if (run_metadata != nullptr) {
|
||||
auto* step_stats = run_metadata->mutable_step_stats();
|
||||
MaybeInitializeStepStats(step_stats, ctx);
|
||||
// Record the sending on the source device for now.
|
||||
int device_idx = StepStatsDeviceIndex(step_stats, ctx, handle_device);
|
||||
auto* dev_stats = step_stats->mutable_dev_stats(device_idx);
|
||||
auto* node_stats = dev_stats->add_node_stats();
|
||||
node_stats->set_node_name("_Send");
|
||||
node_stats->set_all_start_micros(pre_time);
|
||||
node_stats->set_op_end_rel_micros(Env::Default()->NowMicros() - pre_time);
|
||||
}
|
||||
if (!status.ok()) {
|
||||
if (result_handle != nullptr) result_handle->Unref();
|
||||
return errors::Internal("Failed copying input tensor from ",
|
||||
actual_device->name(), " to ",
|
||||
expected_device->name(), " in order to run ",
|
||||
op->Name(), ": ", status.error_message());
|
||||
}
|
||||
|
||||
(*handle)->Unref();
|
||||
*handle = result_handle;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ValidateInputTypeAndPlacement(EagerContext* ctx, Device* op_device,
|
||||
EagerOperation* op, const OpKernel* kernel,
|
||||
RunMetadata* run_metadata) {
|
||||
@ -78,76 +160,9 @@ Status ValidateInputTypeAndPlacement(EagerContext* ctx, Device* op_device,
|
||||
for (int i = 0; i < op->Inputs().size(); ++i) {
|
||||
const Device* expected_device =
|
||||
memtypes[i] == HOST_MEMORY ? host_device : op_device;
|
||||
TensorHandle* handle = op->Inputs()[i];
|
||||
Device* handle_device = nullptr;
|
||||
TF_RETURN_IF_ERROR(handle->Device(&handle_device));
|
||||
const Device* actual_device =
|
||||
handle_device == nullptr ? host_device : handle_device;
|
||||
if (expected_device != actual_device) {
|
||||
switch (ctx->GetDevicePlacementPolicy()) {
|
||||
case DEVICE_PLACEMENT_SILENT_FOR_INT32:
|
||||
// TODO(xpan): See if we could bubble python related error up
|
||||
// to python level.
|
||||
if (handle->dtype == DT_INT32) {
|
||||
// Note: enabling silent copies of int32 tensors to match behavior
|
||||
// of graph mode.
|
||||
break;
|
||||
}
|
||||
TF_FALLTHROUGH_INTENDED;
|
||||
case DEVICE_PLACEMENT_EXPLICIT:
|
||||
return errors::InvalidArgument(
|
||||
"Tensors on conflicting devices:"
|
||||
" cannot compute ",
|
||||
op->Name(), " as input #", i, " was expected to be on ",
|
||||
expected_device->name(), " but is actually on ",
|
||||
actual_device->name(), " (operation running on ",
|
||||
op_device->name(), ")",
|
||||
" Tensors can be copied explicitly using .gpu() or .cpu() "
|
||||
"methods,"
|
||||
" or transparently copied by using tf.enable_eager_execution("
|
||||
"device_policy=tfe.DEVICE_PLACEMENT_SILENT). Copying tensors "
|
||||
"between devices"
|
||||
" may slow down your model");
|
||||
case DEVICE_PLACEMENT_WARN:
|
||||
LOG(WARNING) << "before computing " << op->Name() << " input #" << i
|
||||
<< " was expected to be on " << expected_device->name()
|
||||
<< " but is actually on " << actual_device->name()
|
||||
<< " (operation running on " << op_device->name()
|
||||
<< "). This triggers a copy which can be a performance "
|
||||
"bottleneck.";
|
||||
break;
|
||||
case DEVICE_PLACEMENT_SILENT: // Do nothing.
|
||||
break;
|
||||
}
|
||||
// We are only here if the policy is warn or silent copies, so we should
|
||||
// trigger a copy.
|
||||
auto pre_time = Env::Default()->NowMicros();
|
||||
TensorHandle* copied_tensor = nullptr;
|
||||
Status status = EagerCopyToDevice(
|
||||
handle, ctx, expected_device->name().c_str(), &copied_tensor);
|
||||
if (run_metadata != nullptr) {
|
||||
auto* step_stats = run_metadata->mutable_step_stats();
|
||||
MaybeInitializeStepStats(step_stats, ctx);
|
||||
// Record the sending on the source device for now.
|
||||
int device_idx = StepStatsDeviceIndex(step_stats, ctx, handle_device);
|
||||
auto* dev_stats = step_stats->mutable_dev_stats(device_idx);
|
||||
auto* node_stats = dev_stats->add_node_stats();
|
||||
node_stats->set_node_name("_Send");
|
||||
node_stats->set_all_start_micros(pre_time);
|
||||
node_stats->set_op_end_rel_micros(Env::Default()->NowMicros() -
|
||||
pre_time);
|
||||
}
|
||||
if (!status.ok()) {
|
||||
if (copied_tensor != nullptr) copied_tensor->Unref();
|
||||
return errors::Internal("Failed copying input tensor from ",
|
||||
actual_device->name(), " to ",
|
||||
expected_device->name(), " in order to run ",
|
||||
op->Name(), ": ", status.error_message());
|
||||
}
|
||||
handle->Unref();
|
||||
handle = copied_tensor;
|
||||
(*op->MutableInputs())[i] = copied_tensor;
|
||||
}
|
||||
TF_RETURN_IF_ERROR(MaybeCopyInputToExpectedDevice(
|
||||
op, i, expected_device, run_metadata, &((*op->MutableInputs())[i])));
|
||||
tensorflow::TensorHandle* handle = op->Inputs()[i];
|
||||
if (handle->dtype != kernel->input_type(i)) {
|
||||
return errors::InvalidArgument(
|
||||
"cannot compute ", op->Name(), " as input #", i,
|
||||
@ -192,8 +207,8 @@ Status SelectDevice(const NodeDef& ndef, EagerContext* ctx, Device** device) {
|
||||
// Resource4> as the input params to the synthesized function.
|
||||
//
|
||||
// It populates `const_input_types`, `arg_input_types` and
|
||||
// `op_input_to_func_input` based on the reordering results, that the caller can
|
||||
// use them to build an XlaLaunch. On error, it returns NULL, and sets
|
||||
// `op_input_to_func_input` based on the reordering results, that the caller
|
||||
// can use them to build an XlaLaunch. On error, it returns NULL, and sets
|
||||
// `status` accordingly.
|
||||
const FunctionDef* OpToFunction(TFE_Op* op,
|
||||
std::vector<TF_DataType>* const_input_types,
|
||||
@ -221,8 +236,8 @@ const FunctionDef* OpToFunction(TFE_Op* op,
|
||||
const std::unordered_set<string> const_inputs(
|
||||
*XlaOpRegistry::CompileTimeConstantInputs(op->operation.Name()));
|
||||
|
||||
// First add place holders for the input args, so that we can refer to them by
|
||||
// position in the next loop. Also tally up the resource inputs.
|
||||
// First add place holders for the input args, so that we can refer to them
|
||||
// by position in the next loop. Also tally up the resource inputs.
|
||||
int num_resource_inputs = 0;
|
||||
for (int i = 0; i < op_def.input_arg_size(); ++i) {
|
||||
if (op_def.input_arg(i).type() == DT_RESOURCE) {
|
||||
@ -336,8 +351,9 @@ std::unique_ptr<TFE_Op> BuildXlaLaunch(TFE_Op* op, TF_Status* status) {
|
||||
&op_input_to_func_input, status);
|
||||
if (!status.ok()) return nullptr;
|
||||
} else {
|
||||
// TODO(hongm): XlaOpRegistry::CompileTimeConstantInputs() does not work for
|
||||
// functions, so we need to find another way to handle constant inputs.
|
||||
// TODO(hongm): XlaOpRegistry::CompileTimeConstantInputs() does not work
|
||||
// for functions, so we need to find another way to handle constant
|
||||
// inputs.
|
||||
for (int i = const_input_types.size();
|
||||
i < fdef->signature().input_arg_size(); ++i) {
|
||||
VLOG(1) << "Adding Targs from input arg " << i;
|
||||
@ -348,8 +364,9 @@ std::unique_ptr<TFE_Op> BuildXlaLaunch(TFE_Op* op, TF_Status* status) {
|
||||
DCHECK(fdef != nullptr);
|
||||
|
||||
// Copy inputs and their devices.
|
||||
// Since input param reordering may have occurred between `op` and `launch_op`
|
||||
// via `op_input_to_func_input`, adjust the actual inputs accordingly.
|
||||
// Since input param reordering may have occurred between `op` and
|
||||
// `launch_op` via `op_input_to_func_input`, adjust the actual inputs
|
||||
// accordingly.
|
||||
*launch_op->operation.MutableInputs() = op->operation.Inputs();
|
||||
for (TensorHandle* h : launch_op->operation.Inputs()) {
|
||||
h->Ref();
|
||||
@ -545,24 +562,24 @@ Status EagerLocalExecute(EagerOperation* op,
|
||||
Status EagerRemoteExecute(EagerOperation* op, eager::EagerClient* eager_client,
|
||||
uint64 context_id, TensorHandle** retvals,
|
||||
int* num_retvals) {
|
||||
// All tensors must be on the same device.
|
||||
// TODO(nareshmodi): handle silent copies
|
||||
eager::EnqueueRequest request;
|
||||
eager::EnqueueResponse response;
|
||||
|
||||
auto* remote_op = request.add_queue()->mutable_operation();
|
||||
|
||||
for (auto* input : op->Inputs()) {
|
||||
for (int i = 0; i < op->Inputs().size(); i++) {
|
||||
tensorflow::Device* input_device;
|
||||
TF_RETURN_IF_ERROR(input->Device(&input_device));
|
||||
TF_RETURN_IF_ERROR(op->Inputs()[i]->Device(&input_device));
|
||||
if (op->Device() != input_device) {
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
"Ops and inputs are not on the same device. Use "
|
||||
"TFE_TensorHandleCopyToDevice to get ops on the same "
|
||||
"device. Expected device: ",
|
||||
op->Device()->name(), ", Actual device: ", input_device->name());
|
||||
// TODO(b/110044833): It's possible the same tensor gets copied to the
|
||||
// remote device repeatedly.
|
||||
TF_RETURN_IF_ERROR(MaybeCopyInputToExpectedDevice(
|
||||
op, i, op->Device(), /* run_metadata= */ nullptr,
|
||||
&(*op->MutableInputs())[i]));
|
||||
}
|
||||
|
||||
tensorflow::TensorHandle* input = op->Inputs()[i];
|
||||
|
||||
tensorflow::uint64 op_id;
|
||||
int32 output_num;
|
||||
TF_RETURN_IF_ERROR(input->RemoteAddress(&op_id, &output_num));
|
||||
|
@ -42,7 +42,7 @@ load(
|
||||
# Check that tensorflow/core:tensorflow does not depend on grpc.
|
||||
check_deps(
|
||||
name = "core_tensorflow_check_deps",
|
||||
disallowed_deps = ["@grpc//:grpc++_unsecure"],
|
||||
disallowed_deps = ["@grpc//:grpc++"],
|
||||
deps = ["//tensorflow/core:tensorflow"],
|
||||
)
|
||||
|
||||
@ -150,7 +150,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:proto_text",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@grpc//:grpc++_unsecure",
|
||||
"@grpc//:grpc++",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
@ -170,7 +170,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@grpc//:grpc++_unsecure",
|
||||
"@grpc//:grpc++",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -649,7 +649,7 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/core/kernels:dense_update_ops",
|
||||
"//tensorflow/core/kernels:identity_op",
|
||||
"//tensorflow/core/kernels:variable_ops",
|
||||
"@grpc//:grpc++_unsecure",
|
||||
"@grpc//:grpc++",
|
||||
],
|
||||
)
|
||||
|
||||
@ -682,7 +682,7 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_testlib",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache",
|
||||
"@grpc//:grpc++_unsecure",
|
||||
"@grpc//:grpc++",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -65,8 +65,8 @@ cc_library(
|
||||
"//tensorflow/core/distributed_runtime:worker_env",
|
||||
"//tensorflow/core/distributed_runtime/eager:remote_tensor_handle",
|
||||
"//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr",
|
||||
"@grpc//:grpc++_unsecure",
|
||||
"@grpc//:grpc_unsecure",
|
||||
"@grpc",
|
||||
"@grpc//:grpc++",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -41,8 +41,8 @@ cc_library(
|
||||
srcs = ["grpc_util.cc"],
|
||||
hdrs = ["grpc_util.h"],
|
||||
deps = [
|
||||
"@grpc//:grpc_unsecure",
|
||||
"@grpc//:grpc++_unsecure",
|
||||
"@grpc",
|
||||
"@grpc//:grpc++",
|
||||
"//tensorflow/core:lib",
|
||||
# Required to be able to overload TensorResponse parsing.
|
||||
"//tensorflow/core/distributed_runtime:tensor_coding",
|
||||
@ -56,7 +56,7 @@ cc_library(
|
||||
deps = [
|
||||
":grpc_util",
|
||||
"//tensorflow/core:lib",
|
||||
"@grpc//:grpc++_unsecure",
|
||||
"@grpc//:grpc++",
|
||||
],
|
||||
)
|
||||
|
||||
@ -70,7 +70,7 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/distributed_runtime:call_options",
|
||||
"//tensorflow/core/distributed_runtime:tensor_coding",
|
||||
"@grpc//:grpc++_unsecure",
|
||||
"@grpc//:grpc++",
|
||||
],
|
||||
)
|
||||
|
||||
@ -90,7 +90,7 @@ cc_library(
|
||||
"//tensorflow/core/distributed_runtime:tensor_coding",
|
||||
"//tensorflow/core/distributed_runtime:worker_cache_logger",
|
||||
"//tensorflow/core/distributed_runtime:worker_interface",
|
||||
"@grpc//:grpc++_unsecure",
|
||||
"@grpc//:grpc++",
|
||||
],
|
||||
)
|
||||
|
||||
@ -103,7 +103,7 @@ cc_library(
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"@grpc//:grpc++_unsecure",
|
||||
"@grpc//:grpc++",
|
||||
],
|
||||
)
|
||||
|
||||
@ -118,7 +118,7 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:worker_proto_cc",
|
||||
"@grpc//:grpc++_unsecure",
|
||||
"@grpc//:grpc++",
|
||||
],
|
||||
)
|
||||
|
||||
@ -129,7 +129,7 @@ cc_library(
|
||||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"@grpc//:grpc++_unsecure",
|
||||
"@grpc//:grpc++",
|
||||
],
|
||||
)
|
||||
|
||||
@ -180,7 +180,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core/distributed_runtime:worker_cache",
|
||||
"//tensorflow/core/distributed_runtime:worker_env",
|
||||
"//tensorflow/core/distributed_runtime:worker_session",
|
||||
"@grpc//:grpc++_unsecure",
|
||||
"@grpc//:grpc++",
|
||||
],
|
||||
)
|
||||
|
||||
@ -192,7 +192,7 @@ cc_library(
|
||||
":grpc_util",
|
||||
"//tensorflow/core:worker_proto_cc",
|
||||
"//tensorflow/core/distributed_runtime:tensor_coding",
|
||||
"@grpc//:grpc++_unsecure",
|
||||
"@grpc//:grpc++",
|
||||
],
|
||||
)
|
||||
|
||||
@ -225,7 +225,7 @@ cc_library(
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:master_proto_cc",
|
||||
"//tensorflow/core/distributed_runtime:master",
|
||||
"@grpc//:grpc++_unsecure",
|
||||
"@grpc//:grpc++",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
@ -236,7 +236,7 @@ cc_library(
|
||||
hdrs = ["grpc_master_service_impl.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:master_proto_cc",
|
||||
"@grpc//:grpc++_unsecure",
|
||||
"@grpc//:grpc++",
|
||||
],
|
||||
)
|
||||
|
||||
@ -285,8 +285,8 @@ cc_library(
|
||||
"//tensorflow/core/distributed_runtime:server_lib",
|
||||
"//tensorflow/core/distributed_runtime:session_mgr",
|
||||
"//tensorflow/core/distributed_runtime:worker_env",
|
||||
"@grpc//:grpc++_unsecure",
|
||||
"@grpc//:grpc_unsecure",
|
||||
"@grpc",
|
||||
"@grpc//:grpc++",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
@ -313,7 +313,7 @@ tf_cc_binary(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/distributed_runtime:server_lib",
|
||||
"//tensorflow/core/kernels:data_flow",
|
||||
"@grpc//:grpc++_unsecure",
|
||||
"@grpc//:grpc++",
|
||||
],
|
||||
)
|
||||
|
||||
@ -338,7 +338,7 @@ tf_cc_binary(
|
||||
"//tensorflow/core/kernels:matmul_op",
|
||||
"//tensorflow/core/kernels:reduction_ops",
|
||||
"//tensorflow/core/kernels:variable_ops",
|
||||
"@grpc//:grpc++_unsecure",
|
||||
"@grpc//:grpc++",
|
||||
],
|
||||
)
|
||||
|
||||
@ -432,7 +432,7 @@ tf_cc_test(
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core:worker_proto_cc",
|
||||
"@grpc//:grpc++_unsecure",
|
||||
"@grpc//:grpc++",
|
||||
],
|
||||
)
|
||||
|
||||
@ -445,8 +445,8 @@ tf_cc_test(
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:worker_proto_cc",
|
||||
"@grpc//:grpc++_unsecure",
|
||||
"@grpc//:grpc_unsecure",
|
||||
"@grpc",
|
||||
"@grpc//:grpc++",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -12,7 +12,7 @@ cc_library(
|
||||
hdrs = ["grpc_eager_service.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:eager_service_proto_cc",
|
||||
"@grpc//:grpc++_unsecure",
|
||||
"@grpc//:grpc++",
|
||||
],
|
||||
)
|
||||
|
||||
@ -29,7 +29,7 @@ cc_library(
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_state",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
|
||||
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_service",
|
||||
"@grpc//:grpc++_unsecure",
|
||||
"@grpc//:grpc++",
|
||||
],
|
||||
)
|
||||
|
||||
@ -48,7 +48,7 @@ cc_library(
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_service",
|
||||
"@grpc//:grpc++_unsecure",
|
||||
"@grpc//:grpc++",
|
||||
],
|
||||
)
|
||||
|
||||
|
29
tensorflow/core/kernels/cwise_op_bessel.cc
Normal file
29
tensorflow/core/kernels/cwise_op_bessel.cc
Normal file
@ -0,0 +1,29 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/cwise_ops_common.h"
|
||||
|
||||
namespace tensorflow {
|
||||
REGISTER3(UnaryOp, CPU, "BesselI0e", functor::bessel_i0e, Eigen::half, float,
|
||||
double);
|
||||
REGISTER3(UnaryOp, CPU, "BesselI1e", functor::bessel_i1e, Eigen::half, float,
|
||||
double);
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER3(UnaryOp, GPU, "BesselI0e", functor::bessel_i0e, Eigen::half, float,
|
||||
double);
|
||||
REGISTER3(UnaryOp, GPU, "BesselI1e", functor::bessel_i1e, Eigen::half, float,
|
||||
double);
|
||||
#endif
|
||||
} // namespace tensorflow
|
27
tensorflow/core/kernels/cwise_op_bessel.cu.cc
Normal file
27
tensorflow/core/kernels/cwise_op_bessel.cu.cc
Normal file
@ -0,0 +1,27 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace functor {
|
||||
DEFINE_UNARY3(bessel_i0e, Eigen::half, float, double);
|
||||
DEFINE_UNARY3(bessel_i1e, Eigen::half, float, double);
|
||||
} // namespace functor
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA
|
@ -616,6 +616,12 @@ struct acos : base<T, Eigen::internal::scalar_acos_op<T>> {};
|
||||
template <typename T>
|
||||
struct atan : base<T, Eigen::internal::scalar_atan_op<T>> {};
|
||||
|
||||
template <typename T>
|
||||
struct bessel_i0e : base<T, Eigen::internal::scalar_i0e_op<T>> {};
|
||||
|
||||
template <typename T>
|
||||
struct bessel_i1e : base<T, Eigen::internal::scalar_i1e_op<T>> {};
|
||||
|
||||
struct logical_not : base<bool, Eigen::internal::scalar_boolean_not_op<bool>> {
|
||||
};
|
||||
|
||||
|
@ -207,12 +207,6 @@ class IteratorResource : public ResourceBase {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
||||
std::shared_ptr<StatsAggregator> stats_aggregator() {
|
||||
tf_shared_lock l(mu_);
|
||||
return stats_aggregator_;
|
||||
}
|
||||
|
||||
string DebugString() override { return "Iterator resource"; }
|
||||
|
||||
const DataTypeVector& output_dtypes() const { return output_dtypes_; }
|
||||
@ -231,7 +225,6 @@ class IteratorResource : public ResourceBase {
|
||||
FunctionLibraryRuntime* lib_ = nullptr; // not owned.
|
||||
std::shared_ptr<IteratorBase> iterator_;
|
||||
mutex mu_;
|
||||
std::shared_ptr<StatsAggregator> stats_aggregator_ GUARDED_BY(mu_);
|
||||
std::shared_ptr<const FunctionLibraryDefinition> lib_def_ GUARDED_BY(mu_);
|
||||
const DataTypeVector output_dtypes_;
|
||||
const std::vector<PartialTensorShape> output_shapes_;
|
||||
@ -944,9 +937,6 @@ class IteratorGetNextOp : public AsyncOpKernel {
|
||||
|
||||
IteratorContext::Params params;
|
||||
params.env = ctx->env();
|
||||
params.stats_aggregator_getter = [iterator]() {
|
||||
return iterator->stats_aggregator();
|
||||
};
|
||||
params.runner = *(ctx->runner());
|
||||
params.function_library = iterator->function_library();
|
||||
DeviceBase* device = ctx->function_library()->device();
|
||||
@ -995,9 +985,6 @@ class IteratorGetNextSyncOp : public OpKernel {
|
||||
|
||||
IteratorContext::Params params;
|
||||
params.env = ctx->env();
|
||||
params.stats_aggregator_getter = [iterator]() {
|
||||
return iterator->stats_aggregator();
|
||||
};
|
||||
params.runner = *(ctx->runner());
|
||||
params.function_library = iterator->function_library();
|
||||
DeviceBase* device = ctx->function_library()->device();
|
||||
|
@ -10085,6 +10085,52 @@ op {
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "BesselI0e"
|
||||
input_arg {
|
||||
name: "x"
|
||||
type_attr: "T"
|
||||
}
|
||||
output_arg {
|
||||
name: "y"
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_BFLOAT16
|
||||
type: DT_HALF
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "BesselI1e"
|
||||
input_arg {
|
||||
name: "x"
|
||||
type_attr: "T"
|
||||
}
|
||||
output_arg {
|
||||
name: "y"
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_BFLOAT16
|
||||
type: DT_HALF
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "Betainc"
|
||||
input_arg {
|
||||
@ -25468,6 +25514,44 @@ op {
|
||||
type: "func"
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "If"
|
||||
input_arg {
|
||||
name: "cond"
|
||||
type_attr: "Tcond"
|
||||
}
|
||||
input_arg {
|
||||
name: "input"
|
||||
type_list_attr: "Tin"
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
type_list_attr: "Tout"
|
||||
}
|
||||
attr {
|
||||
name: "Tcond"
|
||||
type: "type"
|
||||
}
|
||||
attr {
|
||||
name: "Tin"
|
||||
type: "list(type)"
|
||||
has_minimum: true
|
||||
}
|
||||
attr {
|
||||
name: "Tout"
|
||||
type: "list(type)"
|
||||
has_minimum: true
|
||||
minimum: 1
|
||||
}
|
||||
attr {
|
||||
name: "then_branch"
|
||||
type: "func"
|
||||
}
|
||||
attr {
|
||||
name: "else_branch"
|
||||
type: "func"
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "Igamma"
|
||||
input_arg {
|
||||
|
@ -239,6 +239,10 @@ REGISTER_OP("Acos").UNARY();
|
||||
|
||||
REGISTER_OP("Atan").UNARY();
|
||||
|
||||
REGISTER_OP("BesselI0e").UNARY_REAL();
|
||||
|
||||
REGISTER_OP("BesselI1e").UNARY_REAL();
|
||||
|
||||
#undef UNARY
|
||||
#undef UNARY_REAL
|
||||
#undef UNARY_COMPLEX
|
||||
|
@ -3860,6 +3860,52 @@ op {
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "BesselI0e"
|
||||
input_arg {
|
||||
name: "x"
|
||||
type_attr: "T"
|
||||
}
|
||||
output_arg {
|
||||
name: "y"
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_BFLOAT16
|
||||
type: DT_HALF
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "BesselI1e"
|
||||
input_arg {
|
||||
name: "x"
|
||||
type_attr: "T"
|
||||
}
|
||||
output_arg {
|
||||
name: "y"
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_BFLOAT16
|
||||
type: DT_HALF
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "Betainc"
|
||||
input_arg {
|
||||
@ -12358,7 +12404,6 @@ op {
|
||||
name: "Tin"
|
||||
type: "list(type)"
|
||||
has_minimum: true
|
||||
minimum: 1
|
||||
}
|
||||
attr {
|
||||
name: "Tout"
|
||||
|
@ -4210,69 +4210,6 @@ func Digamma(scope *Scope, x tf.Output) (y tf.Output) {
|
||||
return op.Output(0)
|
||||
}
|
||||
|
||||
// Shuffle dimensions of x according to a permutation.
|
||||
//
|
||||
// The output `y` has the same rank as `x`. The shapes of `x` and `y` satisfy:
|
||||
// `y.shape[i] == x.shape[perm[i]] for i in [0, 1, ..., rank(x) - 1]`
|
||||
func Transpose(scope *Scope, x tf.Output, perm tf.Output) (y tf.Output) {
|
||||
if scope.Err() != nil {
|
||||
return
|
||||
}
|
||||
opspec := tf.OpSpec{
|
||||
Type: "Transpose",
|
||||
Input: []tf.Input{
|
||||
x, perm,
|
||||
},
|
||||
}
|
||||
op := scope.AddOperation(opspec)
|
||||
return op.Output(0)
|
||||
}
|
||||
|
||||
// MinAttr is an optional argument to Min.
|
||||
type MinAttr func(optionalAttr)
|
||||
|
||||
// MinKeepDims sets the optional keep_dims attribute to value.
|
||||
//
|
||||
// value: If true, retain reduced dimensions with length 1.
|
||||
// If not specified, defaults to false
|
||||
func MinKeepDims(value bool) MinAttr {
|
||||
return func(m optionalAttr) {
|
||||
m["keep_dims"] = value
|
||||
}
|
||||
}
|
||||
|
||||
// Computes the minimum of elements across dimensions of a tensor.
|
||||
//
|
||||
// Reduces `input` along the dimensions given in `axis`. Unless
|
||||
// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in
|
||||
// `axis`. If `keep_dims` is true, the reduced dimensions are
|
||||
// retained with length 1.
|
||||
//
|
||||
// Arguments:
|
||||
// input: The tensor to reduce.
|
||||
// axis: The dimensions to reduce. Must be in the range
|
||||
// `[-rank(input), rank(input))`.
|
||||
//
|
||||
// Returns The reduced tensor.
|
||||
func Min(scope *Scope, input tf.Output, axis tf.Output, optional ...MinAttr) (output tf.Output) {
|
||||
if scope.Err() != nil {
|
||||
return
|
||||
}
|
||||
attrs := map[string]interface{}{}
|
||||
for _, a := range optional {
|
||||
a(attrs)
|
||||
}
|
||||
opspec := tf.OpSpec{
|
||||
Type: "Min",
|
||||
Input: []tf.Input{
|
||||
input, axis,
|
||||
},
|
||||
Attrs: attrs,
|
||||
}
|
||||
op := scope.AddOperation(opspec)
|
||||
return op.Output(0)
|
||||
}
|
||||
|
||||
// Conv2DBackpropFilterAttr is an optional argument to Conv2DBackpropFilter.
|
||||
type Conv2DBackpropFilterAttr func(optionalAttr)
|
||||
|
||||
@ -6181,6 +6118,77 @@ func Mod(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
|
||||
return op.Output(0)
|
||||
}
|
||||
|
||||
// Computes offsets of concat inputs within its output.
|
||||
//
|
||||
// For example:
|
||||
//
|
||||
// ```
|
||||
// # 'x' is [2, 2, 7]
|
||||
// # 'y' is [2, 3, 7]
|
||||
// # 'z' is [2, 5, 7]
|
||||
// concat_offset(2, [x, y, z]) => [0, 0, 0], [0, 2, 0], [0, 5, 0]
|
||||
// ```
|
||||
//
|
||||
// This is typically used by gradient computations for a concat operation.
|
||||
//
|
||||
// Arguments:
|
||||
// concat_dim: The dimension along which to concatenate.
|
||||
// shape: The `N` int32 vectors representing shape of tensors being concatenated.
|
||||
//
|
||||
// Returns The `N` int32 vectors representing the starting offset
|
||||
// of input tensors within the concatenated output.
|
||||
func ConcatOffset(scope *Scope, concat_dim tf.Output, shape []tf.Output) (offset []tf.Output) {
|
||||
if scope.Err() != nil {
|
||||
return
|
||||
}
|
||||
opspec := tf.OpSpec{
|
||||
Type: "ConcatOffset",
|
||||
Input: []tf.Input{
|
||||
concat_dim, tf.OutputList(shape),
|
||||
},
|
||||
}
|
||||
op := scope.AddOperation(opspec)
|
||||
if scope.Err() != nil {
|
||||
return
|
||||
}
|
||||
var idx int
|
||||
var err error
|
||||
if offset, idx, err = makeOutputList(op, idx, "offset"); err != nil {
|
||||
scope.UpdateErr("ConcatOffset", err)
|
||||
return
|
||||
}
|
||||
return offset
|
||||
}
|
||||
|
||||
// Compute the lower regularized incomplete Gamma function `Q(a, x)`.
|
||||
//
|
||||
// The lower regularized incomplete Gamma function is defined as:
|
||||
//
|
||||
//
|
||||
// \\(P(a, x) = gamma(a, x) / Gamma(a) = 1 - Q(a, x)\\)
|
||||
//
|
||||
// where
|
||||
//
|
||||
// \\(gamma(a, x) = int_{0}^{x} t^{a-1} exp(-t) dt\\)
|
||||
//
|
||||
// is the lower incomplete Gamma function.
|
||||
//
|
||||
// Note, above `Q(a, x)` (`Igammac`) is the upper regularized complete
|
||||
// Gamma function.
|
||||
func Igamma(scope *Scope, a tf.Output, x tf.Output) (z tf.Output) {
|
||||
if scope.Err() != nil {
|
||||
return
|
||||
}
|
||||
opspec := tf.OpSpec{
|
||||
Type: "Igamma",
|
||||
Input: []tf.Input{
|
||||
a, x,
|
||||
},
|
||||
}
|
||||
op := scope.AddOperation(opspec)
|
||||
return op.Output(0)
|
||||
}
|
||||
|
||||
// DepthToSpaceAttr is an optional argument to DepthToSpace.
|
||||
type DepthToSpaceAttr func(optionalAttr)
|
||||
|
||||
@ -7000,6 +7008,69 @@ func BiasAddV1(scope *Scope, value tf.Output, bias tf.Output) (output tf.Output)
|
||||
return op.Output(0)
|
||||
}
|
||||
|
||||
// Shuffle dimensions of x according to a permutation.
|
||||
//
|
||||
// The output `y` has the same rank as `x`. The shapes of `x` and `y` satisfy:
|
||||
// `y.shape[i] == x.shape[perm[i]] for i in [0, 1, ..., rank(x) - 1]`
|
||||
func Transpose(scope *Scope, x tf.Output, perm tf.Output) (y tf.Output) {
|
||||
if scope.Err() != nil {
|
||||
return
|
||||
}
|
||||
opspec := tf.OpSpec{
|
||||
Type: "Transpose",
|
||||
Input: []tf.Input{
|
||||
x, perm,
|
||||
},
|
||||
}
|
||||
op := scope.AddOperation(opspec)
|
||||
return op.Output(0)
|
||||
}
|
||||
|
||||
// MinAttr is an optional argument to Min.
|
||||
type MinAttr func(optionalAttr)
|
||||
|
||||
// MinKeepDims sets the optional keep_dims attribute to value.
|
||||
//
|
||||
// value: If true, retain reduced dimensions with length 1.
|
||||
// If not specified, defaults to false
|
||||
func MinKeepDims(value bool) MinAttr {
|
||||
return func(m optionalAttr) {
|
||||
m["keep_dims"] = value
|
||||
}
|
||||
}
|
||||
|
||||
// Computes the minimum of elements across dimensions of a tensor.
|
||||
//
|
||||
// Reduces `input` along the dimensions given in `axis`. Unless
|
||||
// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in
|
||||
// `axis`. If `keep_dims` is true, the reduced dimensions are
|
||||
// retained with length 1.
|
||||
//
|
||||
// Arguments:
|
||||
// input: The tensor to reduce.
|
||||
// axis: The dimensions to reduce. Must be in the range
|
||||
// `[-rank(input), rank(input))`.
|
||||
//
|
||||
// Returns The reduced tensor.
|
||||
func Min(scope *Scope, input tf.Output, axis tf.Output, optional ...MinAttr) (output tf.Output) {
|
||||
if scope.Err() != nil {
|
||||
return
|
||||
}
|
||||
attrs := map[string]interface{}{}
|
||||
for _, a := range optional {
|
||||
a(attrs)
|
||||
}
|
||||
opspec := tf.OpSpec{
|
||||
Type: "Min",
|
||||
Input: []tf.Input{
|
||||
input, axis,
|
||||
},
|
||||
Attrs: attrs,
|
||||
}
|
||||
op := scope.AddOperation(opspec)
|
||||
return op.Output(0)
|
||||
}
|
||||
|
||||
// Transforms a Tensor into a serialized TensorProto proto.
|
||||
//
|
||||
// Arguments:
|
||||
@ -11592,60 +11663,6 @@ func SparseDenseCwiseMul(scope *Scope, sp_indices tf.Output, sp_values tf.Output
|
||||
return op.Output(0)
|
||||
}
|
||||
|
||||
// ResizeAreaAttr is an optional argument to ResizeArea.
|
||||
type ResizeAreaAttr func(optionalAttr)
|
||||
|
||||
// ResizeAreaAlignCorners sets the optional align_corners attribute to value.
|
||||
//
|
||||
// value: If true, the centers of the 4 corner pixels of the input and output tensors are
|
||||
// aligned, preserving the values at the corner pixels. Defaults to false.
|
||||
// If not specified, defaults to false
|
||||
func ResizeAreaAlignCorners(value bool) ResizeAreaAttr {
|
||||
return func(m optionalAttr) {
|
||||
m["align_corners"] = value
|
||||
}
|
||||
}
|
||||
|
||||
// Resize `images` to `size` using area interpolation.
|
||||
//
|
||||
// Input images can be of different types but output images are always float.
|
||||
//
|
||||
// The range of pixel values for the output image might be slightly different
|
||||
// from the range for the input image because of limited numerical precision.
|
||||
// To guarantee an output range, for example `[0.0, 1.0]`, apply
|
||||
// `tf.clip_by_value` to the output.
|
||||
//
|
||||
// Each output pixel is computed by first transforming the pixel's footprint into
|
||||
// the input tensor and then averaging the pixels that intersect the footprint. An
|
||||
// input pixel's contribution to the average is weighted by the fraction of its
|
||||
// area that intersects the footprint. This is the same as OpenCV's INTER_AREA.
|
||||
//
|
||||
// Arguments:
|
||||
// images: 4-D with shape `[batch, height, width, channels]`.
|
||||
// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The
|
||||
// new size for the images.
|
||||
//
|
||||
// Returns 4-D with shape
|
||||
// `[batch, new_height, new_width, channels]`.
|
||||
func ResizeArea(scope *Scope, images tf.Output, size tf.Output, optional ...ResizeAreaAttr) (resized_images tf.Output) {
|
||||
if scope.Err() != nil {
|
||||
return
|
||||
}
|
||||
attrs := map[string]interface{}{}
|
||||
for _, a := range optional {
|
||||
a(attrs)
|
||||
}
|
||||
opspec := tf.OpSpec{
|
||||
Type: "ResizeArea",
|
||||
Input: []tf.Input{
|
||||
images, size,
|
||||
},
|
||||
Attrs: attrs,
|
||||
}
|
||||
op := scope.AddOperation(opspec)
|
||||
return op.Output(0)
|
||||
}
|
||||
|
||||
// 2D real-valued fast Fourier transform.
|
||||
//
|
||||
// Computes the 2-dimensional discrete Fourier transform of a real-valued signal
|
||||
@ -13635,170 +13652,6 @@ func TopK(scope *Scope, input tf.Output, k int64, optional ...TopKAttr) (values
|
||||
return op.Output(0), op.Output(1)
|
||||
}
|
||||
|
||||
// ComplexAttr is an optional argument to Complex.
|
||||
type ComplexAttr func(optionalAttr)
|
||||
|
||||
// ComplexTout sets the optional Tout attribute to value.
|
||||
// If not specified, defaults to DT_COMPLEX64
|
||||
func ComplexTout(value tf.DataType) ComplexAttr {
|
||||
return func(m optionalAttr) {
|
||||
m["Tout"] = value
|
||||
}
|
||||
}
|
||||
|
||||
// Converts two real numbers to a complex number.
|
||||
//
|
||||
// Given a tensor `real` representing the real part of a complex number, and a
|
||||
// tensor `imag` representing the imaginary part of a complex number, this
|
||||
// operation returns complex numbers elementwise of the form \\(a + bj\\), where
|
||||
// *a* represents the `real` part and *b* represents the `imag` part.
|
||||
//
|
||||
// The input tensors `real` and `imag` must have the same shape.
|
||||
//
|
||||
// For example:
|
||||
//
|
||||
// ```
|
||||
// # tensor 'real' is [2.25, 3.25]
|
||||
// # tensor `imag` is [4.75, 5.75]
|
||||
// tf.complex(real, imag) ==> [[2.25 + 4.75j], [3.25 + 5.75j]]
|
||||
// ```
|
||||
func Complex(scope *Scope, real tf.Output, imag tf.Output, optional ...ComplexAttr) (out tf.Output) {
|
||||
if scope.Err() != nil {
|
||||
return
|
||||
}
|
||||
attrs := map[string]interface{}{}
|
||||
for _, a := range optional {
|
||||
a(attrs)
|
||||
}
|
||||
opspec := tf.OpSpec{
|
||||
Type: "Complex",
|
||||
Input: []tf.Input{
|
||||
real, imag,
|
||||
},
|
||||
Attrs: attrs,
|
||||
}
|
||||
op := scope.AddOperation(opspec)
|
||||
return op.Output(0)
|
||||
}
|
||||
|
||||
// ImagAttr is an optional argument to Imag.
|
||||
type ImagAttr func(optionalAttr)
|
||||
|
||||
// ImagTout sets the optional Tout attribute to value.
|
||||
// If not specified, defaults to DT_FLOAT
|
||||
func ImagTout(value tf.DataType) ImagAttr {
|
||||
return func(m optionalAttr) {
|
||||
m["Tout"] = value
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the imaginary part of a complex number.
|
||||
//
|
||||
// Given a tensor `input` of complex numbers, this operation returns a tensor of
|
||||
// type `float` that is the imaginary part of each element in `input`. All
|
||||
// elements in `input` must be complex numbers of the form \\(a + bj\\), where *a*
|
||||
// is the real part and *b* is the imaginary part returned by this operation.
|
||||
//
|
||||
// For example:
|
||||
//
|
||||
// ```
|
||||
// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j]
|
||||
// tf.imag(input) ==> [4.75, 5.75]
|
||||
// ```
|
||||
func Imag(scope *Scope, input tf.Output, optional ...ImagAttr) (output tf.Output) {
|
||||
if scope.Err() != nil {
|
||||
return
|
||||
}
|
||||
attrs := map[string]interface{}{}
|
||||
for _, a := range optional {
|
||||
a(attrs)
|
||||
}
|
||||
opspec := tf.OpSpec{
|
||||
Type: "Imag",
|
||||
Input: []tf.Input{
|
||||
input,
|
||||
},
|
||||
Attrs: attrs,
|
||||
}
|
||||
op := scope.AddOperation(opspec)
|
||||
return op.Output(0)
|
||||
}
|
||||
|
||||
// Computes the maximum along segments of a tensor.
|
||||
//
|
||||
// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
|
||||
// segments.
|
||||
//
|
||||
// Computes a tensor such that
|
||||
// \\(output_i = \max_j(data_j)\\) where `max` is over `j` such
|
||||
// that `segment_ids[j] == i`.
|
||||
//
|
||||
// If the max is empty for a given segment ID `i`, `output[i] = 0`.
|
||||
//
|
||||
// <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
|
||||
// <img style="width:100%" src="https://www.tensorflow.org/images/SegmentMax.png" alt>
|
||||
// </div>
|
||||
//
|
||||
// Arguments:
|
||||
//
|
||||
// segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
|
||||
// first dimension. Values should be sorted and can be repeated.
|
||||
//
|
||||
// Returns Has same shape as data, except for dimension 0 which
|
||||
// has size `k`, the number of segments.
|
||||
func SegmentMax(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf.Output) {
|
||||
if scope.Err() != nil {
|
||||
return
|
||||
}
|
||||
opspec := tf.OpSpec{
|
||||
Type: "SegmentMax",
|
||||
Input: []tf.Input{
|
||||
data, segment_ids,
|
||||
},
|
||||
}
|
||||
op := scope.AddOperation(opspec)
|
||||
return op.Output(0)
|
||||
}
|
||||
|
||||
// Computes hyperbolic tangent of `x` element-wise.
|
||||
func Tanh(scope *Scope, x tf.Output) (y tf.Output) {
|
||||
if scope.Err() != nil {
|
||||
return
|
||||
}
|
||||
opspec := tf.OpSpec{
|
||||
Type: "Tanh",
|
||||
Input: []tf.Input{
|
||||
x,
|
||||
},
|
||||
}
|
||||
op := scope.AddOperation(opspec)
|
||||
return op.Output(0)
|
||||
}
|
||||
|
||||
// Creates a dataset that skips `count` elements from the `input_dataset`.
|
||||
//
|
||||
// Arguments:
|
||||
//
|
||||
// count: A scalar representing the number of elements from the `input_dataset`
|
||||
// that should be skipped. If count is -1, skips everything.
|
||||
//
|
||||
//
|
||||
func SkipDataset(scope *Scope, input_dataset tf.Output, count tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) {
|
||||
if scope.Err() != nil {
|
||||
return
|
||||
}
|
||||
attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes}
|
||||
opspec := tf.OpSpec{
|
||||
Type: "SkipDataset",
|
||||
Input: []tf.Input{
|
||||
input_dataset, count,
|
||||
},
|
||||
Attrs: attrs,
|
||||
}
|
||||
op := scope.AddOperation(opspec)
|
||||
return op.Output(0)
|
||||
}
|
||||
|
||||
// Compute the Hurwitz zeta function \\(\zeta(x, q)\\).
|
||||
//
|
||||
// The Hurwitz zeta function is defined as:
|
||||
@ -14064,49 +13917,6 @@ func Minimum(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
|
||||
return op.Output(0)
|
||||
}
|
||||
|
||||
// RealAttr is an optional argument to Real.
|
||||
type RealAttr func(optionalAttr)
|
||||
|
||||
// RealTout sets the optional Tout attribute to value.
|
||||
// If not specified, defaults to DT_FLOAT
|
||||
func RealTout(value tf.DataType) RealAttr {
|
||||
return func(m optionalAttr) {
|
||||
m["Tout"] = value
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the real part of a complex number.
|
||||
//
|
||||
// Given a tensor `input` of complex numbers, this operation returns a tensor of
|
||||
// type `float` that is the real part of each element in `input`. All elements in
|
||||
// `input` must be complex numbers of the form \\(a + bj\\), where *a* is the real
|
||||
// part returned by this operation and *b* is the imaginary part.
|
||||
//
|
||||
// For example:
|
||||
//
|
||||
// ```
|
||||
// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j]
|
||||
// tf.real(input) ==> [-2.25, 3.25]
|
||||
// ```
|
||||
func Real(scope *Scope, input tf.Output, optional ...RealAttr) (output tf.Output) {
|
||||
if scope.Err() != nil {
|
||||
return
|
||||
}
|
||||
attrs := map[string]interface{}{}
|
||||
for _, a := range optional {
|
||||
a(attrs)
|
||||
}
|
||||
opspec := tf.OpSpec{
|
||||
Type: "Real",
|
||||
Input: []tf.Input{
|
||||
input,
|
||||
},
|
||||
Attrs: attrs,
|
||||
}
|
||||
op := scope.AddOperation(opspec)
|
||||
return op.Output(0)
|
||||
}
|
||||
|
||||
// AudioSummaryAttr is an optional argument to AudioSummary.
|
||||
type AudioSummaryAttr func(optionalAttr)
|
||||
|
||||
@ -19704,6 +19514,267 @@ func OrderedMapIncompleteSize(scope *Scope, dtypes []tf.DataType, optional ...Or
|
||||
return op.Output(0)
|
||||
}
|
||||
|
||||
// ComplexAttr is an optional argument to Complex.
|
||||
type ComplexAttr func(optionalAttr)
|
||||
|
||||
// ComplexTout sets the optional Tout attribute to value.
|
||||
// If not specified, defaults to DT_COMPLEX64
|
||||
func ComplexTout(value tf.DataType) ComplexAttr {
|
||||
return func(m optionalAttr) {
|
||||
m["Tout"] = value
|
||||
}
|
||||
}
|
||||
|
||||
// Converts two real numbers to a complex number.
|
||||
//
|
||||
// Given a tensor `real` representing the real part of a complex number, and a
|
||||
// tensor `imag` representing the imaginary part of a complex number, this
|
||||
// operation returns complex numbers elementwise of the form \\(a + bj\\), where
|
||||
// *a* represents the `real` part and *b* represents the `imag` part.
|
||||
//
|
||||
// The input tensors `real` and `imag` must have the same shape.
|
||||
//
|
||||
// For example:
|
||||
//
|
||||
// ```
|
||||
// # tensor 'real' is [2.25, 3.25]
|
||||
// # tensor `imag` is [4.75, 5.75]
|
||||
// tf.complex(real, imag) ==> [[2.25 + 4.75j], [3.25 + 5.75j]]
|
||||
// ```
|
||||
func Complex(scope *Scope, real tf.Output, imag tf.Output, optional ...ComplexAttr) (out tf.Output) {
|
||||
if scope.Err() != nil {
|
||||
return
|
||||
}
|
||||
attrs := map[string]interface{}{}
|
||||
for _, a := range optional {
|
||||
a(attrs)
|
||||
}
|
||||
opspec := tf.OpSpec{
|
||||
Type: "Complex",
|
||||
Input: []tf.Input{
|
||||
real, imag,
|
||||
},
|
||||
Attrs: attrs,
|
||||
}
|
||||
op := scope.AddOperation(opspec)
|
||||
return op.Output(0)
|
||||
}
|
||||
|
||||
// ImagAttr is an optional argument to Imag.
|
||||
type ImagAttr func(optionalAttr)
|
||||
|
||||
// ImagTout sets the optional Tout attribute to value.
|
||||
// If not specified, defaults to DT_FLOAT
|
||||
func ImagTout(value tf.DataType) ImagAttr {
|
||||
return func(m optionalAttr) {
|
||||
m["Tout"] = value
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the imaginary part of a complex number.
|
||||
//
|
||||
// Given a tensor `input` of complex numbers, this operation returns a tensor of
|
||||
// type `float` that is the imaginary part of each element in `input`. All
|
||||
// elements in `input` must be complex numbers of the form \\(a + bj\\), where *a*
|
||||
// is the real part and *b* is the imaginary part returned by this operation.
|
||||
//
|
||||
// For example:
|
||||
//
|
||||
// ```
|
||||
// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j]
|
||||
// tf.imag(input) ==> [4.75, 5.75]
|
||||
// ```
|
||||
func Imag(scope *Scope, input tf.Output, optional ...ImagAttr) (output tf.Output) {
|
||||
if scope.Err() != nil {
|
||||
return
|
||||
}
|
||||
attrs := map[string]interface{}{}
|
||||
for _, a := range optional {
|
||||
a(attrs)
|
||||
}
|
||||
opspec := tf.OpSpec{
|
||||
Type: "Imag",
|
||||
Input: []tf.Input{
|
||||
input,
|
||||
},
|
||||
Attrs: attrs,
|
||||
}
|
||||
op := scope.AddOperation(opspec)
|
||||
return op.Output(0)
|
||||
}
|
||||
|
||||
// Computes the maximum along segments of a tensor.
|
||||
//
|
||||
// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
|
||||
// segments.
|
||||
//
|
||||
// Computes a tensor such that
|
||||
// \\(output_i = \max_j(data_j)\\) where `max` is over `j` such
|
||||
// that `segment_ids[j] == i`.
|
||||
//
|
||||
// If the max is empty for a given segment ID `i`, `output[i] = 0`.
|
||||
//
|
||||
// <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
|
||||
// <img style="width:100%" src="https://www.tensorflow.org/images/SegmentMax.png" alt>
|
||||
// </div>
|
||||
//
|
||||
// Arguments:
|
||||
//
|
||||
// segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
|
||||
// first dimension. Values should be sorted and can be repeated.
|
||||
//
|
||||
// Returns Has same shape as data, except for dimension 0 which
|
||||
// has size `k`, the number of segments.
|
||||
func SegmentMax(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf.Output) {
|
||||
if scope.Err() != nil {
|
||||
return
|
||||
}
|
||||
opspec := tf.OpSpec{
|
||||
Type: "SegmentMax",
|
||||
Input: []tf.Input{
|
||||
data, segment_ids,
|
||||
},
|
||||
}
|
||||
op := scope.AddOperation(opspec)
|
||||
return op.Output(0)
|
||||
}
|
||||
|
||||
// Computes hyperbolic tangent of `x` element-wise.
|
||||
func Tanh(scope *Scope, x tf.Output) (y tf.Output) {
|
||||
if scope.Err() != nil {
|
||||
return
|
||||
}
|
||||
opspec := tf.OpSpec{
|
||||
Type: "Tanh",
|
||||
Input: []tf.Input{
|
||||
x,
|
||||
},
|
||||
}
|
||||
op := scope.AddOperation(opspec)
|
||||
return op.Output(0)
|
||||
}
|
||||
|
||||
// Creates a dataset that skips `count` elements from the `input_dataset`.
|
||||
//
|
||||
// Arguments:
|
||||
//
|
||||
// count: A scalar representing the number of elements from the `input_dataset`
|
||||
// that should be skipped. If count is -1, skips everything.
|
||||
//
|
||||
//
|
||||
func SkipDataset(scope *Scope, input_dataset tf.Output, count tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) {
|
||||
if scope.Err() != nil {
|
||||
return
|
||||
}
|
||||
attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes}
|
||||
opspec := tf.OpSpec{
|
||||
Type: "SkipDataset",
|
||||
Input: []tf.Input{
|
||||
input_dataset, count,
|
||||
},
|
||||
Attrs: attrs,
|
||||
}
|
||||
op := scope.AddOperation(opspec)
|
||||
return op.Output(0)
|
||||
}
|
||||
|
||||
// RealAttr is an optional argument to Real.
|
||||
type RealAttr func(optionalAttr)
|
||||
|
||||
// RealTout sets the optional Tout attribute to value.
|
||||
// If not specified, defaults to DT_FLOAT
|
||||
func RealTout(value tf.DataType) RealAttr {
|
||||
return func(m optionalAttr) {
|
||||
m["Tout"] = value
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the real part of a complex number.
|
||||
//
|
||||
// Given a tensor `input` of complex numbers, this operation returns a tensor of
|
||||
// type `float` that is the real part of each element in `input`. All elements in
|
||||
// `input` must be complex numbers of the form \\(a + bj\\), where *a* is the real
|
||||
// part returned by this operation and *b* is the imaginary part.
|
||||
//
|
||||
// For example:
|
||||
//
|
||||
// ```
|
||||
// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j]
|
||||
// tf.real(input) ==> [-2.25, 3.25]
|
||||
// ```
|
||||
func Real(scope *Scope, input tf.Output, optional ...RealAttr) (output tf.Output) {
|
||||
if scope.Err() != nil {
|
||||
return
|
||||
}
|
||||
attrs := map[string]interface{}{}
|
||||
for _, a := range optional {
|
||||
a(attrs)
|
||||
}
|
||||
opspec := tf.OpSpec{
|
||||
Type: "Real",
|
||||
Input: []tf.Input{
|
||||
input,
|
||||
},
|
||||
Attrs: attrs,
|
||||
}
|
||||
op := scope.AddOperation(opspec)
|
||||
return op.Output(0)
|
||||
}
|
||||
|
||||
// ResizeAreaAttr is an optional argument to ResizeArea.
|
||||
type ResizeAreaAttr func(optionalAttr)
|
||||
|
||||
// ResizeAreaAlignCorners sets the optional align_corners attribute to value.
|
||||
//
|
||||
// value: If true, the centers of the 4 corner pixels of the input and output tensors are
|
||||
// aligned, preserving the values at the corner pixels. Defaults to false.
|
||||
// If not specified, defaults to false
|
||||
func ResizeAreaAlignCorners(value bool) ResizeAreaAttr {
|
||||
return func(m optionalAttr) {
|
||||
m["align_corners"] = value
|
||||
}
|
||||
}
|
||||
|
||||
// Resize `images` to `size` using area interpolation.
|
||||
//
|
||||
// Input images can be of different types but output images are always float.
|
||||
//
|
||||
// The range of pixel values for the output image might be slightly different
|
||||
// from the range for the input image because of limited numerical precision.
|
||||
// To guarantee an output range, for example `[0.0, 1.0]`, apply
|
||||
// `tf.clip_by_value` to the output.
|
||||
//
|
||||
// Each output pixel is computed by first transforming the pixel's footprint into
|
||||
// the input tensor and then averaging the pixels that intersect the footprint. An
|
||||
// input pixel's contribution to the average is weighted by the fraction of its
|
||||
// area that intersects the footprint. This is the same as OpenCV's INTER_AREA.
|
||||
//
|
||||
// Arguments:
|
||||
// images: 4-D with shape `[batch, height, width, channels]`.
|
||||
// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The
|
||||
// new size for the images.
|
||||
//
|
||||
// Returns 4-D with shape
|
||||
// `[batch, new_height, new_width, channels]`.
|
||||
func ResizeArea(scope *Scope, images tf.Output, size tf.Output, optional ...ResizeAreaAttr) (resized_images tf.Output) {
|
||||
if scope.Err() != nil {
|
||||
return
|
||||
}
|
||||
attrs := map[string]interface{}{}
|
||||
for _, a := range optional {
|
||||
a(attrs)
|
||||
}
|
||||
opspec := tf.OpSpec{
|
||||
Type: "ResizeArea",
|
||||
Input: []tf.Input{
|
||||
images, size,
|
||||
},
|
||||
Attrs: attrs,
|
||||
}
|
||||
op := scope.AddOperation(opspec)
|
||||
return op.Output(0)
|
||||
}
|
||||
|
||||
// VarHandleOpAttr is an optional argument to VarHandleOp.
|
||||
type VarHandleOpAttr func(optionalAttr)
|
||||
|
||||
@ -30639,74 +30710,3 @@ func UnravelIndex(scope *Scope, indices tf.Output, dims tf.Output) (output tf.Ou
|
||||
op := scope.AddOperation(opspec)
|
||||
return op.Output(0)
|
||||
}
|
||||
|
||||
// Compute the lower regularized incomplete Gamma function `Q(a, x)`.
|
||||
//
|
||||
// The lower regularized incomplete Gamma function is defined as:
|
||||
//
|
||||
//
|
||||
// \\(P(a, x) = gamma(a, x) / Gamma(a) = 1 - Q(a, x)\\)
|
||||
//
|
||||
// where
|
||||
//
|
||||
// \\(gamma(a, x) = int_{0}^{x} t^{a-1} exp(-t) dt\\)
|
||||
//
|
||||
// is the lower incomplete Gamma function.
|
||||
//
|
||||
// Note, above `Q(a, x)` (`Igammac`) is the upper regularized complete
|
||||
// Gamma function.
|
||||
func Igamma(scope *Scope, a tf.Output, x tf.Output) (z tf.Output) {
|
||||
if scope.Err() != nil {
|
||||
return
|
||||
}
|
||||
opspec := tf.OpSpec{
|
||||
Type: "Igamma",
|
||||
Input: []tf.Input{
|
||||
a, x,
|
||||
},
|
||||
}
|
||||
op := scope.AddOperation(opspec)
|
||||
return op.Output(0)
|
||||
}
|
||||
|
||||
// Computes offsets of concat inputs within its output.
|
||||
//
|
||||
// For example:
|
||||
//
|
||||
// ```
|
||||
// # 'x' is [2, 2, 7]
|
||||
// # 'y' is [2, 3, 7]
|
||||
// # 'z' is [2, 5, 7]
|
||||
// concat_offset(2, [x, y, z]) => [0, 0, 0], [0, 2, 0], [0, 5, 0]
|
||||
// ```
|
||||
//
|
||||
// This is typically used by gradient computations for a concat operation.
|
||||
//
|
||||
// Arguments:
|
||||
// concat_dim: The dimension along which to concatenate.
|
||||
// shape: The `N` int32 vectors representing shape of tensors being concatenated.
|
||||
//
|
||||
// Returns The `N` int32 vectors representing the starting offset
|
||||
// of input tensors within the concatenated output.
|
||||
func ConcatOffset(scope *Scope, concat_dim tf.Output, shape []tf.Output) (offset []tf.Output) {
|
||||
if scope.Err() != nil {
|
||||
return
|
||||
}
|
||||
opspec := tf.OpSpec{
|
||||
Type: "ConcatOffset",
|
||||
Input: []tf.Input{
|
||||
concat_dim, tf.OutputList(shape),
|
||||
},
|
||||
}
|
||||
op := scope.AddOperation(opspec)
|
||||
if scope.Err() != nil {
|
||||
return
|
||||
}
|
||||
var idx int
|
||||
var err error
|
||||
if offset, idx, err = makeOutputList(op, idx, "offset"); err != nil {
|
||||
scope.UpdateErr("ConcatOffset", err)
|
||||
return
|
||||
}
|
||||
return offset
|
||||
}
|
||||
|
@ -2530,6 +2530,7 @@ py_library(
|
||||
":check_ops",
|
||||
":confusion_matrix",
|
||||
":control_flow_ops",
|
||||
":distribute",
|
||||
":framework",
|
||||
":framework_for_generated_wrappers",
|
||||
":math_ops",
|
||||
|
@ -259,9 +259,7 @@ class DatasetConstructorTest(test.TestCase):
|
||||
sess.run(init_op)
|
||||
self.assertAllEqual([1, 2, 3], sess.run(get_next))
|
||||
self.assertAllEqual([4, 5, 6], sess.run(get_next))
|
||||
# NOTE(mrry): Type name in message differs between Python 2 (`long`) and
|
||||
# 3 (`int`).
|
||||
with self.assertRaisesOpError(r"invalid literal for"):
|
||||
with self.assertRaisesOpError("The expected type was int64"):
|
||||
sess.run(get_next)
|
||||
self.assertAllEqual([7, 8, 9], sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -290,6 +288,34 @@ class DatasetConstructorTest(test.TestCase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testFromGeneratorStructureError(self):
|
||||
def generator():
|
||||
yield 1, 2
|
||||
yield 3, 4
|
||||
yield 5
|
||||
yield 6, 7, 8
|
||||
yield 9, 10
|
||||
|
||||
iterator = (dataset_ops.Dataset.from_generator(
|
||||
generator, output_types=(dtypes.int64, dtypes.int64))
|
||||
.make_initializable_iterator())
|
||||
init_op = iterator.initializer
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.test_session() as sess:
|
||||
sess.run(init_op)
|
||||
self.assertEqual((1, 2), sess.run(get_next))
|
||||
self.assertEqual((3, 4), sess.run(get_next))
|
||||
with self.assertRaisesOpError(
|
||||
r"The expected structure was \(tf\.int64, tf\.int64\)"):
|
||||
sess.run(get_next)
|
||||
with self.assertRaisesOpError(
|
||||
r"The expected structure was \(tf\.int64, tf\.int64\)"):
|
||||
sess.run(get_next)
|
||||
self.assertEqual((9, 10), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
def testFromGeneratorHeterogeneous(self):
|
||||
def generator():
|
||||
yield 1
|
||||
|
@ -223,6 +223,13 @@ class Dataset(object):
|
||||
def from_tensors(tensors):
|
||||
"""Creates a `Dataset` with a single element, comprising the given tensors.
|
||||
|
||||
Note that if `tensors` contains a NumPy array, and eager execution is not
|
||||
enabled, the values will be embedded in the graph as one or more
|
||||
@{tf.constant} operations. For large datasets (> 1 GB), this can waste
|
||||
memory and run into byte limits of graph serialization. If tensors contains
|
||||
one or more large NumPy arrays, consider the alternative described in
|
||||
@{$programmers_guide/datasets#consuming_numpy_arrays$this guide}.
|
||||
|
||||
Args:
|
||||
tensors: A nested structure of tensors.
|
||||
|
||||
@ -235,6 +242,13 @@ class Dataset(object):
|
||||
def from_tensor_slices(tensors):
|
||||
"""Creates a `Dataset` whose elements are slices of the given tensors.
|
||||
|
||||
Note that if `tensors` contains a NumPy array, and eager execution is not
|
||||
enabled, the values will be embedded in the graph as one or more
|
||||
@{tf.constant} operations. For large datasets (> 1 GB), this can waste
|
||||
memory and run into byte limits of graph serialization. If tensors contains
|
||||
one or more large NumPy arrays, consider the alternative described in
|
||||
@{$programmers_guide/datasets#consuming_numpy_arrays$this guide}.
|
||||
|
||||
Args:
|
||||
tensors: A nested structure of tensors, each having the same size in the
|
||||
0th dimension.
|
||||
@ -409,13 +423,23 @@ class Dataset(object):
|
||||
# Use the same _convert function from the py_func() implementation to
|
||||
# convert the returned values to arrays early, so that we can inspect
|
||||
# their values.
|
||||
# pylint: disable=protected-access
|
||||
ret_arrays = [
|
||||
script_ops.FuncRegistry._convert(ret, dtype=dtype.as_numpy_dtype)
|
||||
for ret, dtype in zip(
|
||||
nest.flatten_up_to(output_types, values), flattened_types)
|
||||
]
|
||||
# pylint: enable=protected-access
|
||||
try:
|
||||
flattened_values = nest.flatten_up_to(output_types, values)
|
||||
except (TypeError, ValueError):
|
||||
raise TypeError(
|
||||
"`generator` yielded an element that did not match the expected "
|
||||
"structure. The expected structure was %s, but the yielded "
|
||||
"element was %s." % (output_types, values))
|
||||
ret_arrays = []
|
||||
for ret, dtype in zip(flattened_values, flattened_types):
|
||||
try:
|
||||
ret_arrays.append(script_ops.FuncRegistry._convert( # pylint: disable=protected-access
|
||||
ret, dtype=dtype.as_numpy_dtype))
|
||||
except (TypeError, ValueError):
|
||||
raise TypeError(
|
||||
"`generator` yielded an element that could not be converted to "
|
||||
"the expected type. The expected type was %s, but the yielded "
|
||||
"element was %s." % (dtype.name, ret))
|
||||
|
||||
# Additional type and shape checking to ensure that the components
|
||||
# of the generated element match the `output_types` and `output_shapes`
|
||||
|
@ -451,42 +451,48 @@ def get_error_intro(tf_error):
|
||||
sample commands for debugging.
|
||||
"""
|
||||
|
||||
op_name = tf_error.op.name
|
||||
if hasattr(tf_error, "op") and hasattr(tf_error.op, "name"):
|
||||
op_name = tf_error.op.name
|
||||
else:
|
||||
op_name = None
|
||||
|
||||
intro_lines = [
|
||||
"--------------------------------------",
|
||||
RL("!!! An error occurred during the run !!!", "blink"),
|
||||
"",
|
||||
"You may use the following commands to debug:",
|
||||
]
|
||||
|
||||
out = debugger_cli_common.rich_text_lines_from_rich_line_list(intro_lines)
|
||||
|
||||
out.extend(
|
||||
_recommend_command("ni -a -d -t %s" % op_name,
|
||||
"Inspect information about the failing op.",
|
||||
create_link=True))
|
||||
out.extend(
|
||||
_recommend_command("li -r %s" % op_name,
|
||||
"List inputs to the failing op, recursively.",
|
||||
create_link=True))
|
||||
if op_name is not None:
|
||||
out.extend(debugger_cli_common.RichTextLines(
|
||||
["You may use the following commands to debug:"]))
|
||||
out.extend(
|
||||
_recommend_command("ni -a -d -t %s" % op_name,
|
||||
"Inspect information about the failing op.",
|
||||
create_link=True))
|
||||
out.extend(
|
||||
_recommend_command("li -r %s" % op_name,
|
||||
"List inputs to the failing op, recursively.",
|
||||
create_link=True))
|
||||
|
||||
out.extend(
|
||||
_recommend_command(
|
||||
"lt",
|
||||
"List all tensors dumped during the failing run() call.",
|
||||
create_link=True))
|
||||
out.extend(
|
||||
_recommend_command(
|
||||
"lt",
|
||||
"List all tensors dumped during the failing run() call.",
|
||||
create_link=True))
|
||||
else:
|
||||
out.extend(debugger_cli_common.RichTextLines([
|
||||
"WARNING: Cannot determine the name of the op that caused the error."]))
|
||||
|
||||
more_lines = [
|
||||
"",
|
||||
"Op name: " + op_name,
|
||||
"Op name: %s" % op_name,
|
||||
"Error type: " + str(type(tf_error)),
|
||||
"",
|
||||
"Details:",
|
||||
str(tf_error),
|
||||
"",
|
||||
"WARNING: Using client GraphDef due to the error, instead of "
|
||||
"executor GraphDefs.",
|
||||
"--------------------------------------",
|
||||
"",
|
||||
]
|
||||
|
@ -372,6 +372,11 @@ class GetErrorIntroTest(test_util.TensorFlowTestCase):
|
||||
self.assertEqual("Details:", error_intro.lines[14])
|
||||
self.assertStartsWith(error_intro.lines[15], "foo description")
|
||||
|
||||
def testGetErrorIntroForNoOpName(self):
|
||||
tf_error = errors.OpError(None, None, "Fake OpError", -1)
|
||||
error_intro = cli_shared.get_error_intro(tf_error)
|
||||
self.assertIn("Cannot determine the name of the op", error_intro.lines[3])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
||||
|
@ -69,6 +69,12 @@ run
|
||||
exit
|
||||
EOF
|
||||
|
||||
cat << EOF | ${DEBUG_ERRORS_BIN} --error=uninitialized_variable --debug --ui_type=readline
|
||||
run
|
||||
ni -a -d -t v/read
|
||||
exit
|
||||
EOF
|
||||
|
||||
cat << EOF | ${DEBUG_MNIST_BIN} --debug --max_steps=1 --fake_data --ui_type=readline
|
||||
run -t 1
|
||||
run --node_name_filter hidden --op_type_filter MatMul
|
||||
|
@ -748,7 +748,7 @@ class DebugDumpDir(object):
|
||||
return sum(len(self._dump_tensor_data[device_name])
|
||||
for device_name in self._dump_tensor_data)
|
||||
|
||||
def _load_partition_graphs(self, partition_graphs, validate):
|
||||
def _load_partition_graphs(self, client_partition_graphs, validate):
|
||||
"""Load and process partition graphs.
|
||||
|
||||
Load the graphs; parse the input and control input structure; obtain the
|
||||
@ -757,8 +757,10 @@ class DebugDumpDir(object):
|
||||
tensor dumps.
|
||||
|
||||
Args:
|
||||
partition_graphs: A repeated field of GraphDefs representing the
|
||||
partition graphs executed by the TensorFlow runtime.
|
||||
client_partition_graphs: A repeated field of GraphDefs representing the
|
||||
partition graphs executed by the TensorFlow runtime, from the Python
|
||||
client. These partition graphs are used only if partition graphs
|
||||
cannot be loaded from the dump directory on the file system.
|
||||
validate: (`bool`) Whether the dump files are to be validated against the
|
||||
partition graphs.
|
||||
|
||||
@ -769,24 +771,23 @@ class DebugDumpDir(object):
|
||||
self._debug_graphs = {}
|
||||
self._node_devices = {}
|
||||
|
||||
if partition_graphs:
|
||||
partition_graphs_and_device_names = [
|
||||
(partition_graph, None) for partition_graph in partition_graphs]
|
||||
else:
|
||||
partition_graphs_and_device_names = []
|
||||
for device_name in self._device_names:
|
||||
partition_graph = None
|
||||
if device_name in self._dump_graph_file_paths:
|
||||
partition_graph = _load_graph_def_from_event_file(
|
||||
self._dump_graph_file_paths[device_name])
|
||||
else:
|
||||
partition_graph = self._find_partition_graph(partition_graphs,
|
||||
device_name)
|
||||
if partition_graph:
|
||||
partition_graphs_and_device_names.append((partition_graph,
|
||||
device_name))
|
||||
else:
|
||||
logging.warn("Failed to load partition graphs from disk.")
|
||||
partition_graphs_and_device_names = []
|
||||
for device_name in self._device_names:
|
||||
partition_graph = None
|
||||
if device_name in self._dump_graph_file_paths:
|
||||
partition_graph = _load_graph_def_from_event_file(
|
||||
self._dump_graph_file_paths[device_name])
|
||||
else:
|
||||
logging.warn(
|
||||
"Failed to load partition graphs for device %s from disk. "
|
||||
"As a fallback, the client graphs will be used. This "
|
||||
"may cause mismatches in device names." % device_name)
|
||||
partition_graph = self._find_partition_graph(client_partition_graphs,
|
||||
device_name)
|
||||
|
||||
if partition_graph:
|
||||
partition_graphs_and_device_names.append((partition_graph,
|
||||
device_name))
|
||||
|
||||
for partition_graph, maybe_device_name in partition_graphs_and_device_names:
|
||||
debug_graph = debug_graphs.DebugGraph(partition_graph,
|
||||
|
@ -1873,6 +1873,8 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
|
||||
delete backward_function;
|
||||
});
|
||||
|
||||
Py_DECREF(num_inputs);
|
||||
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
@ -1931,8 +1933,10 @@ bool ReadVariableOp(const FastPathOpExecInfo& parent_op_exec_info,
|
||||
Py_INCREF(output->get()); // stay alive after since tuple steals.
|
||||
PyTuple_SET_ITEM(outputs.get(), 0, output->get());
|
||||
|
||||
if (!RecordGradient(GetPythonObjectFromString("ReadVariableOp"),
|
||||
inputs.get(), Py_None, outputs.get(), Py_None)) {
|
||||
tensorflow::Safe_PyObjectPtr op_string(
|
||||
GetPythonObjectFromString("ReadVariableOp"));
|
||||
if (!RecordGradient(op_string.get(), inputs.get(), Py_None, outputs.get(),
|
||||
Py_None)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -1242,11 +1242,11 @@ class TensorFlowTestCase(googletest.TestCase):
|
||||
b,
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
msg="Mismatched value: a%s is different from b%s." % (path_str,
|
||||
path_str))
|
||||
msg=("Mismatched value: a%s is different from b%s. %s" %
|
||||
(path_str, path_str, msg)))
|
||||
except TypeError as e:
|
||||
msg = "Error: a%s has %s, but b%s has %s" % (path_str, type(a),
|
||||
path_str, type(b))
|
||||
msg = ("Error: a%s has %s, but b%s has %s. %s" %
|
||||
(path_str, type(a), path_str, type(b), msg))
|
||||
e.args = ((e.args[0] + " : " + msg,) + e.args[1:])
|
||||
raise
|
||||
|
||||
|
@ -118,6 +118,7 @@ class LocallyConnectedLayersTest(test.TestCase):
|
||||
},
|
||||
input_shape=(num_samples, num_row, num_col, stack_size))
|
||||
|
||||
@tf_test_util.run_in_graph_and_eager_modes()
|
||||
def test_locallyconnected_2d_channels_first(self):
|
||||
num_samples = 8
|
||||
filters = 3
|
||||
@ -125,15 +126,14 @@ class LocallyConnectedLayersTest(test.TestCase):
|
||||
num_row = 6
|
||||
num_col = 10
|
||||
|
||||
with self.test_session():
|
||||
testing_utils.layer_test(
|
||||
keras.layers.LocallyConnected2D,
|
||||
kwargs={
|
||||
'filters': filters,
|
||||
'kernel_size': 3,
|
||||
'data_format': 'channels_first'
|
||||
},
|
||||
input_shape=(num_samples, num_row, num_col, stack_size))
|
||||
testing_utils.layer_test(
|
||||
keras.layers.LocallyConnected2D,
|
||||
kwargs={
|
||||
'filters': filters,
|
||||
'kernel_size': 3,
|
||||
'data_format': 'channels_first'
|
||||
},
|
||||
input_shape=(num_samples, num_row, num_col, stack_size))
|
||||
|
||||
def test_locallyconnected_2d_regularization(self):
|
||||
num_samples = 8
|
||||
|
@ -241,6 +241,12 @@ class UnaryOpTest(test.TestCase):
|
||||
math_ops.lgamma)
|
||||
self._compareBoth(x, np.vectorize(math.erf), math_ops.erf)
|
||||
self._compareBoth(x, np.vectorize(math.erfc), math_ops.erfc)
|
||||
try:
|
||||
from scipy import special # pylint: disable=g-import-not-at-top
|
||||
self._compareBoth(x, special.i0e, math_ops.bessel_i0e)
|
||||
self._compareBoth(x, special.i1e, math_ops.bessel_i1e)
|
||||
except ImportError as e:
|
||||
tf_logging.warn("Cannot test special functions: %s" % str(e))
|
||||
|
||||
self._compareBothSparse(x, np.abs, math_ops.abs)
|
||||
self._compareBothSparse(x, np.negative, math_ops.negative)
|
||||
@ -286,6 +292,12 @@ class UnaryOpTest(test.TestCase):
|
||||
self._compareBoth(x, np.arcsin, math_ops.asin)
|
||||
self._compareBoth(x, np.arccos, math_ops.acos)
|
||||
self._compareBoth(x, np.arctan, math_ops.atan)
|
||||
try:
|
||||
from scipy import special # pylint: disable=g-import-not-at-top
|
||||
self._compareBoth(x, special.i0e, math_ops.bessel_i0e)
|
||||
self._compareBoth(x, special.i1e, math_ops.bessel_i1e)
|
||||
except ImportError as e:
|
||||
tf_logging.warn("Cannot test special functions: %s" % str(e))
|
||||
|
||||
self._compareBothSparse(x, np.abs, math_ops.abs)
|
||||
self._compareBothSparse(x, np.negative, math_ops.negative)
|
||||
@ -334,6 +346,12 @@ class UnaryOpTest(test.TestCase):
|
||||
self._compareBoth(k, np.arcsin, math_ops.asin)
|
||||
self._compareBoth(k, np.arccos, math_ops.acos)
|
||||
self._compareBoth(k, np.tan, math_ops.tan)
|
||||
try:
|
||||
from scipy import special # pylint: disable=g-import-not-at-top
|
||||
self._compareBoth(x, special.i0e, math_ops.bessel_i0e)
|
||||
self._compareBoth(x, special.i1e, math_ops.bessel_i1e)
|
||||
except ImportError as e:
|
||||
tf_logging.warn("Cannot test special functions: %s" % str(e))
|
||||
|
||||
self._compareBothSparse(x, np.abs, math_ops.abs)
|
||||
self._compareBothSparse(x, np.negative, math_ops.negative)
|
||||
@ -370,6 +388,12 @@ class UnaryOpTest(test.TestCase):
|
||||
math_ops.lgamma)
|
||||
self._compareBoth(x, np.vectorize(math.erf), math_ops.erf)
|
||||
self._compareBoth(x, np.vectorize(math.erfc), math_ops.erfc)
|
||||
try:
|
||||
from scipy import special # pylint: disable=g-import-not-at-top
|
||||
self._compareBoth(x, special.i0e, math_ops.bessel_i0e)
|
||||
self._compareBoth(x, special.i1e, math_ops.bessel_i1e)
|
||||
except ImportError as e:
|
||||
tf_logging.warn("Cannot test special functions: %s" % str(e))
|
||||
|
||||
self._compareBothSparse(x, np.abs, math_ops.abs)
|
||||
self._compareBothSparse(x, np.negative, math_ops.negative)
|
||||
|
@ -939,7 +939,8 @@ class ResizeMethod(object):
|
||||
def resize_images(images,
|
||||
size,
|
||||
method=ResizeMethod.BILINEAR,
|
||||
align_corners=False):
|
||||
align_corners=False,
|
||||
preserve_aspect_ratio=False):
|
||||
"""Resize `images` to `size` using the specified `method`.
|
||||
|
||||
Resized images will be distorted if their original aspect ratio is not
|
||||
@ -971,6 +972,10 @@ def resize_images(images,
|
||||
align_corners: bool. If True, the centers of the 4 corner pixels of the
|
||||
input and output tensors are aligned, preserving the values at the
|
||||
corner pixels. Defaults to `False`.
|
||||
preserve_aspect_ratio: Whether to preserve the aspect ratio. If this is set,
|
||||
then `images` will be resized to a size that fits in `size` while
|
||||
preserving the aspect ratio of the original image. Scales up the image if
|
||||
`size` is bigger than the current size of the `image`. Defaults to False.
|
||||
|
||||
Raises:
|
||||
ValueError: if the shape of `images` is incompatible with the
|
||||
@ -1009,6 +1014,28 @@ def resize_images(images,
|
||||
new_height_const = size_const_as_shape[0].value
|
||||
new_width_const = size_const_as_shape[1].value
|
||||
|
||||
if preserve_aspect_ratio:
|
||||
# Get the current shapes of the image, even if dynamic.
|
||||
_, current_height, current_width, _ = _ImageDimensions(images, rank=4)
|
||||
|
||||
# do the computation to find the right scale and height/width.
|
||||
scale_factor_height = (math_ops.to_float(new_height_const) /
|
||||
math_ops.to_float(current_height))
|
||||
scale_factor_width = (math_ops.to_float(new_width_const) /
|
||||
math_ops.to_float(current_width))
|
||||
scale_factor = math_ops.minimum(scale_factor_height, scale_factor_width)
|
||||
scaled_height_const = math_ops.to_int32(scale_factor *
|
||||
math_ops.to_float(current_height))
|
||||
scaled_width_const = math_ops.to_int32(scale_factor *
|
||||
math_ops.to_float(current_width))
|
||||
|
||||
# NOTE: Reset the size and other constants used later.
|
||||
size = ops.convert_to_tensor([scaled_height_const, scaled_width_const],
|
||||
dtypes.int32, name='size')
|
||||
size_const_as_shape = tensor_util.constant_value_as_shape(size)
|
||||
new_height_const = size_const_as_shape[0].value
|
||||
new_width_const = size_const_as_shape[1].value
|
||||
|
||||
# If we can determine that the height and width will be unmodified by this
|
||||
# transformation, we avoid performing the resize.
|
||||
if all(x is not None
|
||||
@ -1469,6 +1496,75 @@ def adjust_hue(image, delta, name=None):
|
||||
return convert_image_dtype(rgb_altered, orig_dtype)
|
||||
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
@tf_export('image.random_jpeg_quality')
|
||||
def random_jpeg_quality(image, min_jpeg_quality, max_jpeg_quality, seed=None):
|
||||
"""Randomly changes jpeg encoding quality for inducing jpeg noise.
|
||||
|
||||
`min_jpeg_quality` must be in the interval `[0, 100]` and less than
|
||||
`max_jpeg_quality`.
|
||||
`max_jpeg_quality` must be in the interval `[0, 100]`.
|
||||
|
||||
Args:
|
||||
image: RGB image or images. Size of the last dimension must be 3.
|
||||
min_jpeg_quality: Minimum jpeg encoding quality to use.
|
||||
max_jpeg_quality: Maximum jpeg encoding quality to use.
|
||||
seed: An operation-specific seed. It will be used in conjunction
|
||||
with the graph-level seed to determine the real seeds that will be
|
||||
used in this operation. Please see the documentation of
|
||||
set_random_seed for its interaction with the graph-level random seed.
|
||||
|
||||
Returns:
|
||||
Adjusted image(s), same shape and DType as `image`.
|
||||
|
||||
Raises:
|
||||
ValueError: if `min_jpeg_quality` or `max_jpeg_quality` is invalid.
|
||||
"""
|
||||
if (min_jpeg_quality < 0 or max_jpeg_quality < 0 or
|
||||
min_jpeg_quality > 100 or max_jpeg_quality > 100):
|
||||
raise ValueError('jpeg encoding range must be between 0 and 100.')
|
||||
|
||||
if min_jpeg_quality >= max_jpeg_quality:
|
||||
raise ValueError('`min_jpeg_quality` must be less than `max_jpeg_quality`.')
|
||||
|
||||
np.random.seed(seed)
|
||||
jpeg_quality = np.random.randint(min_jpeg_quality, max_jpeg_quality)
|
||||
return adjust_jpeg_quality(image, jpeg_quality)
|
||||
|
||||
|
||||
@tf_export('image.adjust_jpeg_quality')
|
||||
def adjust_jpeg_quality(image, jpeg_quality, name=None):
|
||||
"""Adjust jpeg encoding quality of an RGB image.
|
||||
|
||||
This is a convenience method that adjusts jpeg encoding quality of an
|
||||
RGB image.
|
||||
|
||||
`image` is an RGB image. The image's encoding quality is adjusted
|
||||
to `jpeg_quality`.
|
||||
`jpeg_quality` must be in the interval `[0, 100]`.
|
||||
|
||||
Args:
|
||||
image: RGB image or images. Size of the last dimension must be 3.
|
||||
jpeg_quality: int. jpeg encoding quality.
|
||||
name: A name for this operation (optional).
|
||||
|
||||
Returns:
|
||||
Adjusted image(s), same shape and DType as `image`.
|
||||
"""
|
||||
with ops.name_scope(name, 'adjust_jpeg_quality', [image]) as name:
|
||||
image = ops.convert_to_tensor(image, name='image')
|
||||
# Remember original dtype to so we can convert back if needed
|
||||
orig_dtype = image.dtype
|
||||
# Convert to uint8
|
||||
image = convert_image_dtype(image, dtypes.uint8)
|
||||
# Encode image to jpeg with given jpeg quality
|
||||
image = gen_image_ops.encode_jpeg(image, quality=jpeg_quality)
|
||||
# Decode jpeg image
|
||||
image = gen_image_ops.decode_jpeg(image)
|
||||
# Convert back to original dtype and return
|
||||
return convert_image_dtype(image, orig_dtype)
|
||||
|
||||
|
||||
@tf_export('image.random_saturation')
|
||||
def random_saturation(image, lower, upper, seed=None):
|
||||
"""Adjust the saturation of an RGB image by a random factor.
|
||||
|
@ -2599,6 +2599,86 @@ class ResizeImagesTest(test_util.TensorFlowTestCase):
|
||||
y = image_ops.resize_images(single_image, [55, 66])
|
||||
self.assertTrue(y.op.name.startswith("resize_images"))
|
||||
|
||||
def _ResizeImageCall(self, x, max_h, max_w, preserve_aspect_ratio,
|
||||
use_tensor_inputs):
|
||||
if use_tensor_inputs:
|
||||
target_max = ops.convert_to_tensor([max_h, max_w])
|
||||
x_tensor = array_ops.placeholder(x.dtype, shape=[None] * x.ndim)
|
||||
feed_dict = {x_tensor: x}
|
||||
else:
|
||||
target_max = [max_h, max_w]
|
||||
x_tensor = x
|
||||
feed_dict = {}
|
||||
|
||||
y = image_ops.resize_images(x_tensor, target_max,
|
||||
preserve_aspect_ratio=preserve_aspect_ratio)
|
||||
|
||||
with self.test_session(use_gpu=True):
|
||||
return y.eval(feed_dict=feed_dict)
|
||||
|
||||
def _assertResizeEqual(self, x, x_shape, y, y_shape,
|
||||
preserve_aspect_ratio=True,
|
||||
use_tensor_inputs_options=None):
|
||||
use_tensor_inputs_options = use_tensor_inputs_options or [False, True]
|
||||
target_height, target_width, _ = y_shape
|
||||
x = np.array(x).reshape(x_shape)
|
||||
y = np.array(y).reshape(y_shape)
|
||||
|
||||
for use_tensor_inputs in use_tensor_inputs_options:
|
||||
y_tf = self._ResizeImageCall(x, target_height, target_width,
|
||||
preserve_aspect_ratio, use_tensor_inputs)
|
||||
self.assertAllClose(y, y_tf)
|
||||
|
||||
def _assertResizeCheckShape(self, x, x_shape, target_shape,
|
||||
y_shape, preserve_aspect_ratio=True,
|
||||
use_tensor_inputs_options=None):
|
||||
use_tensor_inputs_options = use_tensor_inputs_options or [False, True]
|
||||
target_height, target_width = target_shape
|
||||
x = np.array(x).reshape(x_shape)
|
||||
y = np.zeros(y_shape)
|
||||
|
||||
for use_tensor_inputs in use_tensor_inputs_options:
|
||||
y_tf = self._ResizeImageCall(x, target_height, target_width,
|
||||
preserve_aspect_ratio, use_tensor_inputs)
|
||||
self.assertShapeEqual(y, ops.convert_to_tensor(y_tf))
|
||||
|
||||
def testPreserveAspectRatioMultipleImages(self):
|
||||
x_shape = [10, 100, 100, 10]
|
||||
x = np.random.uniform(size=x_shape)
|
||||
|
||||
self._assertResizeCheckShape(x, x_shape, [250, 250], [10, 250, 250, 10],
|
||||
preserve_aspect_ratio=False)
|
||||
|
||||
def testPreserveAspectRatioNoOp(self):
|
||||
x_shape = [10, 10, 10]
|
||||
x = np.random.uniform(size=x_shape)
|
||||
|
||||
self._assertResizeEqual(x, x_shape, x, x_shape)
|
||||
|
||||
def testPreserveAspectRatioSmaller(self):
|
||||
x_shape = [100, 100, 10]
|
||||
x = np.random.uniform(size=x_shape)
|
||||
|
||||
self._assertResizeCheckShape(x, x_shape, [75, 50], [50, 50, 10])
|
||||
|
||||
def testPreserveAspectRatioSmallerMultipleImages(self):
|
||||
x_shape = [10, 100, 100, 10]
|
||||
x = np.random.uniform(size=x_shape)
|
||||
|
||||
self._assertResizeCheckShape(x, x_shape, [75, 50], [10, 50, 50, 10])
|
||||
|
||||
def testPreserveAspectRatioLarger(self):
|
||||
x_shape = [100, 100, 10]
|
||||
x = np.random.uniform(size=x_shape)
|
||||
|
||||
self._assertResizeCheckShape(x, x_shape, [150, 200], [150, 150, 10])
|
||||
|
||||
def testPreserveAspectRatioSameRatio(self):
|
||||
x_shape = [1920, 1080, 3]
|
||||
x = np.random.uniform(size=x_shape)
|
||||
|
||||
self._assertResizeCheckShape(x, x_shape, [3840, 2160], [3840, 2160, 3])
|
||||
|
||||
|
||||
class ResizeImageWithCropOrPadTest(test_util.TensorFlowTestCase):
|
||||
|
||||
|
@ -620,6 +620,35 @@ def _DigammaGrad(op, grad):
|
||||
return grad * math_ops.polygamma(array_ops.constant(1, dtype=x.dtype), x)
|
||||
|
||||
|
||||
@ops.RegisterGradient("BesselI0e")
|
||||
def _BesselI0eGrad(op, grad):
|
||||
"""Compute gradient of bessel_i0e(x) with respect to its argument."""
|
||||
x = op.inputs[0]
|
||||
y = op.outputs[0]
|
||||
with ops.control_dependencies([grad]):
|
||||
return grad * (math_ops.bessel_i1e(x) - math_ops.sign(x) * y)
|
||||
|
||||
|
||||
@ops.RegisterGradient("BesselI1e")
|
||||
def _BesselI1eGrad(op, grad):
|
||||
"""Compute gradient of bessel_i1e(x) with respect to its argument."""
|
||||
x = op.inputs[0]
|
||||
y = op.outputs[0]
|
||||
with ops.control_dependencies([grad]):
|
||||
# For x = 0, the correct gradient is 0.5.
|
||||
# However, the main branch gives NaN because of the division by x, so
|
||||
# we impute the gradient manually.
|
||||
# An alternative solution is to express the gradient via bessel_i0e and
|
||||
# bessel_i2e, but the latter is not yet implemented in Eigen.
|
||||
eps = np.finfo(x.dtype.as_numpy_dtype).eps
|
||||
zeros = array_ops.zeros_like(x)
|
||||
x_is_not_tiny = math_ops.abs(x) > eps
|
||||
safe_x = array_ops.where(x_is_not_tiny, x, eps + zeros)
|
||||
dy_dx = math_ops.bessel_i0e(safe_x) - y * (
|
||||
math_ops.sign(safe_x) + math_ops.reciprocal(safe_x))
|
||||
return grad * array_ops.where(x_is_not_tiny, dy_dx, 0.5 + zeros)
|
||||
|
||||
|
||||
@ops.RegisterGradient("Igamma")
|
||||
def _IgammaGrad(op, grad):
|
||||
"""Returns gradient of igamma(a, x) with respect to x."""
|
||||
|
@ -2954,6 +2954,67 @@ def polyval(coeffs, x, name=None):
|
||||
p = c + p * x
|
||||
return p
|
||||
|
||||
|
||||
@tf_export("math.bessel_i0e")
|
||||
def bessel_i0e(x, name=None):
|
||||
"""Computes the Bessel i0e function of `x` element-wise.
|
||||
|
||||
Exponentially scaled modified Bessel function of order 0 defined as
|
||||
`bessel_i0e(x) = exp(-abs(x)) bessel_i0(x)`.
|
||||
|
||||
This function is faster and numerically stabler than `bessel_i0(x)`.
|
||||
|
||||
Args:
|
||||
x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
|
||||
`float32`, `float64`.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
|
||||
|
||||
@compatibility(scipy)
|
||||
Equivalent to scipy.special.i0e
|
||||
@end_compatibility
|
||||
"""
|
||||
with ops.name_scope(name, "bessel_i0e", [x]) as name:
|
||||
if isinstance(x, sparse_tensor.SparseTensor):
|
||||
x_i0e = gen_math_ops.bessel_i0e(x.values, name=name)
|
||||
return sparse_tensor.SparseTensor(
|
||||
indices=x.indices, values=x_i0e, dense_shape=x.dense_shape)
|
||||
else:
|
||||
return gen_math_ops.bessel_i0e(x, name=name)
|
||||
|
||||
|
||||
@tf_export("math.bessel_i1e")
|
||||
def bessel_i1e(x, name=None):
|
||||
"""Computes the Bessel i1e function of `x` element-wise.
|
||||
|
||||
Exponentially scaled modified Bessel function of order 1 defined as
|
||||
`bessel_i1e(x) = exp(-abs(x)) bessel_i1(x)`.
|
||||
|
||||
This function is faster and numerically stabler than `bessel_i1(x)`.
|
||||
|
||||
Args:
|
||||
x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
|
||||
`float32`, `float64`.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
|
||||
|
||||
@compatibility(scipy)
|
||||
Equivalent to scipy.special.i1e
|
||||
@end_compatibility
|
||||
"""
|
||||
with ops.name_scope(name, "bessel_i1e", [x]) as name:
|
||||
if isinstance(x, sparse_tensor.SparseTensor):
|
||||
x_i1e = gen_math_ops.bessel_i1e(x.values, name=name)
|
||||
return sparse_tensor.SparseTensor(
|
||||
indices=x.indices, values=x_i1e, dense_shape=x.dense_shape)
|
||||
else:
|
||||
return gen_math_ops.bessel_i1e(x, name=name)
|
||||
|
||||
|
||||
# FFT ops were moved to tf.spectral. tf.fft symbols were part of the TensorFlow
|
||||
# 1.0 API so we leave these here for backwards compatibility.
|
||||
fft = gen_spectral_ops.fft
|
||||
|
@ -34,21 +34,54 @@ from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import weights_broadcast_ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import distribute as distribute_lib
|
||||
from tensorflow.python.util.deprecation import deprecated
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
def metric_variable(shape, dtype, validate_shape=True, name=None):
|
||||
"""Create variable in `GraphKeys.(LOCAL|METRIC_VARIABLES`) collections."""
|
||||
"""Create variable in `GraphKeys.(LOCAL|METRIC_VARIABLES)` collections.
|
||||
|
||||
return variable_scope.variable(
|
||||
lambda: array_ops.zeros(shape, dtype),
|
||||
trainable=False,
|
||||
collections=[
|
||||
ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.METRIC_VARIABLES
|
||||
],
|
||||
validate_shape=validate_shape,
|
||||
name=name)
|
||||
If running in a `DistributionStrategy` context, the variable will be
|
||||
"tower local". This means:
|
||||
|
||||
* The returned object will be a container with separate variables
|
||||
per replica/tower of the model.
|
||||
|
||||
* When writing to the variable, e.g. using `assign_add` in a metric
|
||||
update, the update will be applied to the variable local to the
|
||||
replica/tower.
|
||||
|
||||
* To get a metric's result value, we need to sum the variable values
|
||||
across the replicas/towers before computing the final answer.
|
||||
Furthermore, the final answer should be computed once instead of
|
||||
in every replica/tower. Both of these are accomplished by
|
||||
running the computation of the final result value inside
|
||||
`tf.contrib.distribute.get_tower_context().merge_call(fn)`.
|
||||
Inside the `merge_call()`, ops are only added to the graph once
|
||||
and access to a tower-local variable in a computation returns
|
||||
the sum across all replicas/towers.
|
||||
|
||||
Args:
|
||||
shape: Shape of the created variable.
|
||||
dtype: Type of the created variable.
|
||||
validate_shape: (Optional) Whether shape validation is enabled for
|
||||
the created variable.
|
||||
name: (Optional) String name of the created variable.
|
||||
|
||||
Returns:
|
||||
A (non-trainable) variable initialized to zero, or if inside a
|
||||
`DistributionStrategy` scope a tower-local variable container.
|
||||
"""
|
||||
with distribute_lib.get_tower_context().tower_local_var_scope('sum'):
|
||||
# Note that "tower local" implies trainable=False.
|
||||
return variable_scope.variable(
|
||||
lambda: array_ops.zeros(shape, dtype),
|
||||
collections=[
|
||||
ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.METRIC_VARIABLES
|
||||
],
|
||||
validate_shape=validate_shape,
|
||||
name=name)
|
||||
|
||||
|
||||
def _remove_squeezable_dimensions(predictions, labels, weights):
|
||||
@ -333,11 +366,15 @@ def mean(values,
|
||||
with ops.control_dependencies([values]):
|
||||
update_count_op = state_ops.assign_add(count, num_values)
|
||||
|
||||
mean_t = _safe_div(total, count, 'value')
|
||||
update_op = _safe_div(update_total_op, update_count_op, 'update_op')
|
||||
def aggregate_across_towers(_, t, c):
|
||||
mean_t = _safe_div(t, c, 'value')
|
||||
if metrics_collections:
|
||||
ops.add_to_collections(metrics_collections, mean_t)
|
||||
return mean_t
|
||||
|
||||
if metrics_collections:
|
||||
ops.add_to_collections(metrics_collections, mean_t)
|
||||
mean_t = distribute_lib.get_tower_context().merge_call(
|
||||
aggregate_across_towers, total, count)
|
||||
update_op = _safe_div(update_total_op, update_count_op, 'update_op')
|
||||
|
||||
if updates_collections:
|
||||
ops.add_to_collections(updates_collections, update_op)
|
||||
@ -572,6 +609,17 @@ def _confusion_matrix_at_thresholds(labels,
|
||||
return values, update_ops
|
||||
|
||||
|
||||
def _aggregate_variable(v, collections):
|
||||
|
||||
def f(distribution, value):
|
||||
value = distribution.read_var(value)
|
||||
if collections:
|
||||
ops.add_to_collections(collections, value)
|
||||
return value
|
||||
|
||||
return distribute_lib.get_tower_context().merge_call(f, v)
|
||||
|
||||
|
||||
@tf_export('metrics.auc')
|
||||
def auc(labels,
|
||||
predictions,
|
||||
@ -757,14 +805,18 @@ def auc(labels,
|
||||
raise ValueError('Invalid summation_method: %s' % summation_method)
|
||||
|
||||
# sum up the areas of all the trapeziums
|
||||
auc_value = compute_auc(values['tp'], values['fn'], values['tn'],
|
||||
values['fp'], 'value')
|
||||
def aggregate_auc(_, values):
|
||||
auc_value = compute_auc(values['tp'], values['fn'], values['tn'],
|
||||
values['fp'], 'value')
|
||||
if metrics_collections:
|
||||
ops.add_to_collections(metrics_collections, auc_value)
|
||||
return auc_value
|
||||
|
||||
auc_value = distribute_lib.get_tower_context().merge_call(
|
||||
aggregate_auc, values)
|
||||
update_op = compute_auc(update_ops['tp'], update_ops['fn'],
|
||||
update_ops['tn'], update_ops['fp'], 'update_op')
|
||||
|
||||
if metrics_collections:
|
||||
ops.add_to_collections(metrics_collections, auc_value)
|
||||
|
||||
if updates_collections:
|
||||
ops.add_to_collections(updates_collections, update_op)
|
||||
|
||||
@ -992,15 +1044,18 @@ def mean_per_class_accuracy(labels,
|
||||
update_total_op = state_ops.scatter_add(total, labels, ones)
|
||||
update_count_op = state_ops.scatter_add(count, labels, is_correct)
|
||||
|
||||
per_class_accuracy = _safe_div(count, total, None)
|
||||
def aggregate_mean_accuracy(_, count, total):
|
||||
per_class_accuracy = _safe_div(count, total, None)
|
||||
mean_accuracy_v = math_ops.reduce_mean(
|
||||
per_class_accuracy, name='mean_accuracy')
|
||||
if metrics_collections:
|
||||
ops.add_to_collections(metrics_collections, mean_accuracy_v)
|
||||
return mean_accuracy_v
|
||||
|
||||
mean_accuracy_v = distribute_lib.get_tower_context().merge_call(
|
||||
aggregate_mean_accuracy, count, total)
|
||||
|
||||
mean_accuracy_v = math_ops.reduce_mean(
|
||||
per_class_accuracy, name='mean_accuracy')
|
||||
update_op = _safe_div(update_count_op, update_total_op, name='update_op')
|
||||
|
||||
if metrics_collections:
|
||||
ops.add_to_collections(metrics_collections, mean_accuracy_v)
|
||||
|
||||
if updates_collections:
|
||||
ops.add_to_collections(updates_collections, update_op)
|
||||
|
||||
@ -1071,7 +1126,7 @@ def mean_iou(labels,
|
||||
total_cm, update_op = _streaming_confusion_matrix(labels, predictions,
|
||||
num_classes, weights)
|
||||
|
||||
def compute_mean_iou(name):
|
||||
def compute_mean_iou(total_cm, name):
|
||||
"""Compute the mean intersection-over-union via the confusion matrix."""
|
||||
sum_over_row = math_ops.to_float(math_ops.reduce_sum(total_cm, 0))
|
||||
sum_over_col = math_ops.to_float(math_ops.reduce_sum(total_cm, 1))
|
||||
@ -1098,10 +1153,14 @@ def mean_iou(labels,
|
||||
math_ops.reduce_sum(iou, name=name) / num_valid_entries, 0)
|
||||
return result
|
||||
|
||||
mean_iou_v = compute_mean_iou('mean_iou')
|
||||
def mean_iou_across_towers(_, v):
|
||||
mean_iou_v = compute_mean_iou(v, 'mean_iou')
|
||||
if metrics_collections:
|
||||
ops.add_to_collections(metrics_collections, mean_iou_v)
|
||||
return mean_iou_v
|
||||
|
||||
if metrics_collections:
|
||||
ops.add_to_collections(metrics_collections, mean_iou_v)
|
||||
mean_iou_v = distribute_lib.get_tower_context().merge_call(
|
||||
mean_iou_across_towers, total_cm)
|
||||
|
||||
if updates_collections:
|
||||
ops.add_to_collections(updates_collections, update_op)
|
||||
@ -1310,12 +1369,16 @@ def mean_tensor(values,
|
||||
with ops.control_dependencies([values]):
|
||||
update_count_op = state_ops.assign_add(count, num_values)
|
||||
|
||||
mean_t = _safe_div(total, count, 'value')
|
||||
def aggregate_across_towers(_, t, c):
|
||||
mean_t = _safe_div(t, c, 'value')
|
||||
if metrics_collections:
|
||||
ops.add_to_collections(metrics_collections, mean_t)
|
||||
return mean_t
|
||||
|
||||
mean_t = distribute_lib.get_tower_context().merge_call(
|
||||
aggregate_across_towers, total, count)
|
||||
|
||||
update_op = _safe_div(update_total_op, update_count_op, 'update_op')
|
||||
|
||||
if metrics_collections:
|
||||
ops.add_to_collections(metrics_collections, mean_t)
|
||||
|
||||
if updates_collections:
|
||||
ops.add_to_collections(updates_collections, update_op)
|
||||
|
||||
@ -1413,12 +1476,9 @@ def _count_condition(values,
|
||||
weights = math_ops.to_float(weights)
|
||||
values = math_ops.multiply(values, weights)
|
||||
|
||||
value_tensor = array_ops.identity(count)
|
||||
value_tensor = _aggregate_variable(count, metrics_collections)
|
||||
|
||||
update_op = state_ops.assign_add(count, math_ops.reduce_sum(values))
|
||||
|
||||
if metrics_collections:
|
||||
ops.add_to_collections(metrics_collections, value_tensor)
|
||||
|
||||
if updates_collections:
|
||||
ops.add_to_collections(updates_collections, update_op)
|
||||
|
||||
@ -1525,13 +1585,12 @@ def false_negatives_at_thresholds(labels,
|
||||
values, update_ops = _confusion_matrix_at_thresholds(
|
||||
labels, predictions, thresholds, weights=weights, includes=('fn',))
|
||||
|
||||
if metrics_collections:
|
||||
ops.add_to_collections(metrics_collections, values['fn'])
|
||||
fn_value = _aggregate_variable(values['fn'], metrics_collections)
|
||||
|
||||
if updates_collections:
|
||||
ops.add_to_collections(updates_collections, update_ops['fn'])
|
||||
|
||||
return values['fn'], update_ops['fn']
|
||||
return fn_value, update_ops['fn']
|
||||
|
||||
|
||||
@tf_export('metrics.false_positives')
|
||||
@ -1635,13 +1694,12 @@ def false_positives_at_thresholds(labels,
|
||||
values, update_ops = _confusion_matrix_at_thresholds(
|
||||
labels, predictions, thresholds, weights=weights, includes=('fp',))
|
||||
|
||||
if metrics_collections:
|
||||
ops.add_to_collections(metrics_collections, values['fp'])
|
||||
fp_value = _aggregate_variable(values['fp'], metrics_collections)
|
||||
|
||||
if updates_collections:
|
||||
ops.add_to_collections(updates_collections, update_ops['fp'])
|
||||
|
||||
return values['fp'], update_ops['fp']
|
||||
return fp_value, update_ops['fp']
|
||||
|
||||
|
||||
@tf_export('metrics.true_negatives')
|
||||
@ -1745,13 +1803,12 @@ def true_negatives_at_thresholds(labels,
|
||||
values, update_ops = _confusion_matrix_at_thresholds(
|
||||
labels, predictions, thresholds, weights=weights, includes=('tn',))
|
||||
|
||||
if metrics_collections:
|
||||
ops.add_to_collections(metrics_collections, values['tn'])
|
||||
tn_value = _aggregate_variable(values['tn'], metrics_collections)
|
||||
|
||||
if updates_collections:
|
||||
ops.add_to_collections(updates_collections, update_ops['tn'])
|
||||
|
||||
return values['tn'], update_ops['tn']
|
||||
return tn_value, update_ops['tn']
|
||||
|
||||
|
||||
@tf_export('metrics.true_positives')
|
||||
@ -1855,13 +1912,12 @@ def true_positives_at_thresholds(labels,
|
||||
values, update_ops = _confusion_matrix_at_thresholds(
|
||||
labels, predictions, thresholds, weights=weights, includes=('tp',))
|
||||
|
||||
if metrics_collections:
|
||||
ops.add_to_collections(metrics_collections, values['tp'])
|
||||
tp_value = _aggregate_variable(values['tp'], metrics_collections)
|
||||
|
||||
if updates_collections:
|
||||
ops.add_to_collections(updates_collections, update_ops['tp'])
|
||||
|
||||
return values['tp'], update_ops['tp']
|
||||
return tp_value, update_ops['tp']
|
||||
|
||||
|
||||
@tf_export('metrics.precision')
|
||||
@ -1945,13 +2001,17 @@ def precision(labels,
|
||||
return array_ops.where(
|
||||
math_ops.greater(tp + fp, 0), math_ops.div(tp, tp + fp), 0, name)
|
||||
|
||||
p = compute_precision(true_p, false_p, 'value')
|
||||
def once_across_towers(_, true_p, false_p):
|
||||
p = compute_precision(true_p, false_p, 'value')
|
||||
if metrics_collections:
|
||||
ops.add_to_collections(metrics_collections, p)
|
||||
return p
|
||||
|
||||
p = distribute_lib.get_tower_context().merge_call(
|
||||
once_across_towers, true_p, false_p)
|
||||
|
||||
update_op = compute_precision(true_positives_update_op,
|
||||
false_positives_update_op, 'update_op')
|
||||
|
||||
if metrics_collections:
|
||||
ops.add_to_collections(metrics_collections, p)
|
||||
|
||||
if updates_collections:
|
||||
ops.add_to_collections(updates_collections, update_op)
|
||||
|
||||
@ -2025,13 +2085,17 @@ def precision_at_thresholds(labels,
|
||||
def compute_precision(tp, fp, name):
|
||||
return math_ops.div(tp, epsilon + tp + fp, name='precision_' + name)
|
||||
|
||||
prec = compute_precision(values['tp'], values['fp'], 'value')
|
||||
def precision_across_towers(_, values):
|
||||
prec = compute_precision(values['tp'], values['fp'], 'value')
|
||||
if metrics_collections:
|
||||
ops.add_to_collections(metrics_collections, prec)
|
||||
return prec
|
||||
|
||||
prec = distribute_lib.get_tower_context().merge_call(
|
||||
precision_across_towers, values)
|
||||
|
||||
update_op = compute_precision(update_ops['tp'], update_ops['fp'],
|
||||
'update_op')
|
||||
|
||||
if metrics_collections:
|
||||
ops.add_to_collections(metrics_collections, prec)
|
||||
|
||||
if updates_collections:
|
||||
ops.add_to_collections(updates_collections, update_op)
|
||||
|
||||
@ -2050,7 +2114,7 @@ def recall(labels,
|
||||
The `recall` function creates two local variables, `true_positives`
|
||||
and `false_negatives`, that are used to compute the recall. This value is
|
||||
ultimately returned as `recall`, an idempotent operation that simply divides
|
||||
`true_positives` by the sum of `true_positives` and `false_negatives`.
|
||||
`true_positives` by the sum of `true_positives` and `false_negatives`.
|
||||
|
||||
For estimation of the metric over a stream of data, the function creates an
|
||||
`update_op` that updates these variables and returns the `recall`. `update_op`
|
||||
@ -2117,13 +2181,17 @@ def recall(labels,
|
||||
math_ops.greater(true_p + false_n, 0),
|
||||
math_ops.div(true_p, true_p + false_n), 0, name)
|
||||
|
||||
rec = compute_recall(true_p, false_n, 'value')
|
||||
def once_across_towers(_, true_p, false_n):
|
||||
rec = compute_recall(true_p, false_n, 'value')
|
||||
if metrics_collections:
|
||||
ops.add_to_collections(metrics_collections, rec)
|
||||
return rec
|
||||
|
||||
rec = distribute_lib.get_tower_context().merge_call(
|
||||
once_across_towers, true_p, false_n)
|
||||
|
||||
update_op = compute_recall(true_positives_update_op,
|
||||
false_negatives_update_op, 'update_op')
|
||||
|
||||
if metrics_collections:
|
||||
ops.add_to_collections(metrics_collections, rec)
|
||||
|
||||
if updates_collections:
|
||||
ops.add_to_collections(updates_collections, update_op)
|
||||
|
||||
@ -2552,11 +2620,17 @@ def recall_at_top_k(labels,
|
||||
class_id=class_id,
|
||||
weights=weights)
|
||||
|
||||
metric = math_ops.div(tp, math_ops.add(tp, fn), name=scope)
|
||||
def aggregate_across_towers(_, tp, fn):
|
||||
metric = math_ops.div(tp, math_ops.add(tp, fn), name=scope)
|
||||
if metrics_collections:
|
||||
ops.add_to_collections(metrics_collections, metric)
|
||||
return metric
|
||||
|
||||
metric = distribute_lib.get_tower_context().merge_call(
|
||||
aggregate_across_towers, tp, fn)
|
||||
|
||||
update = math_ops.div(
|
||||
tp_update, math_ops.add(tp_update, fn_update), name='update')
|
||||
if metrics_collections:
|
||||
ops.add_to_collections(metrics_collections, metric)
|
||||
if updates_collections:
|
||||
ops.add_to_collections(updates_collections, update)
|
||||
return metric, update
|
||||
@ -2627,12 +2701,16 @@ def recall_at_thresholds(labels,
|
||||
def compute_recall(tp, fn, name):
|
||||
return math_ops.div(tp, epsilon + tp + fn, name='recall_' + name)
|
||||
|
||||
rec = compute_recall(values['tp'], values['fn'], 'value')
|
||||
def recall_across_towers(_, values):
|
||||
rec = compute_recall(values['tp'], values['fn'], 'value')
|
||||
if metrics_collections:
|
||||
ops.add_to_collections(metrics_collections, rec)
|
||||
return rec
|
||||
|
||||
rec = distribute_lib.get_tower_context().merge_call(
|
||||
recall_across_towers, values)
|
||||
|
||||
update_op = compute_recall(update_ops['tp'], update_ops['fn'], 'update_op')
|
||||
|
||||
if metrics_collections:
|
||||
ops.add_to_collections(metrics_collections, rec)
|
||||
|
||||
if updates_collections:
|
||||
ops.add_to_collections(updates_collections, update_op)
|
||||
|
||||
@ -2698,13 +2776,16 @@ def root_mean_squared_error(labels,
|
||||
mse, update_mse_op = mean_squared_error(labels, predictions, weights, None,
|
||||
None, name or
|
||||
'root_mean_squared_error')
|
||||
def once_across_towers(_, mse):
|
||||
rmse = math_ops.sqrt(mse)
|
||||
if metrics_collections:
|
||||
ops.add_to_collections(metrics_collections, rmse)
|
||||
return rmse
|
||||
|
||||
rmse = distribute_lib.get_tower_context().merge_call(
|
||||
once_across_towers, mse)
|
||||
|
||||
rmse = math_ops.sqrt(mse)
|
||||
update_rmse_op = math_ops.sqrt(update_mse_op)
|
||||
|
||||
if metrics_collections:
|
||||
ops.add_to_collections(metrics_collections, rmse)
|
||||
|
||||
if updates_collections:
|
||||
ops.add_to_collections(updates_collections, update_rmse_op)
|
||||
|
||||
@ -2797,15 +2878,19 @@ def sensitivity_at_specificity(labels,
|
||||
return math_ops.div(tp[tf_index], tp[tf_index] + fn[tf_index] + kepsilon,
|
||||
name)
|
||||
|
||||
sensitivity = compute_sensitivity_at_specificity(
|
||||
values['tp'], values['tn'], values['fp'], values['fn'], 'value')
|
||||
def aggregate_across_towers(_, values):
|
||||
sensitivity = compute_sensitivity_at_specificity(
|
||||
values['tp'], values['tn'], values['fp'], values['fn'], 'value')
|
||||
if metrics_collections:
|
||||
ops.add_to_collections(metrics_collections, sensitivity)
|
||||
return sensitivity
|
||||
|
||||
sensitivity = distribute_lib.get_tower_context().merge_call(
|
||||
aggregate_across_towers, values)
|
||||
|
||||
update_op = compute_sensitivity_at_specificity(
|
||||
update_ops['tp'], update_ops['tn'], update_ops['fp'], update_ops['fn'],
|
||||
'update_op')
|
||||
|
||||
if metrics_collections:
|
||||
ops.add_to_collections(metrics_collections, sensitivity)
|
||||
|
||||
if updates_collections:
|
||||
ops.add_to_collections(updates_collections, update_op)
|
||||
|
||||
@ -3070,11 +3155,16 @@ def _streaming_sparse_average_precision_at_top_k(labels,
|
||||
total_update = state_ops.assign_add(total_var, batch_total, name='update')
|
||||
|
||||
# Divide total by max to get mean, for both vars and the update ops.
|
||||
mean_average_precision = _safe_scalar_div(total_var, max_var, name='mean')
|
||||
update = _safe_scalar_div(total_update, max_update, name=scope)
|
||||
def aggregate_across_towers(_, total_var, max_var):
|
||||
mean_average_precision = _safe_scalar_div(total_var, max_var, name='mean')
|
||||
if metrics_collections:
|
||||
ops.add_to_collections(metrics_collections, mean_average_precision)
|
||||
return mean_average_precision
|
||||
|
||||
if metrics_collections:
|
||||
ops.add_to_collections(metrics_collections, mean_average_precision)
|
||||
mean_average_precision = distribute_lib.get_tower_context().merge_call(
|
||||
aggregate_across_towers, total_var, max_var)
|
||||
|
||||
update = _safe_scalar_div(total_update, max_update, name=scope)
|
||||
if updates_collections:
|
||||
ops.add_to_collections(updates_collections, update)
|
||||
|
||||
@ -3351,11 +3441,17 @@ def precision_at_top_k(labels,
|
||||
class_id=class_id,
|
||||
weights=weights)
|
||||
|
||||
metric = math_ops.div(tp, math_ops.add(tp, fp), name=scope)
|
||||
def aggregate_across_towers(_, tp, fp):
|
||||
metric = math_ops.div(tp, math_ops.add(tp, fp), name=scope)
|
||||
if metrics_collections:
|
||||
ops.add_to_collections(metrics_collections, metric)
|
||||
return metric
|
||||
|
||||
metric = distribute_lib.get_tower_context().merge_call(
|
||||
aggregate_across_towers, tp, fp)
|
||||
|
||||
update = math_ops.div(
|
||||
tp_update, math_ops.add(tp_update, fp_update), name='update')
|
||||
if metrics_collections:
|
||||
ops.add_to_collections(metrics_collections, metric)
|
||||
if updates_collections:
|
||||
ops.add_to_collections(updates_collections, update)
|
||||
return metric, update
|
||||
@ -3583,15 +3679,19 @@ def specificity_at_sensitivity(labels,
|
||||
return math_ops.div(tn[tf_index], tn[tf_index] + fp[tf_index] + kepsilon,
|
||||
name)
|
||||
|
||||
specificity = compute_specificity_at_sensitivity(
|
||||
values['tp'], values['tn'], values['fp'], values['fn'], 'value')
|
||||
def aggregate_across_towers(_, values):
|
||||
specificity = compute_specificity_at_sensitivity(
|
||||
values['tp'], values['tn'], values['fp'], values['fn'], 'value')
|
||||
if metrics_collections:
|
||||
ops.add_to_collections(metrics_collections, specificity)
|
||||
return specificity
|
||||
|
||||
specificity = distribute_lib.get_tower_context().merge_call(
|
||||
aggregate_across_towers, values)
|
||||
|
||||
update_op = compute_specificity_at_sensitivity(
|
||||
update_ops['tp'], update_ops['tn'], update_ops['fp'], update_ops['fn'],
|
||||
'update_op')
|
||||
|
||||
if metrics_collections:
|
||||
ops.add_to_collections(metrics_collections, specificity)
|
||||
|
||||
if updates_collections:
|
||||
ops.add_to_collections(updates_collections, update_op)
|
||||
|
||||
|
@ -82,6 +82,54 @@ def lbeta(x, name='lbeta'):
|
||||
return result
|
||||
|
||||
|
||||
@tf_export('math.bessel_i0')
|
||||
def bessel_i0(x, name='bessel_i0'):
|
||||
"""Computes the Bessel i0 function of `x` element-wise.
|
||||
|
||||
Modified Bessel function of order 0.
|
||||
|
||||
It is preferable to use the numerically stabler function `i0e(x)` instead.
|
||||
|
||||
Args:
|
||||
x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
|
||||
`float32`, `float64`.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
|
||||
|
||||
@compatibility(scipy)
|
||||
Equivalent to scipy.special.i0
|
||||
@end_compatibility
|
||||
"""
|
||||
with ops.name_scope(name, [x]):
|
||||
return math_ops.exp(math_ops.abs(x)) * math_ops.bessel_i0e(x)
|
||||
|
||||
|
||||
@tf_export('math.bessel_i1')
|
||||
def bessel_i1(x, name='bessel_i1'):
|
||||
"""Computes the Bessel i1 function of `x` element-wise.
|
||||
|
||||
Modified Bessel function of order 1.
|
||||
|
||||
It is preferable to use the numerically stabler function `i1e(x)` instead.
|
||||
|
||||
Args:
|
||||
x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
|
||||
`float32`, `float64`.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
|
||||
|
||||
@compatibility(scipy)
|
||||
Equivalent to scipy.special.i1
|
||||
@end_compatibility
|
||||
"""
|
||||
with ops.name_scope(name, [x]):
|
||||
return math_ops.exp(math_ops.abs(x)) * math_ops.bessel_i1e(x)
|
||||
|
||||
|
||||
@tf_export('einsum', 'linalg.einsum')
|
||||
def einsum(equation, *inputs, **kwargs):
|
||||
"""A generalized contraction between tensors of arbitrary dimension.
|
||||
|
@ -29,6 +29,7 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import special_math_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.platform import tf_logging
|
||||
|
||||
|
||||
class LBetaTest(test.TestCase):
|
||||
@ -150,6 +151,33 @@ class LBetaTest(test.TestCase):
|
||||
self.assertEqual(expected_result.get_shape(), lbeta_x.get_shape())
|
||||
|
||||
|
||||
class BesselTest(test.TestCase):
|
||||
|
||||
def test_bessel_i0(self):
|
||||
x_single = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float32)
|
||||
x_double = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float64)
|
||||
try:
|
||||
from scipy import special # pylint: disable=g-import-not-at-top
|
||||
self.assertAllClose(special.i0(x_single),
|
||||
self.evaluate(special_math_ops.bessel_i0(x_single)))
|
||||
self.assertAllClose(special.i0(x_double),
|
||||
self.evaluate(special_math_ops.bessel_i0(x_double)))
|
||||
except ImportError as e:
|
||||
tf_logging.warn('Cannot test special functions: %s' % str(e))
|
||||
|
||||
def test_bessel_i1(self):
|
||||
x_single = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float32)
|
||||
x_double = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float64)
|
||||
try:
|
||||
from scipy import special # pylint: disable=g-import-not-at-top
|
||||
self.assertAllClose(special.i1(x_single),
|
||||
self.evaluate(special_math_ops.bessel_i1(x_single)))
|
||||
self.assertAllClose(special.i1(x_double),
|
||||
self.evaluate(special_math_ops.bessel_i1(x_double)))
|
||||
except ImportError as e:
|
||||
tf_logging.warn('Cannot test special functions: %s' % str(e))
|
||||
|
||||
|
||||
class EinsumTest(test.TestCase):
|
||||
|
||||
simple_cases = [
|
||||
|
@ -79,12 +79,14 @@ def _parse_saved_model(export_dir):
|
||||
constants.SAVED_MODEL_FILENAME_PB))
|
||||
|
||||
|
||||
def _get_asset_tensors(export_dir, meta_graph_def_to_load):
|
||||
def _get_asset_tensors(export_dir, meta_graph_def_to_load, import_scope=None):
|
||||
"""Gets the asset tensors, if defined in the meta graph def to load.
|
||||
|
||||
Args:
|
||||
export_dir: Directory where the SavedModel is located.
|
||||
meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded.
|
||||
import_scope: Optional `string` -- if specified, prepend this followed by
|
||||
'/' to all returned asset tensor names.
|
||||
|
||||
Returns:
|
||||
A dictionary of asset tensors, keyed by the name of the asset tensor. The
|
||||
@ -104,7 +106,10 @@ def _get_asset_tensors(export_dir, meta_graph_def_to_load):
|
||||
for asset_any_proto in assets_any_proto:
|
||||
asset_proto = meta_graph_pb2.AssetFileDef()
|
||||
asset_any_proto.Unpack(asset_proto)
|
||||
asset_tensor_dict[asset_proto.tensor_info.name] = os.path.join(
|
||||
tensor_name = asset_proto.tensor_info.name
|
||||
if import_scope:
|
||||
tensor_name = "%s/%s" % (import_scope, tensor_name)
|
||||
asset_tensor_dict[tensor_name] = os.path.join(
|
||||
compat.as_bytes(assets_directory),
|
||||
compat.as_bytes(asset_proto.filename))
|
||||
return asset_tensor_dict
|
||||
@ -179,7 +184,7 @@ def maybe_saved_model_directory(export_dir):
|
||||
|
||||
|
||||
@tf_export("saved_model.loader.load")
|
||||
def load(sess, tags, export_dir, **saver_kwargs):
|
||||
def load(sess, tags, export_dir, import_scope=None, **saver_kwargs):
|
||||
"""Loads the model from a SavedModel as specified by tags.
|
||||
|
||||
Args:
|
||||
@ -189,6 +194,10 @@ def load(sess, tags, export_dir, **saver_kwargs):
|
||||
SavedModel `save()` API.
|
||||
export_dir: Directory in which the SavedModel protocol buffer and variables
|
||||
to be loaded are located.
|
||||
import_scope: Optional `string` -- if specified, prepend this string
|
||||
followed by '/' to all loaded tensor names. This scope is applied to
|
||||
tensor instances loaded into the passed session, but it is *not* written
|
||||
through to the static `MetaGraphDef` protocol buffer that is returned.
|
||||
**saver_kwargs: Optional keyword arguments passed through to Saver.
|
||||
|
||||
Returns:
|
||||
@ -216,7 +225,8 @@ def load(sess, tags, export_dir, **saver_kwargs):
|
||||
)
|
||||
|
||||
# Build a saver by importing the meta graph def to load.
|
||||
saver = tf_saver.import_meta_graph(meta_graph_def_to_load, **saver_kwargs)
|
||||
saver = tf_saver.import_meta_graph(
|
||||
meta_graph_def_to_load, import_scope=import_scope, **saver_kwargs)
|
||||
|
||||
if saver:
|
||||
# Build the checkpoint path where the variables are located.
|
||||
@ -232,8 +242,8 @@ def load(sess, tags, export_dir, **saver_kwargs):
|
||||
"checkpoints were restored.")
|
||||
|
||||
# Get asset tensors, if any.
|
||||
asset_tensors_dictionary = _get_asset_tensors(export_dir,
|
||||
meta_graph_def_to_load)
|
||||
asset_tensors_dictionary = _get_asset_tensors(
|
||||
export_dir, meta_graph_def_to_load, import_scope=import_scope)
|
||||
|
||||
main_op_tensor = (
|
||||
_get_main_op_tensor(meta_graph_def_to_load) or
|
||||
|
@ -1197,6 +1197,59 @@ class SavedModelTest(test.TestCase):
|
||||
_validate_custom_saver("tag_1", "save_1/restore_all")
|
||||
_validate_custom_saver("tag_2", "save_2/restore_all")
|
||||
|
||||
def testImportScope(self):
|
||||
export_dir = self._get_export_dir("test_scoped_assets")
|
||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||
|
||||
# Build a SavedModel with a variable, an asset, and a constant tensor.
|
||||
with self.test_session(graph=ops.Graph()) as sess:
|
||||
self._init_and_validate_variable(sess, "v", 42)
|
||||
asset_collection = self._build_asset_collection("foo.txt", "content_foo",
|
||||
"asset_file_tensor")
|
||||
constant_op.constant("constant value", name="constant_tensor_name")
|
||||
builder.add_meta_graph_and_variables(
|
||||
sess, ["tag_name"], assets_collection=asset_collection)
|
||||
|
||||
# Save the asset file path for later comparison.
|
||||
asset_file_path = asset_collection[0].eval()
|
||||
|
||||
# Save the SavedModel to disk.
|
||||
builder.save()
|
||||
|
||||
with self.test_session(graph=ops.Graph()) as sess:
|
||||
# Restore the SavedModel under an import_scope in a new graph/session.
|
||||
graph_proto = loader.load(
|
||||
sess, ["tag_name"], export_dir, import_scope="scope_name")
|
||||
|
||||
# The loaded variable tensor should be scoped, but its contents should be
|
||||
# unchanged.
|
||||
self.assertEqual(
|
||||
"scope_name/v:0",
|
||||
ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].name)
|
||||
self.assertEqual(
|
||||
42,
|
||||
ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
|
||||
|
||||
# The loaded asset tensor should be scoped, but the asset file path and
|
||||
# contents should be unchanged.
|
||||
asset_collection = ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS)
|
||||
self.assertEqual(1, len(asset_collection))
|
||||
self.assertEqual(asset_file_path, asset_collection[0].eval())
|
||||
self.assertEqual("scope_name/asset_file_tensor:0",
|
||||
asset_collection[0].name)
|
||||
# The static asset data inside graph_proto.collection_def should not be
|
||||
# scoped.
|
||||
self._validate_asset_collection(export_dir, graph_proto.collection_def,
|
||||
"foo.txt", "content_foo",
|
||||
"asset_file_tensor:0")
|
||||
|
||||
# The constant tensor should be scoped, but its contents should be
|
||||
# unchanged.
|
||||
self.assertEqual(
|
||||
compat.as_bytes("constant value"),
|
||||
ops.get_default_graph().get_tensor_by_name(
|
||||
"scope_name/constant_tensor_name:0").eval())
|
||||
|
||||
def testClearDevices(self):
|
||||
export_dir = self._get_export_dir("test_clear_devices")
|
||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||
|
@ -528,6 +528,8 @@ class DistributionStrategy(object):
|
||||
* `d.update_non_slot(d.non_slot_devices(), fn)`: in cross-tower
|
||||
context, like `d.update()` except with locality N.
|
||||
* `d.fetch(t)`: Copy `t` with any locality to the client's CPU device.
|
||||
TODO(josh11b): Deprecate `fetch`, switch to `read_var` for
|
||||
reading tower-local variables.
|
||||
|
||||
The standard pattern for updating variables is to:
|
||||
|
||||
@ -614,8 +616,8 @@ class DistributionStrategy(object):
|
||||
|
||||
There will still be one component variable per tower, but there is
|
||||
no requirement that they stay in sync. Instead, when saving them
|
||||
or calling `fetch()`, we use the value that results when calling
|
||||
`reduce()` on all the towers' variables.
|
||||
or calling `fetch()/read_var()`, we use the value that
|
||||
results when calling `reduce()` on all the towers' variables.
|
||||
|
||||
Note: tower-local implies not trainable. Instead, it is expected
|
||||
that each tower will directly update (using `assign_add()` or
|
||||
@ -646,6 +648,21 @@ class DistributionStrategy(object):
|
||||
_require_distribution_strategy_scope(self)
|
||||
return variable_scope.variable_creator_scope(create_tower_local_variable)
|
||||
|
||||
def read_var(self, v):
|
||||
"""Reads the value of a variable.
|
||||
|
||||
Returns the aggregate value of a tower-local variable, or the
|
||||
(read-only) value of any other variable.
|
||||
|
||||
Args:
|
||||
v: A variable allocated within the scope of this `DistributionStrategy`.
|
||||
|
||||
Returns:
|
||||
A tensor representing the value of `v`, aggregated across towers if
|
||||
necessary.
|
||||
"""
|
||||
raise NotImplementedError("must be implemented in descendants")
|
||||
|
||||
def colocate_vars_with(self, colocate_with_variable):
|
||||
"""Scope that controls which devices variables will be created on.
|
||||
|
||||
@ -904,6 +921,8 @@ class DistributionStrategy(object):
|
||||
will attempt to avoid a copy by checking if the value is already
|
||||
on the destination device.
|
||||
|
||||
TODO(josh11b): Switch to `read_var`.
|
||||
|
||||
Args:
|
||||
val: Value (which may be mirrored) to copy.
|
||||
destination: A device string to copy the value to.
|
||||
@ -1197,6 +1216,9 @@ class _DefaultDistributionStrategy(DistributionStrategy):
|
||||
with ops.colocate_with(colocate_with), UpdateContext(colocate_with):
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
def read_var(self, tower_local_var):
|
||||
return array_ops.identity(tower_local_var)
|
||||
|
||||
def _fetch(self, var, destination, fn):
|
||||
with ops.colocate_with(var):
|
||||
var = fn(var)
|
||||
|
@ -1970,7 +1970,7 @@ def import_meta_graph(meta_graph_or_file, clear_devices=False,
|
||||
|
||||
return Saver(saver_def=meta_graph_def.saver_def, name=scope)
|
||||
else:
|
||||
if variables._all_saveable_objects(): # pylint: disable=protected-access
|
||||
if variables._all_saveable_objects(scope=import_scope): # pylint: disable=protected-access
|
||||
# Return the default saver instance for all graph variables.
|
||||
return Saver()
|
||||
else:
|
||||
|
@ -2339,6 +2339,46 @@ class MetaGraphTest(test.TestCase):
|
||||
10, size=[1, 10])
|
||||
})
|
||||
|
||||
def testImportIntoNamescopeWithoutVariables(self):
|
||||
# Save a simple graph that contains no variables into a checkpoint.
|
||||
test_dir = self._get_test_dir("no_vars_graph")
|
||||
filename = os.path.join(test_dir, "ckpt")
|
||||
graph_1 = ops_lib.Graph()
|
||||
with session.Session(graph=graph_1) as sess:
|
||||
constant_op.constant([1, 2, 3], name="x")
|
||||
constant_op.constant([1, 2, 3], name="y")
|
||||
saver = saver_module.Saver(allow_empty=True)
|
||||
saver.save(sess, filename)
|
||||
|
||||
# Create a fresh graph.
|
||||
graph_2 = ops_lib.Graph()
|
||||
with session.Session(graph=graph_2) as sess:
|
||||
# Restore the above checkpoint under scope "subgraph_1".
|
||||
new_saver_1 = saver_module.import_meta_graph(
|
||||
filename + ".meta", graph=graph_2, import_scope="subgraph_1")
|
||||
# There are no variables to restore, so import_meta_graph should not
|
||||
# return a Saver.
|
||||
self.assertIsNone(new_saver_1)
|
||||
|
||||
# Create a variable in graph_2 under scope "my_scope".
|
||||
variables.Variable(array_ops.zeros([10]), name="my_scope/my_var")
|
||||
sess.run(variables.global_variables_initializer())
|
||||
# Restore the checkpoint into a different scope "subgraph_2".
|
||||
new_saver_2 = saver_module.import_meta_graph(
|
||||
filename + ".meta", graph=graph_2, import_scope="subgraph_2")
|
||||
# Because the variable does not live in scope "subgraph_2",
|
||||
# import_meta_graph should not attempt to restore the variable. So,
|
||||
# import_meta_graph still won't return a Saver instance.
|
||||
self.assertIsNone(new_saver_2)
|
||||
|
||||
# However, if we restore the checkpoint under scope "my_scope",
|
||||
# import_meta_graph will detect the variable and return a Saver for
|
||||
# restoring it. This should happen even when the variable does not
|
||||
# originate from graph_1.
|
||||
new_saver_3 = saver_module.import_meta_graph(
|
||||
filename + ".meta", graph=graph_2, import_scope="my_scope")
|
||||
self.assertIsInstance(new_saver_3, saver_module.Saver)
|
||||
|
||||
def testImportIntoImplicitNamescope(self):
|
||||
# Test that we can import a meta graph into an implicit namescope.
|
||||
test_dir = self._get_test_dir("import_into_namescope")
|
||||
|
@ -24,17 +24,12 @@ limitations under the License.
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#ifdef __APPLE__
|
||||
#include <IOKit/kext/KextManager.h>
|
||||
#include <mach-o/dyld.h>
|
||||
#else
|
||||
#if !defined(PLATFORM_WINDOWS)
|
||||
#include <link.h>
|
||||
#include <sys/sysmacros.h>
|
||||
#include <unistd.h>
|
||||
#endif
|
||||
#include <sys/stat.h>
|
||||
#endif
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
@ -54,9 +49,7 @@ limitations under the License.
|
||||
namespace stream_executor {
|
||||
namespace cuda {
|
||||
|
||||
#ifdef __APPLE__
|
||||
static const CFStringRef kDriverKextIdentifier = CFSTR("com.nvidia.CUDA");
|
||||
#elif !defined(PLATFORM_WINDOWS)
|
||||
#if !defined(PLATFORM_WINDOWS)
|
||||
static const char *kDriverVersionPath = "/proc/driver/nvidia/version";
|
||||
#endif
|
||||
|
||||
@ -121,26 +114,7 @@ string Diagnostician::GetDevNodePath(int dev_node_ordinal) {
|
||||
}
|
||||
|
||||
void Diagnostician::LogDiagnosticInformation() {
|
||||
#ifdef __APPLE__
|
||||
CFStringRef kext_ids[1];
|
||||
kext_ids[0] = kDriverKextIdentifier;
|
||||
CFArrayRef kext_id_query = CFArrayCreate(nullptr, (const void**)kext_ids, 1, &kCFTypeArrayCallBacks);
|
||||
CFDictionaryRef kext_infos = KextManagerCopyLoadedKextInfo(kext_id_query, nullptr);
|
||||
CFRelease(kext_id_query);
|
||||
|
||||
CFDictionaryRef cuda_driver_info = nullptr;
|
||||
if (CFDictionaryGetValueIfPresent(kext_infos, kDriverKextIdentifier, (const void**)&cuda_driver_info)) {
|
||||
bool started = CFBooleanGetValue((CFBooleanRef)CFDictionaryGetValue(cuda_driver_info, CFSTR("OSBundleStarted")));
|
||||
if (!started) {
|
||||
LOG(INFO) << "kernel driver is installed, but does not appear to be running on this host "
|
||||
<< "(" << port::Hostname() << ")";
|
||||
}
|
||||
} else {
|
||||
LOG(INFO) << "kernel driver does not appear to be installed on this host "
|
||||
<< "(" << port::Hostname() << ")";
|
||||
}
|
||||
CFRelease(kext_infos);
|
||||
#elif !defined(PLATFORM_WINDOWS)
|
||||
#if !defined(PLATFORM_WINDOWS)
|
||||
if (access(kDriverVersionPath, F_OK) != 0) {
|
||||
LOG(INFO) << "kernel driver does not appear to be running on this host "
|
||||
<< "(" << port::Hostname() << "): "
|
||||
@ -194,8 +168,7 @@ void Diagnostician::LogDiagnosticInformation() {
|
||||
<< DriverVersionStatusToString(kernel_version);
|
||||
#endif
|
||||
|
||||
// OS X kernel driver does not report version accurately
|
||||
#if !defined(__APPLE__) && !defined(PLATFORM_WINDOWS)
|
||||
#if !defined(PLATFORM_WINDOWS)
|
||||
if (kernel_version.ok() && dso_version.ok()) {
|
||||
WarnOnDsoKernelMismatch(dso_version, kernel_version);
|
||||
}
|
||||
@ -209,29 +182,6 @@ port::StatusOr<DriverVersion> Diagnostician::FindDsoVersion() {
|
||||
port::error::NOT_FOUND,
|
||||
"was unable to find libcuda.so DSO loaded into this program"));
|
||||
|
||||
#if defined(__APPLE__)
|
||||
// OSX CUDA libraries have names like: libcuda_310.41.15_mercury.dylib
|
||||
const string prefix("libcuda_");
|
||||
const string suffix("_mercury.dylib");
|
||||
for (uint32_t image_index = 0; image_index < _dyld_image_count(); ++image_index) {
|
||||
const string path(_dyld_get_image_name(image_index));
|
||||
const size_t suffix_pos = path.rfind(suffix);
|
||||
const size_t prefix_pos = path.rfind(prefix, suffix_pos);
|
||||
if (prefix_pos == string::npos ||
|
||||
suffix_pos == string::npos) {
|
||||
// no match
|
||||
continue;
|
||||
}
|
||||
const size_t start = prefix_pos + prefix.size();
|
||||
if (start >= suffix_pos) {
|
||||
// version not included
|
||||
continue;
|
||||
}
|
||||
const size_t length = suffix_pos - start;
|
||||
const string version = path.substr(start, length);
|
||||
result = StringToDriverVersion(version);
|
||||
}
|
||||
#else
|
||||
#if !defined(PLATFORM_WINDOWS) && !defined(ANDROID_TEGRA)
|
||||
// Callback used when iterating through DSOs. Looks for the driver-interfacing
|
||||
// DSO and yields its version number into the callback data, when found.
|
||||
@ -264,7 +214,6 @@ port::StatusOr<DriverVersion> Diagnostician::FindDsoVersion() {
|
||||
};
|
||||
|
||||
dl_iterate_phdr(iterate_phdr, &result);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
return result;
|
||||
@ -310,38 +259,7 @@ void Diagnostician::WarnOnDsoKernelMismatch(
|
||||
|
||||
|
||||
port::StatusOr<DriverVersion> Diagnostician::FindKernelDriverVersion() {
|
||||
#if defined(__APPLE__)
|
||||
CFStringRef kext_ids[1];
|
||||
kext_ids[0] = kDriverKextIdentifier;
|
||||
CFArrayRef kext_id_query = CFArrayCreate(nullptr, (const void**)kext_ids, 1, &kCFTypeArrayCallBacks);
|
||||
CFDictionaryRef kext_infos = KextManagerCopyLoadedKextInfo(kext_id_query, nullptr);
|
||||
CFRelease(kext_id_query);
|
||||
|
||||
CFDictionaryRef cuda_driver_info = nullptr;
|
||||
if (CFDictionaryGetValueIfPresent(kext_infos, kDriverKextIdentifier, (const void**)&cuda_driver_info)) {
|
||||
// NOTE: OSX CUDA driver does not currently store the same driver version
|
||||
// in kCFBundleVersionKey as is returned by cuDriverGetVersion
|
||||
CFRelease(kext_infos);
|
||||
const CFStringRef str = (CFStringRef)CFDictionaryGetValue(
|
||||
cuda_driver_info, kCFBundleVersionKey);
|
||||
const char *version = CFStringGetCStringPtr(str, kCFStringEncodingUTF8);
|
||||
|
||||
// version can be NULL in which case treat it as empty string
|
||||
// see
|
||||
// https://developer.apple.com/library/mac/documentation/CoreFoundation/Conceptual/CFStrings/Articles/AccessingContents.html#//apple_ref/doc/uid/20001184-100980-TPXREF112
|
||||
if (version == NULL) {
|
||||
return StringToDriverVersion("");
|
||||
}
|
||||
return StringToDriverVersion(version);
|
||||
}
|
||||
CFRelease(kext_infos);
|
||||
auto status = port::Status(
|
||||
port::error::INTERNAL,
|
||||
port::StrCat(
|
||||
"failed to read driver bundle version: ",
|
||||
CFStringGetCStringPtr(kDriverKextIdentifier, kCFStringEncodingUTF8)));
|
||||
return status;
|
||||
#elif defined(PLATFORM_WINDOWS)
|
||||
#if defined(PLATFORM_WINDOWS)
|
||||
auto status =
|
||||
port::Status(port::error::UNIMPLEMENTED,
|
||||
"kernel reported driver version not implemented on Windows");
|
||||
|
@ -495,10 +495,10 @@ PersistentRnnPlan CreatePersistentRnnPlan(cudnnRNNDescriptor_t rnn_desc,
|
||||
|
||||
// Turns a BatchDescriptor structure into a cudnn tensor handle within a
|
||||
// scope.
|
||||
class ScopedTensorDescriptor {
|
||||
class CudnnTensorDescriptor {
|
||||
public:
|
||||
ScopedTensorDescriptor(const dnn::BatchDescriptor& batch_descriptor,
|
||||
cudnnDataType_t elem_type)
|
||||
CudnnTensorDescriptor(const dnn::BatchDescriptor& batch_descriptor,
|
||||
cudnnDataType_t elem_type)
|
||||
: handle_(CreateTensorDescriptor()) {
|
||||
switch (batch_descriptor.layout()) {
|
||||
case dnn::DataLayout::kBatchYXDepth:
|
||||
@ -540,15 +540,15 @@ class ScopedTensorDescriptor {
|
||||
private:
|
||||
TensorDescriptor handle_;
|
||||
|
||||
SE_DISALLOW_COPY_AND_ASSIGN(ScopedTensorDescriptor);
|
||||
SE_DISALLOW_COPY_AND_ASSIGN(CudnnTensorDescriptor);
|
||||
};
|
||||
|
||||
// Turns a FilterDescriptor structure into a cudnn filter handle within a
|
||||
// scope.
|
||||
class ScopedFilterDescriptor {
|
||||
class CudnnFilterDescriptor {
|
||||
public:
|
||||
ScopedFilterDescriptor(const dnn::FilterDescriptor& filter_descriptor,
|
||||
cudnnDataType_t elem_type)
|
||||
CudnnFilterDescriptor(const dnn::FilterDescriptor& filter_descriptor,
|
||||
cudnnDataType_t elem_type)
|
||||
: handle_(CreateFilterDescriptor()) {
|
||||
// TODO(b/23032134): Even if the filter layout is not supported,
|
||||
// cudnnSetFilter4DDescriptor_v4 will return CUDNN_STATUS_SUCCESS because
|
||||
@ -586,7 +586,7 @@ class ScopedFilterDescriptor {
|
||||
private:
|
||||
FilterDescriptor handle_; // Owned.
|
||||
|
||||
SE_DISALLOW_COPY_AND_ASSIGN(ScopedFilterDescriptor);
|
||||
SE_DISALLOW_COPY_AND_ASSIGN(CudnnFilterDescriptor);
|
||||
};
|
||||
|
||||
// A helper function to decide whether to enable the TENSOR_OP_MATH math type
|
||||
@ -636,9 +636,9 @@ bool BatchnormSpatialPersistentEnabled() {
|
||||
|
||||
// Turns a ConvolutionDescriptor structure into a cudnn convolution handle
|
||||
// within a scope.
|
||||
class ScopedConvolutionDescriptor {
|
||||
class CudnnConvolutionDescriptor {
|
||||
public:
|
||||
ScopedConvolutionDescriptor(
|
||||
CudnnConvolutionDescriptor(
|
||||
const dnn::ConvolutionDescriptor& convolution_descriptor,
|
||||
cudnnDataType_t data_type)
|
||||
: handle_(CreateConvolutionDescriptor()) {
|
||||
@ -700,14 +700,14 @@ class ScopedConvolutionDescriptor {
|
||||
private:
|
||||
ConvolutionDescriptor handle_; // Owned.
|
||||
|
||||
SE_DISALLOW_COPY_AND_ASSIGN(ScopedConvolutionDescriptor);
|
||||
SE_DISALLOW_COPY_AND_ASSIGN(CudnnConvolutionDescriptor);
|
||||
};
|
||||
|
||||
// Turns a PoolingDescriptor structure into a cudnn pooling descriptor handle
|
||||
// within a scope.
|
||||
class ScopedPoolingDescriptor {
|
||||
class CudnnPoolingDescriptor {
|
||||
public:
|
||||
explicit ScopedPoolingDescriptor(
|
||||
explicit CudnnPoolingDescriptor(
|
||||
const dnn::PoolingDescriptor& pooling_descriptor)
|
||||
: handle_(CreatePoolingDescriptor()) {
|
||||
const std::vector<int64> strides64 = pooling_descriptor.strides();
|
||||
@ -739,13 +739,13 @@ class ScopedPoolingDescriptor {
|
||||
private:
|
||||
PoolingDescriptor handle_; // Owned.
|
||||
|
||||
SE_DISALLOW_COPY_AND_ASSIGN(ScopedPoolingDescriptor);
|
||||
SE_DISALLOW_COPY_AND_ASSIGN(CudnnPoolingDescriptor);
|
||||
};
|
||||
|
||||
// Turns a NormalizeDescriptor structure into a cudnn LRN descriptor handle.
|
||||
class ScopedNormalizeDescriptor {
|
||||
class CudnnNormalizeDescriptor {
|
||||
public:
|
||||
explicit ScopedNormalizeDescriptor(
|
||||
explicit CudnnNormalizeDescriptor(
|
||||
const dnn::NormalizeDescriptor& normalize_descriptor)
|
||||
: handle_(CreateLrnDescriptor()) {
|
||||
// The range specifies that the indices in the closed range
|
||||
@ -777,16 +777,16 @@ class ScopedNormalizeDescriptor {
|
||||
private:
|
||||
LrnDescriptor handle_; // Owned.
|
||||
|
||||
SE_DISALLOW_COPY_AND_ASSIGN(ScopedNormalizeDescriptor);
|
||||
SE_DISALLOW_COPY_AND_ASSIGN(CudnnNormalizeDescriptor);
|
||||
};
|
||||
|
||||
// Turns a ActivationDescriptor structure into a cudnn activation
|
||||
// descriptor handle within a scope.
|
||||
class ScopedActivationDescriptor {
|
||||
class CudnnActivationDescriptor {
|
||||
public:
|
||||
ScopedActivationDescriptor(dnn::ActivationMode activation_mode,
|
||||
cudnnNanPropagation_t nan_propagation,
|
||||
double value_max)
|
||||
CudnnActivationDescriptor(dnn::ActivationMode activation_mode,
|
||||
cudnnNanPropagation_t nan_propagation,
|
||||
double value_max)
|
||||
: handle_(CreateActivationDescriptor()) {
|
||||
double relu_ceiling = 0.0;
|
||||
cudnnActivationMode_t mode;
|
||||
@ -822,7 +822,7 @@ class ScopedActivationDescriptor {
|
||||
private:
|
||||
ActivationDescriptor handle_; // Owned.
|
||||
|
||||
SE_DISALLOW_COPY_AND_ASSIGN(ScopedActivationDescriptor);
|
||||
SE_DISALLOW_COPY_AND_ASSIGN(CudnnActivationDescriptor);
|
||||
};
|
||||
|
||||
cudnnDataType_t ToCudnnDataType(
|
||||
@ -888,21 +888,21 @@ int CudnnDataTypeToByteSize(cudnnDataType_t data_type) {
|
||||
}
|
||||
}
|
||||
|
||||
class ScopedDropoutDescriptor {
|
||||
explicit ScopedDropoutDescriptor(DropoutDescriptor handle)
|
||||
class CudnnDropoutDescriptor {
|
||||
explicit CudnnDropoutDescriptor(DropoutDescriptor handle)
|
||||
: handle_(std::move(handle)) {}
|
||||
|
||||
public:
|
||||
ScopedDropoutDescriptor(ScopedDropoutDescriptor&&) = default;
|
||||
CudnnDropoutDescriptor(CudnnDropoutDescriptor&&) = default;
|
||||
|
||||
static port::StatusOr<ScopedDropoutDescriptor> Create(
|
||||
static port::StatusOr<CudnnDropoutDescriptor> Create(
|
||||
const CudnnHandle& cudnn, float dropout, uint64 seed,
|
||||
ScratchAllocator* state_allocator) {
|
||||
DropoutDescriptor handle = CreateDropoutDescriptor();
|
||||
|
||||
if (dropout == 0.0f) {
|
||||
// Return 'empty' dropout descriptor.
|
||||
return ScopedDropoutDescriptor(std::move(handle));
|
||||
return CudnnDropoutDescriptor(std::move(handle));
|
||||
}
|
||||
|
||||
DeviceMemory<uint8> state_memory;
|
||||
@ -917,14 +917,14 @@ class ScopedDropoutDescriptor {
|
||||
handle.get(), cudnn.handle(), dropout, state_memory.opaque(),
|
||||
state_memory.size(), seed));
|
||||
|
||||
return ScopedDropoutDescriptor(std::move(handle));
|
||||
return CudnnDropoutDescriptor(std::move(handle));
|
||||
}
|
||||
|
||||
cudnnDropoutDescriptor_t handle() const { return handle_.get(); }
|
||||
|
||||
private:
|
||||
DropoutDescriptor handle_; // Owned.
|
||||
SE_DISALLOW_COPY_AND_ASSIGN(ScopedDropoutDescriptor);
|
||||
SE_DISALLOW_COPY_AND_ASSIGN(CudnnDropoutDescriptor);
|
||||
};
|
||||
|
||||
class CudnnRnnParamsDescriptor {
|
||||
@ -973,7 +973,7 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor {
|
||||
cudnnRNNMode_t rnn_mode, cudnnDataType_t data_type,
|
||||
cudnnDataType_t compute_type,
|
||||
const dnn::AlgorithmConfig& algorithm_config,
|
||||
ScopedDropoutDescriptor dropout_desc,
|
||||
CudnnDropoutDescriptor dropout_desc,
|
||||
CudnnRnnParamsDescriptor params_desc)
|
||||
: rnn_desc_(std::move(rnn_desc)),
|
||||
rnn_plan_(std::move(rnn_plan)),
|
||||
@ -1002,8 +1002,8 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor {
|
||||
const dnn::AlgorithmConfig& algorithm_config, float dropout, uint64 seed,
|
||||
ScratchAllocator* state_allocator) {
|
||||
SE_ASSIGN_OR_RETURN(
|
||||
ScopedDropoutDescriptor dropout_desc,
|
||||
ScopedDropoutDescriptor::Create(cudnn, dropout, seed, state_allocator));
|
||||
CudnnDropoutDescriptor dropout_desc,
|
||||
CudnnDropoutDescriptor::Create(cudnn, dropout, seed, state_allocator));
|
||||
|
||||
cuda::RnnDescriptor rnn_desc = CreateRnnDescriptor();
|
||||
cudnnRNNAlgo_t rnn_algo = ToCudnnRNNAlgo(algorithm_config.algorithm());
|
||||
@ -1097,7 +1097,7 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor {
|
||||
cudnnDataType_t data_type_;
|
||||
cudnnDataType_t compute_type_;
|
||||
dnn::AlgorithmConfig algorithm_config_;
|
||||
ScopedDropoutDescriptor dropout_desc_;
|
||||
CudnnDropoutDescriptor dropout_desc_;
|
||||
CudnnRnnParamsDescriptor params_desc_;
|
||||
SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnDescriptor);
|
||||
};
|
||||
@ -1926,10 +1926,9 @@ namespace {
|
||||
// and backward filter.
|
||||
|
||||
port::StatusOr<cudnnConvolutionFwdAlgo_t> GetCudnnConvolutionForwardAlgo(
|
||||
const CudnnHandle& cudnn, const ScopedTensorDescriptor& input_nd,
|
||||
const ScopedFilterDescriptor& filter,
|
||||
const ScopedConvolutionDescriptor& conv,
|
||||
const ScopedTensorDescriptor& output_nd, bool specify_workspace_limit,
|
||||
const CudnnHandle& cudnn, const CudnnTensorDescriptor& input_nd,
|
||||
const CudnnFilterDescriptor& filter, const CudnnConvolutionDescriptor& conv,
|
||||
const CudnnTensorDescriptor& output_nd, bool specify_workspace_limit,
|
||||
size_t memory_limit_bytes) {
|
||||
cudnnConvolutionFwdPreference_t preference =
|
||||
specify_workspace_limit ? CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT
|
||||
@ -1943,10 +1942,10 @@ port::StatusOr<cudnnConvolutionFwdAlgo_t> GetCudnnConvolutionForwardAlgo(
|
||||
|
||||
port::StatusOr<cudnnConvolutionBwdDataAlgo_t>
|
||||
GetCudnnConvolutionBackwardDataAlgo(const CudnnHandle& cudnn,
|
||||
const ScopedTensorDescriptor& input_nd,
|
||||
const ScopedFilterDescriptor& filter,
|
||||
const ScopedConvolutionDescriptor& conv,
|
||||
const ScopedTensorDescriptor& output_nd,
|
||||
const CudnnTensorDescriptor& input_nd,
|
||||
const CudnnFilterDescriptor& filter,
|
||||
const CudnnConvolutionDescriptor& conv,
|
||||
const CudnnTensorDescriptor& output_nd,
|
||||
bool specify_workspace_limit,
|
||||
size_t memory_limit_bytes) {
|
||||
cudnnConvolutionBwdDataPreference_t preference =
|
||||
@ -1962,10 +1961,10 @@ GetCudnnConvolutionBackwardDataAlgo(const CudnnHandle& cudnn,
|
||||
|
||||
port::StatusOr<cudnnConvolutionBwdFilterAlgo_t>
|
||||
GetCudnnConvolutionBackwardFilterAlgo(const CudnnHandle& cudnn,
|
||||
const ScopedTensorDescriptor& input_nd,
|
||||
const ScopedFilterDescriptor& filter,
|
||||
const ScopedConvolutionDescriptor& conv,
|
||||
const ScopedTensorDescriptor& output_nd,
|
||||
const CudnnTensorDescriptor& input_nd,
|
||||
const CudnnFilterDescriptor& filter,
|
||||
const CudnnConvolutionDescriptor& conv,
|
||||
const CudnnTensorDescriptor& output_nd,
|
||||
bool specify_workspace_limit,
|
||||
size_t memory_limit_bytes) {
|
||||
cudnnConvolutionBwdFilterPreference_t preference =
|
||||
@ -1982,10 +1981,9 @@ GetCudnnConvolutionBackwardFilterAlgo(const CudnnHandle& cudnn,
|
||||
port::StatusOr<DeviceMemory<uint8>> AllocateCudnnConvolutionForwardWorkspace(
|
||||
Stream* stream, const CudnnHandle& cudnn,
|
||||
const dnn::AlgorithmDesc& algorithm_desc,
|
||||
const ScopedTensorDescriptor& input_nd,
|
||||
const ScopedFilterDescriptor& filter,
|
||||
const ScopedConvolutionDescriptor& conv,
|
||||
const ScopedTensorDescriptor& output_nd,
|
||||
const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
|
||||
const CudnnConvolutionDescriptor& conv,
|
||||
const CudnnTensorDescriptor& output_nd,
|
||||
ScratchAllocator* scratch_allocator) {
|
||||
// TODO(csigg): This has side effects on the convolution descriptor. It is
|
||||
// functionally correct because the convolution is run with the algorithm of
|
||||
@ -2025,10 +2023,9 @@ port::StatusOr<DeviceMemory<uint8>>
|
||||
AllocateCudnnConvolutionBackwardDataWorkspace(
|
||||
Stream* stream, const CudnnHandle& cudnn,
|
||||
const dnn::AlgorithmDesc& algorithm_desc,
|
||||
const ScopedTensorDescriptor& input_nd,
|
||||
const ScopedFilterDescriptor& filter,
|
||||
const ScopedConvolutionDescriptor& conv,
|
||||
const ScopedTensorDescriptor& output_nd,
|
||||
const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
|
||||
const CudnnConvolutionDescriptor& conv,
|
||||
const CudnnTensorDescriptor& output_nd,
|
||||
ScratchAllocator* scratch_allocator) {
|
||||
// TODO(csigg): This has side effects on the convolution descriptor. It is
|
||||
// functionally correct because the convolution is run with the algorithm of
|
||||
@ -2070,10 +2067,9 @@ port::StatusOr<DeviceMemory<uint8>>
|
||||
AllocateCudnnConvolutionBackwardFilterWorkspace(
|
||||
Stream* stream, const CudnnHandle& cudnn,
|
||||
const dnn::AlgorithmDesc& algorithm_desc,
|
||||
const ScopedTensorDescriptor& input_nd,
|
||||
const ScopedFilterDescriptor& filter,
|
||||
const ScopedConvolutionDescriptor& conv,
|
||||
const ScopedTensorDescriptor& output_nd,
|
||||
const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
|
||||
const CudnnConvolutionDescriptor& conv,
|
||||
const CudnnTensorDescriptor& output_nd,
|
||||
ScratchAllocator* scratch_allocator) {
|
||||
// TODO(csigg): This has side effects on the convolution descriptor. It is
|
||||
// functionally correct because the convolution is run with the algorithm of
|
||||
@ -2114,11 +2110,10 @@ AllocateCudnnConvolutionBackwardFilterWorkspace(
|
||||
port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionForwardAlgorithm(
|
||||
Stream* stream, const CudnnHandle& cudnn,
|
||||
const dnn::AlgorithmConfig& algorithm_config,
|
||||
const ScopedTensorDescriptor& input_nd,
|
||||
const ScopedFilterDescriptor& filter,
|
||||
const ScopedConvolutionDescriptor& conv,
|
||||
const ScopedTensorDescriptor& output_nd,
|
||||
ScratchAllocator* scratch_allocator, DeviceMemory<uint8>* scratch) {
|
||||
const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
|
||||
const CudnnConvolutionDescriptor& conv,
|
||||
const CudnnTensorDescriptor& output_nd, ScratchAllocator* scratch_allocator,
|
||||
DeviceMemory<uint8>* scratch) {
|
||||
dnn::AlgorithmDesc algo_desc = algorithm_config.algorithm();
|
||||
if (algorithm_config.algorithm().is_default()) {
|
||||
// Pick fastest algorithm within memory limit according to cuDNN's
|
||||
@ -2164,11 +2159,10 @@ port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionForwardAlgorithm(
|
||||
port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardDataAlgorithm(
|
||||
Stream* stream, const CudnnHandle& cudnn,
|
||||
const dnn::AlgorithmConfig& algorithm_config,
|
||||
const ScopedTensorDescriptor& input_nd,
|
||||
const ScopedFilterDescriptor& filter,
|
||||
const ScopedConvolutionDescriptor& conv,
|
||||
const ScopedTensorDescriptor& output_nd,
|
||||
ScratchAllocator* scratch_allocator, DeviceMemory<uint8>* scratch) {
|
||||
const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
|
||||
const CudnnConvolutionDescriptor& conv,
|
||||
const CudnnTensorDescriptor& output_nd, ScratchAllocator* scratch_allocator,
|
||||
DeviceMemory<uint8>* scratch) {
|
||||
dnn::AlgorithmDesc algo_desc = algorithm_config.algorithm();
|
||||
if (algorithm_config.algorithm().is_default()) {
|
||||
// Pick fastest algorithm within memory limit according to cuDNN's
|
||||
@ -2214,11 +2208,10 @@ port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardDataAlgorithm(
|
||||
port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardFilterAlgorithm(
|
||||
Stream* stream, const CudnnHandle& cudnn,
|
||||
const dnn::AlgorithmConfig& algorithm_config,
|
||||
const ScopedTensorDescriptor& input_nd,
|
||||
const ScopedFilterDescriptor& filter,
|
||||
const ScopedConvolutionDescriptor& conv,
|
||||
const ScopedTensorDescriptor& output_nd,
|
||||
ScratchAllocator* scratch_allocator, DeviceMemory<uint8>* scratch) {
|
||||
const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
|
||||
const CudnnConvolutionDescriptor& conv,
|
||||
const CudnnTensorDescriptor& output_nd, ScratchAllocator* scratch_allocator,
|
||||
DeviceMemory<uint8>* scratch) {
|
||||
dnn::AlgorithmDesc algo_desc = algorithm_config.algorithm();
|
||||
if (algorithm_config.algorithm().is_default()) {
|
||||
// Pick fastest algorithm within memory limit according to cuDNN's
|
||||
@ -2387,11 +2380,11 @@ port::Status CudnnSupport::DoConvolveImpl(
|
||||
const dnn::AlgorithmConfig& algorithm_config,
|
||||
dnn::ProfileResult* output_profile_result) {
|
||||
cudnnDataType_t cudnn_type = GetCudnnDataType<T>();
|
||||
ScopedTensorDescriptor input_nd(input_descriptor, cudnn_type);
|
||||
ScopedTensorDescriptor output_nd(output_descriptor, cudnn_type);
|
||||
ScopedFilterDescriptor filter(filter_descriptor, cudnn_type);
|
||||
ScopedConvolutionDescriptor conv(convolution_descriptor,
|
||||
GetConvComputeType<T>());
|
||||
CudnnTensorDescriptor input_nd(input_descriptor, cudnn_type);
|
||||
CudnnTensorDescriptor output_nd(output_descriptor, cudnn_type);
|
||||
CudnnFilterDescriptor filter(filter_descriptor, cudnn_type);
|
||||
CudnnConvolutionDescriptor conv(convolution_descriptor,
|
||||
GetConvComputeType<T>());
|
||||
|
||||
auto cudnn = cudnn_->GetHandle(parent_, stream);
|
||||
// Alpha is the scaling factor for input.
|
||||
@ -2493,14 +2486,14 @@ port::Status CudnnSupport::DoFusedConvolveImpl(
|
||||
"Relu activation.");
|
||||
}
|
||||
|
||||
ScopedTensorDescriptor conv_input_nd(
|
||||
CudnnTensorDescriptor conv_input_nd(
|
||||
conv_input_descriptor, static_cast<cudnnDataType_t>(cudnn_data_type));
|
||||
ScopedTensorDescriptor output_nd(
|
||||
CudnnTensorDescriptor output_nd(
|
||||
output_descriptor, static_cast<cudnnDataType_t>(cudnn_data_type));
|
||||
ScopedFilterDescriptor filter(filter_descriptor,
|
||||
static_cast<cudnnDataType_t>(cudnn_data_type));
|
||||
ScopedTensorDescriptor bias_nd(bias_descriptor, CUDNN_DATA_FLOAT);
|
||||
ScopedConvolutionDescriptor conv(
|
||||
CudnnFilterDescriptor filter(filter_descriptor,
|
||||
static_cast<cudnnDataType_t>(cudnn_data_type));
|
||||
CudnnTensorDescriptor bias_nd(bias_descriptor, CUDNN_DATA_FLOAT);
|
||||
CudnnConvolutionDescriptor conv(
|
||||
convolution_descriptor, static_cast<cudnnDataType_t>(cudnn_compute_type));
|
||||
|
||||
auto cudnn = cudnn_->GetHandle(parent_, stream);
|
||||
@ -2528,7 +2521,7 @@ port::Status CudnnSupport::DoFusedConvolveImpl(
|
||||
// activation descriptor. Note that this will change the nan propagation
|
||||
// behavior from separate conv, bias, and relu (which by default is
|
||||
// CUDNN_PROPAGATE_NAN.
|
||||
ScopedActivationDescriptor activation_desc(
|
||||
CudnnActivationDescriptor activation_desc(
|
||||
activation_mode, CUDNN_NOT_PROPAGATE_NAN, output_descriptor.value_max());
|
||||
auto side_input_data_ptr = (side_input_scale == 0) ? output_data->opaque()
|
||||
: side_input_data.opaque();
|
||||
@ -2740,8 +2733,8 @@ port::Status CudnnSupport::DoBatchNormalizationForwardImpl(
|
||||
DeviceMemory<U>* saved_mean, DeviceMemory<U>* saved_inv_var,
|
||||
bool is_training, std::function<const DeviceMemory<U>&()> var_to_inv_var,
|
||||
std::function<void()> inv_var_to_var) {
|
||||
ScopedTensorDescriptor x_descriptor(x_desc, ToCudnnDataType(input_data_type));
|
||||
ScopedTensorDescriptor scale_offset_descriptor(
|
||||
CudnnTensorDescriptor x_descriptor(x_desc, ToCudnnDataType(input_data_type));
|
||||
CudnnTensorDescriptor scale_offset_descriptor(
|
||||
scale_offset_desc, ToCudnnDataType(scale_data_type));
|
||||
cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL;
|
||||
#if CUDNN_VERSION >= 7000
|
||||
@ -2825,9 +2818,9 @@ port::Status CudnnSupport::DoBatchNormalizationBackwardImpl(
|
||||
const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
|
||||
DeviceMemory<T>* x_backprop, DeviceMemory<U>* scale_backprop,
|
||||
DeviceMemory<U>* offset_backprop) {
|
||||
ScopedTensorDescriptor x_descriptor(
|
||||
CudnnTensorDescriptor x_descriptor(
|
||||
x_desc, static_cast<cudnnDataType_t>(cudnn_input_type));
|
||||
ScopedTensorDescriptor scale_offset_descriptor(
|
||||
CudnnTensorDescriptor scale_offset_descriptor(
|
||||
scale_offset_desc, static_cast<cudnnDataType_t>(cudnn_scale_type));
|
||||
cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL;
|
||||
#if CUDNN_VERSION >= 7000
|
||||
@ -3017,9 +3010,9 @@ bool CudnnSupport::DoTransformTensor(Stream* stream,
|
||||
dnn::DataType output_type, float scale,
|
||||
DeviceMemoryBase* output_data) {
|
||||
float beta = 0.0f;
|
||||
ScopedTensorDescriptor input_tensor_desc(
|
||||
CudnnTensorDescriptor input_tensor_desc(
|
||||
input_desc, ToCudnnDataType(input_type, input_desc.layout()));
|
||||
ScopedTensorDescriptor output_tensor_desc(
|
||||
CudnnTensorDescriptor output_tensor_desc(
|
||||
output_desc, ToCudnnDataType(output_type, output_desc.layout()));
|
||||
auto cudnn = cudnn_->GetHandle(parent_, stream);
|
||||
auto status = [&] {
|
||||
@ -3056,11 +3049,11 @@ port::Status CudnnSupport::DoConvolveBackwardDataImpl(
|
||||
|
||||
auto cudnn = cudnn_->GetHandle(parent_, stream);
|
||||
|
||||
ScopedTensorDescriptor out_back_nd(output_descriptor, cudnn_type);
|
||||
ScopedTensorDescriptor in_back_nd(input_descriptor, cudnn_type);
|
||||
ScopedFilterDescriptor filter(filter_descriptor, cudnn_type);
|
||||
ScopedConvolutionDescriptor conv(convolution_descriptor,
|
||||
GetConvComputeType<T>());
|
||||
CudnnTensorDescriptor out_back_nd(output_descriptor, cudnn_type);
|
||||
CudnnTensorDescriptor in_back_nd(input_descriptor, cudnn_type);
|
||||
CudnnFilterDescriptor filter(filter_descriptor, cudnn_type);
|
||||
CudnnConvolutionDescriptor conv(convolution_descriptor,
|
||||
GetConvComputeType<T>());
|
||||
|
||||
const bool is_profiling = output_profile_result != nullptr;
|
||||
|
||||
@ -3192,11 +3185,11 @@ port::Status CudnnSupport::DoConvolveBackwardFilterImpl(
|
||||
|
||||
auto cudnn = cudnn_->GetHandle(parent_, stream);
|
||||
|
||||
ScopedTensorDescriptor out_back_nd(output_descriptor, cudnn_type);
|
||||
ScopedTensorDescriptor input_nd(input_descriptor, cudnn_type);
|
||||
ScopedFilterDescriptor filter(filter_descriptor, cudnn_type);
|
||||
ScopedConvolutionDescriptor conv(convolution_descriptor,
|
||||
GetConvComputeType<T>());
|
||||
CudnnTensorDescriptor out_back_nd(output_descriptor, cudnn_type);
|
||||
CudnnTensorDescriptor input_nd(input_descriptor, cudnn_type);
|
||||
CudnnFilterDescriptor filter(filter_descriptor, cudnn_type);
|
||||
CudnnConvolutionDescriptor conv(convolution_descriptor,
|
||||
GetConvComputeType<T>());
|
||||
|
||||
const bool is_profiling = output_profile_result != nullptr;
|
||||
|
||||
@ -3338,8 +3331,8 @@ port::Status CudnnSupport::DoConvolveBackwardBiasImpl(
|
||||
const dnn::BatchDescriptor& bias_descriptor,
|
||||
DeviceMemory<T>* backward_bias_data) {
|
||||
cudnnDataType_t cudnn_type = GetCudnnDataType<T>();
|
||||
ScopedTensorDescriptor input_nd(input_descriptor, cudnn_type);
|
||||
ScopedTensorDescriptor bias_nd(bias_descriptor, cudnn_type);
|
||||
CudnnTensorDescriptor input_nd(input_descriptor, cudnn_type);
|
||||
CudnnTensorDescriptor bias_nd(bias_descriptor, cudnn_type);
|
||||
|
||||
// Alpha is the scaling factor for input.
|
||||
float alpha = 1.0;
|
||||
@ -3526,7 +3519,7 @@ bool CudnnSupport::DoBiasAdd(Stream* stream,
|
||||
const DeviceMemory<float>& biases,
|
||||
const dnn::BatchDescriptor& dimensions,
|
||||
DeviceMemory<float>* output_data) {
|
||||
ScopedTensorDescriptor input_descriptor(dimensions, CUDNN_DATA_FLOAT);
|
||||
CudnnTensorDescriptor input_descriptor(dimensions, CUDNN_DATA_FLOAT);
|
||||
|
||||
dnn::BatchDescriptor bias_dimensions;
|
||||
bias_dimensions.set_count(1)
|
||||
@ -3534,7 +3527,7 @@ bool CudnnSupport::DoBiasAdd(Stream* stream,
|
||||
.set_height(1)
|
||||
.set_width(1)
|
||||
.set_layout(dnn::DataLayout::kBatchYXDepth);
|
||||
ScopedTensorDescriptor bias_descriptor(bias_dimensions, CUDNN_DATA_FLOAT);
|
||||
CudnnTensorDescriptor bias_descriptor(bias_dimensions, CUDNN_DATA_FLOAT);
|
||||
|
||||
// cudnnAddTensor after R3 is in-place, so we need to copy input_data to
|
||||
// output_data before doing the addition, unless the input and
|
||||
@ -3570,10 +3563,10 @@ bool CudnnSupport::DoActivate(Stream* stream,
|
||||
const DeviceMemory<float>& input_data,
|
||||
DeviceMemory<float>* output_data,
|
||||
uint64 options) {
|
||||
ScopedActivationDescriptor activation_desc(
|
||||
CudnnActivationDescriptor activation_desc(
|
||||
activation_mode, CUDNN_PROPAGATE_NAN, dimensions.value_max());
|
||||
|
||||
ScopedTensorDescriptor input_nd(dimensions, CUDNN_DATA_FLOAT);
|
||||
CudnnTensorDescriptor input_nd(dimensions, CUDNN_DATA_FLOAT);
|
||||
// Alpha is the input scaling factor.
|
||||
float alpha = 1.0;
|
||||
// Beta is the output scaling factor.
|
||||
@ -3600,9 +3593,9 @@ bool CudnnSupport::DoPoolForward(
|
||||
// Beta is the scaling factor for output.
|
||||
double beta = 0.0;
|
||||
|
||||
ScopedTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_DOUBLE);
|
||||
ScopedTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_DOUBLE);
|
||||
ScopedPoolingDescriptor pooling_desc(pooling_dimensions);
|
||||
CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_DOUBLE);
|
||||
CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_DOUBLE);
|
||||
CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
|
||||
|
||||
auto cudnn = cudnn_->GetHandle(parent_, stream);
|
||||
auto status = [&] {
|
||||
@ -3625,9 +3618,9 @@ bool CudnnSupport::DoPoolForward(
|
||||
// Beta is the scaling factor for output.
|
||||
float beta = 0.0;
|
||||
|
||||
ScopedTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_FLOAT);
|
||||
ScopedTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_FLOAT);
|
||||
ScopedPoolingDescriptor pooling_desc(pooling_dimensions);
|
||||
CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_FLOAT);
|
||||
CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_FLOAT);
|
||||
CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
|
||||
|
||||
auto cudnn = cudnn_->GetHandle(parent_, stream);
|
||||
auto status = [&] {
|
||||
@ -3650,9 +3643,9 @@ bool CudnnSupport::DoPoolForward(
|
||||
// Beta is the scaling factor for output.
|
||||
float beta = 0.0;
|
||||
|
||||
ScopedTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_HALF);
|
||||
ScopedTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_HALF);
|
||||
ScopedPoolingDescriptor pooling_desc(pooling_dimensions);
|
||||
CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_HALF);
|
||||
CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_HALF);
|
||||
CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
|
||||
auto cudnn = cudnn_->GetHandle(parent_, stream);
|
||||
auto status = [&] {
|
||||
RETURN_IF_CUDNN_ERROR(cudnnPoolingForward(
|
||||
@ -3676,9 +3669,9 @@ bool CudnnSupport::DoPoolBackward(
|
||||
// Beta is the scaling factor for output.
|
||||
double beta = 0.0;
|
||||
|
||||
ScopedTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_DOUBLE);
|
||||
ScopedTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_DOUBLE);
|
||||
ScopedPoolingDescriptor pooling_desc(pooling_dimensions);
|
||||
CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_DOUBLE);
|
||||
CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_DOUBLE);
|
||||
CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
|
||||
|
||||
auto cudnn = cudnn_->GetHandle(parent_, stream);
|
||||
auto status = [&] {
|
||||
@ -3705,9 +3698,9 @@ bool CudnnSupport::DoPoolBackward(
|
||||
// Beta is the scaling factor for output.
|
||||
float beta = 0.0;
|
||||
|
||||
ScopedTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_FLOAT);
|
||||
ScopedTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_FLOAT);
|
||||
ScopedPoolingDescriptor pooling_desc(pooling_dimensions);
|
||||
CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_FLOAT);
|
||||
CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_FLOAT);
|
||||
CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
|
||||
|
||||
auto cudnn = cudnn_->GetHandle(parent_, stream);
|
||||
auto status = [&] {
|
||||
@ -3734,9 +3727,9 @@ bool CudnnSupport::DoPoolBackward(
|
||||
// Beta is the scaling factor for output.
|
||||
float beta = 0.0;
|
||||
|
||||
ScopedTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_HALF);
|
||||
ScopedTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_HALF);
|
||||
ScopedPoolingDescriptor pooling_desc(pooling_dimensions);
|
||||
CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_HALF);
|
||||
CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_HALF);
|
||||
CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
|
||||
|
||||
auto cudnn = cudnn_->GetHandle(parent_, stream);
|
||||
auto status = [&] {
|
||||
@ -3771,8 +3764,8 @@ bool CudnnSupport::DoNormalizeWithDimensions(
|
||||
return false;
|
||||
}
|
||||
|
||||
ScopedTensorDescriptor dims(dimensions, CUDNN_DATA_FLOAT);
|
||||
ScopedNormalizeDescriptor normalize(normalize_descriptor);
|
||||
CudnnTensorDescriptor dims(dimensions, CUDNN_DATA_FLOAT);
|
||||
CudnnNormalizeDescriptor normalize(normalize_descriptor);
|
||||
|
||||
// Alpha is the scaling factor for input.
|
||||
float alpha = 1.0f;
|
||||
@ -3808,8 +3801,8 @@ bool CudnnSupport::DoNormalizeBackwardWithDimensions(
|
||||
return false;
|
||||
}
|
||||
|
||||
ScopedTensorDescriptor dims(dimensions, CUDNN_DATA_FLOAT);
|
||||
ScopedNormalizeDescriptor normalize(normalize_descriptor);
|
||||
CudnnTensorDescriptor dims(dimensions, CUDNN_DATA_FLOAT);
|
||||
CudnnNormalizeDescriptor normalize(normalize_descriptor);
|
||||
|
||||
float alpha = 1.0f;
|
||||
float beta = 0.0f;
|
||||
@ -3932,9 +3925,9 @@ bool CudnnSupport::DeriveOutputBatchDescriptor(
|
||||
const dnn::FilterDescriptor& filter_descriptor,
|
||||
const dnn::ConvolutionDescriptor& convolution_descriptor,
|
||||
dnn::BatchDescriptor* output_batch_descriptor) {
|
||||
ScopedTensorDescriptor input_nd(batch_descriptor, CUDNN_DATA_FLOAT);
|
||||
ScopedFilterDescriptor filter(filter_descriptor, CUDNN_DATA_FLOAT);
|
||||
ScopedConvolutionDescriptor conv(convolution_descriptor, CUDNN_DATA_FLOAT);
|
||||
CudnnTensorDescriptor input_nd(batch_descriptor, CUDNN_DATA_FLOAT);
|
||||
CudnnFilterDescriptor filter(filter_descriptor, CUDNN_DATA_FLOAT);
|
||||
CudnnConvolutionDescriptor conv(convolution_descriptor, CUDNN_DATA_FLOAT);
|
||||
|
||||
int dn = batch_descriptor.ndims() + 2;
|
||||
std::vector<int> dims(dn); // in BDYX
|
||||
|
@ -15,9 +15,6 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h"
|
||||
|
||||
#if defined(__APPLE__)
|
||||
#include <mach-o/dyld.h>
|
||||
#endif
|
||||
#if defined(PLATFORM_WINDOWS)
|
||||
#include <windows.h>
|
||||
#define PATH_MAX MAX_PATH
|
||||
@ -179,19 +176,11 @@ bool CUDAExecutor::FindOnDiskForComputeCapability(
|
||||
// would return /usr/bin.
|
||||
static string GetBinaryDir(bool strip_exe) {
|
||||
char exe_path[PATH_MAX] = {0};
|
||||
#if defined(__APPLE__)
|
||||
uint32_t buffer_size = 0U;
|
||||
_NSGetExecutablePath(nullptr, &buffer_size);
|
||||
char unresolved_path[buffer_size];
|
||||
_NSGetExecutablePath(unresolved_path, &buffer_size);
|
||||
CHECK_ERR(realpath(unresolved_path, exe_path) ? 1 : -1);
|
||||
#else
|
||||
#if defined(PLATFORM_WINDOWS)
|
||||
HMODULE hModule = GetModuleHandle(NULL);
|
||||
GetModuleFileName(hModule, exe_path, MAX_PATH);
|
||||
#else
|
||||
CHECK_ERR(readlink("/proc/self/exe", exe_path, sizeof(exe_path) - 1));
|
||||
#endif
|
||||
#endif
|
||||
// Make sure it's null-terminated:
|
||||
exe_path[sizeof(exe_path) - 1] = 0;
|
||||
@ -854,10 +843,7 @@ CudaContext* CUDAExecutor::cuda_context() { return context_; }
|
||||
// For anything more complicated/prod-focused than this, you'll likely want to
|
||||
// turn to gsys' topology modeling.
|
||||
static int TryToReadNumaNode(const string &pci_bus_id, int device_ordinal) {
|
||||
#if defined(__APPLE__)
|
||||
LOG(INFO) << "OS X does not support NUMA - returning NUMA node zero";
|
||||
return 0;
|
||||
#elif defined(PLATFORM_WINDOWS)
|
||||
#if defined(PLATFORM_WINDOWS)
|
||||
// Windows support for NUMA is not currently implemented. Return node 0.
|
||||
return 0;
|
||||
#elif defined(__aarch64__)
|
||||
|
@ -5,12 +5,21 @@ licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
load("//tensorflow/tools/api/generator:api_gen.bzl", "TENSORFLOW_API_INIT_FILES")
|
||||
|
||||
py_library(
|
||||
name = "doc_srcs",
|
||||
srcs = ["doc_srcs.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "create_python_api",
|
||||
srcs = ["create_python_api.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":doc_srcs",
|
||||
"//tensorflow/python:no_contrib",
|
||||
],
|
||||
)
|
||||
@ -24,3 +33,18 @@ py_test(
|
||||
"//tensorflow/python:client_testlib",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "tensorflow_doc_srcs_test",
|
||||
srcs = ["doc_srcs_test.py"],
|
||||
args = [
|
||||
"--package=tensorflow.python",
|
||||
] + TENSORFLOW_API_INIT_FILES,
|
||||
main = "doc_srcs_test.py",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":doc_srcs",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:no_contrib",
|
||||
],
|
||||
)
|
||||
|
@ -26,6 +26,7 @@ import sys
|
||||
|
||||
from tensorflow.python.util import tf_decorator
|
||||
from tensorflow.python.util import tf_export
|
||||
from tensorflow.tools.api.generator import doc_srcs
|
||||
|
||||
API_ATTRS = tf_export.API_ATTRS
|
||||
|
||||
@ -36,10 +37,9 @@ _SYMBOLS_TO_SKIP_EXPLICITLY = {
|
||||
# would have side effects.
|
||||
'tensorflow.python.platform.flags.FLAGS'
|
||||
}
|
||||
_GENERATED_FILE_HEADER = """\"\"\"Imports for Python API.
|
||||
|
||||
This file is MACHINE GENERATED! Do not edit.
|
||||
Generated by: tensorflow/tools/api/generator/create_python_api.py script.
|
||||
_GENERATED_FILE_HEADER = """# This file is MACHINE GENERATED! Do not edit.
|
||||
# Generated by: tensorflow/tools/api/generator/create_python_api.py script.
|
||||
\"\"\"%s
|
||||
\"\"\"
|
||||
|
||||
from __future__ import print_function
|
||||
@ -252,6 +252,44 @@ def get_module(dir_path, relative_to_dir):
|
||||
return dir_path.replace('/', '.').strip('.')
|
||||
|
||||
|
||||
def get_module_docstring(module_name, package):
|
||||
"""Get docstring for the given module.
|
||||
|
||||
This method looks for docstring in the following order:
|
||||
1. Checks if module has a docstring specified in doc_srcs.
|
||||
2. Checks if module has a docstring source module specified
|
||||
in doc_srcs. If it does, gets docstring from that module.
|
||||
3. Checks if module with module_name exists under base package.
|
||||
If it does, gets docstring from that module.
|
||||
4. Returns a default docstring.
|
||||
|
||||
Args:
|
||||
module_name: module name relative to tensorflow
|
||||
(excluding 'tensorflow.' prefix) to get a docstring for.
|
||||
package: Base python package containing python with target tf_export
|
||||
decorators.
|
||||
|
||||
Returns:
|
||||
One-line docstring to describe the module.
|
||||
"""
|
||||
# Module under base package to get a docstring from.
|
||||
docstring_module_name = module_name
|
||||
|
||||
if module_name in doc_srcs.TENSORFLOW_DOC_SOURCES:
|
||||
docsrc = doc_srcs.TENSORFLOW_DOC_SOURCES[module_name]
|
||||
if docsrc.docstring:
|
||||
return docsrc.docstring
|
||||
if docsrc.docstring_module_name:
|
||||
docstring_module_name = docsrc.docstring_module_name
|
||||
|
||||
docstring_module_name = package + '.' + docstring_module_name
|
||||
if (docstring_module_name in sys.modules and
|
||||
sys.modules[docstring_module_name].__doc__):
|
||||
return sys.modules[docstring_module_name].__doc__
|
||||
|
||||
return 'Public API for tf.%s namespace.' % module_name
|
||||
|
||||
|
||||
def create_api_files(
|
||||
output_files, package, root_init_template, output_dir, api_name):
|
||||
"""Creates __init__.py files for the Python API.
|
||||
@ -295,7 +333,9 @@ def create_api_files(
|
||||
continue
|
||||
contents = ''
|
||||
if module or not root_init_template:
|
||||
contents = _GENERATED_FILE_HEADER + text + _GENERATED_FILE_FOOTER
|
||||
contents = (
|
||||
_GENERATED_FILE_HEADER %
|
||||
get_module_docstring(module, package) + text + _GENERATED_FILE_FOOTER)
|
||||
else:
|
||||
# Read base init file
|
||||
with open(root_init_template, 'r') as root_init_template_file:
|
||||
@ -308,7 +348,7 @@ def create_api_files(
|
||||
raise ValueError(
|
||||
'Missing outputs for python_api_gen genrule:\n%s.'
|
||||
'Make sure all required outputs are in the '
|
||||
'tensorflow/tools/api/generator/BUILD file.' %
|
||||
'tensorflow/tools/api/generator/api_gen.bzl file.' %
|
||||
',\n'.join(sorted(missing_output_files)))
|
||||
|
||||
|
||||
|
65
tensorflow/tools/api/generator/doc_srcs.py
Normal file
65
tensorflow/tools/api/generator/doc_srcs.py
Normal file
@ -0,0 +1,65 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Specifies sources of doc strings for API modules."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
|
||||
# Specifies docstring source for a module.
|
||||
# Only one of docstring or docstring_module_name should be set.
|
||||
# * If docstring is set, then we will use this docstring when
|
||||
# for the module.
|
||||
# * If docstring_module_name is set, then we will copy the docstring
|
||||
# from docstring source module.
|
||||
DocSource = collections.namedtuple(
|
||||
'DocSource', ['docstring', 'docstring_module_name'])
|
||||
# Each attribute of DocSource is optional.
|
||||
DocSource.__new__.__defaults__ = (None,) * len(DocSource._fields)
|
||||
|
||||
TENSORFLOW_DOC_SOURCES = {
|
||||
'app': DocSource(docstring_module_name='platform.app'),
|
||||
'compat': DocSource(docstring_module_name='util.compat'),
|
||||
'distributions': DocSource(
|
||||
docstring_module_name='ops.distributions.distributions'),
|
||||
'bitwise': DocSource(docstring_module_name='ops.bitwise_ops'),
|
||||
'errors': DocSource(docstring_module_name='framework.errors'),
|
||||
'gfile': DocSource(docstring_module_name='platform.gfile'),
|
||||
'graph_util': DocSource(docstring_module_name='framework.graph_util'),
|
||||
'image': DocSource(docstring_module_name='ops.image_ops'),
|
||||
'keras.estimator': DocSource(docstring_module_name='estimator.keras'),
|
||||
'linalg': DocSource(docstring_module_name='ops.linalg_ops'),
|
||||
'logging': DocSource(docstring_module_name='ops.logging_ops'),
|
||||
'losses': DocSource(docstring_module_name='ops.losses.losses'),
|
||||
'manip': DocSource(docstring_module_name='ops.manip_ops'),
|
||||
'math': DocSource(docstring_module_name='ops.math_ops'),
|
||||
'metrics': DocSource(docstring_module_name='ops.metrics'),
|
||||
'nn': DocSource(docstring_module_name='ops.nn_ops'),
|
||||
'nn.rnn_cell': DocSource(docstring_module_name='ops.rnn_cell'),
|
||||
'python_io': DocSource(docstring_module_name='lib.io.python_io'),
|
||||
'resource_loader': DocSource(
|
||||
docstring_module_name='platform.resource_loader'),
|
||||
'sets': DocSource(docstring_module_name='ops.sets'),
|
||||
'sparse': DocSource(docstring_module_name='ops.sparse_ops'),
|
||||
'spectral': DocSource(docstring_module_name='ops.spectral_ops'),
|
||||
'strings': DocSource(docstring_module_name='ops.string_ops'),
|
||||
'sysconfig': DocSource(docstring_module_name='platform.sysconfig'),
|
||||
'test': DocSource(docstring_module_name='platform.test'),
|
||||
'train': DocSource(docstring_module_name='training.training'),
|
||||
'train.queue_runner': DocSource(
|
||||
docstring_module_name='training.queue_runner'),
|
||||
}
|
80
tensorflow/tools/api/generator/doc_srcs_test.py
Normal file
80
tensorflow/tools/api/generator/doc_srcs_test.py
Normal file
@ -0,0 +1,80 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
"""Tests for tensorflow.tools.api.generator.doc_srcs."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import importlib
|
||||
import sys
|
||||
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.tools.api.generator import doc_srcs
|
||||
|
||||
|
||||
FLAGS = None
|
||||
|
||||
|
||||
class DocSrcsTest(test.TestCase):
|
||||
|
||||
def testModulesAreValidAPIModules(self):
|
||||
for module_name in doc_srcs.TENSORFLOW_DOC_SOURCES:
|
||||
# Convert module_name to corresponding __init__.py file path.
|
||||
file_path = module_name.replace('.', '/')
|
||||
if file_path:
|
||||
file_path += '/'
|
||||
file_path += '__init__.py'
|
||||
|
||||
if file_path not in FLAGS.outputs:
|
||||
self.assertFalse('%s is not a valid API module' % module_name)
|
||||
|
||||
def testHaveDocstringOrDocstringModule(self):
|
||||
for module_name, docsrc in doc_srcs.TENSORFLOW_DOC_SOURCES.items():
|
||||
if docsrc.docstring and docsrc.docstring_module_name:
|
||||
self.assertFalse(
|
||||
'%s contains DocSource has both a docstring and a '
|
||||
'docstring_module_name. '
|
||||
'Only one of "docstring" or "docstring_module_name" should be set.'
|
||||
% (module_name))
|
||||
|
||||
def testDocstringModulesAreValidModules(self):
|
||||
for _, docsrc in doc_srcs.TENSORFLOW_DOC_SOURCES.items():
|
||||
if docsrc.docstring_module_name:
|
||||
doc_module_name = '.'.join([
|
||||
FLAGS.package, docsrc.docstring_module_name])
|
||||
if doc_module_name not in sys.modules:
|
||||
sys.assertFalse(
|
||||
'docsources_module %s is not a valid module under %s.' %
|
||||
(docsrc.docstring_module_name, FLAGS.package))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'outputs', metavar='O', type=str, nargs='+',
|
||||
help='create_python_api output files.')
|
||||
parser.add_argument(
|
||||
'--package', type=str,
|
||||
help='Base package that imports modules containing the target tf_export '
|
||||
'decorators.')
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
|
||||
importlib.import_module(FLAGS.package)
|
||||
|
||||
# Now update argv, so that unittest library does not get confused.
|
||||
sys.argv = [sys.argv[0]] + unparsed
|
||||
test.main()
|
@ -20,6 +20,10 @@ tf_module {
|
||||
name: "adjust_hue"
|
||||
argspec: "args=[\'image\', \'delta\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "adjust_jpeg_quality"
|
||||
argspec: "args=[\'image\', \'jpeg_quality\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "adjust_saturation"
|
||||
argspec: "args=[\'image\', \'saturation_factor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
@ -144,6 +148,10 @@ tf_module {
|
||||
name: "random_hue"
|
||||
argspec: "args=[\'image\', \'max_delta\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "random_jpeg_quality"
|
||||
argspec: "args=[\'image\', \'min_jpeg_quality\', \'max_jpeg_quality\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "random_saturation"
|
||||
argspec: "args=[\'image\', \'lower\', \'upper\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
@ -166,7 +174,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "resize_images"
|
||||
argspec: "args=[\'images\', \'size\', \'method\', \'align_corners\'], varargs=None, keywords=None, defaults=[\'0\', \'False\'], "
|
||||
argspec: "args=[\'images\', \'size\', \'method\', \'align_corners\', \'preserve_aspect_ratio\'], varargs=None, keywords=None, defaults=[\'0\', \'False\', \'False\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "resize_nearest_neighbor"
|
||||
|
@ -1,5 +1,21 @@
|
||||
path: "tensorflow.math"
|
||||
tf_module {
|
||||
member_method {
|
||||
name: "bessel_i0"
|
||||
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'bessel_i0\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "bessel_i0e"
|
||||
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "bessel_i1"
|
||||
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'bessel_i1\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "bessel_i1e"
|
||||
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "polyval"
|
||||
argspec: "args=[\'coeffs\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -2,7 +2,7 @@ path: "tensorflow.saved_model.loader"
|
||||
tf_module {
|
||||
member_method {
|
||||
name: "load"
|
||||
argspec: "args=[\'sess\', \'tags\', \'export_dir\'], varargs=None, keywords=saver_kwargs, defaults=None"
|
||||
argspec: "args=[\'sess\', \'tags\', \'export_dir\', \'import_scope\'], varargs=None, keywords=saver_kwargs, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "maybe_saved_model_directory"
|
||||
|
@ -778,11 +778,9 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
|
||||
actual = "@grpc//:grpc_python_plugin",
|
||||
)
|
||||
|
||||
# gRPC has three empty C++ functions which it wants the user to define
|
||||
# at build time. https://github.com/grpc/grpc/issues/13590
|
||||
native.bind(
|
||||
name = "grpc_lib",
|
||||
actual = "@grpc//:grpc++_unsecure",
|
||||
actual = "@grpc//:grpc++",
|
||||
)
|
||||
|
||||
# Needed by gRPC
|
||||
|
Loading…
x
Reference in New Issue
Block a user