Merge commit for internal changes
This commit is contained in:
commit
e33538d114
@ -142,8 +142,10 @@ void TestRemoteExecute(bool async) {
|
|||||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
TFE_ContextOptionsSetServerDef(opts, serialized.data(), serialized.size(),
|
TFE_ContextOptionsSetServerDef(opts, serialized.data(), serialized.size(),
|
||||||
status);
|
status);
|
||||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(1));
|
|
||||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(1));
|
||||||
|
TFE_ContextOptionsSetDevicePlacementPolicy(opts,
|
||||||
|
TFE_DEVICE_PLACEMENT_EXPLICIT);
|
||||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
TFE_DeleteContextOptions(opts);
|
TFE_DeleteContextOptions(opts);
|
||||||
@ -205,6 +207,83 @@ void TestRemoteExecute(bool async) {
|
|||||||
TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); }
|
TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); }
|
||||||
TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); }
|
TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); }
|
||||||
|
|
||||||
|
void TestRemoteExecuteSilentCopies(bool async) {
|
||||||
|
tensorflow::ServerDef server_def = GetServerDef(2);
|
||||||
|
|
||||||
|
// This server def has the task index set to 0.
|
||||||
|
string serialized = server_def.SerializeAsString();
|
||||||
|
|
||||||
|
server_def.set_task_index(1);
|
||||||
|
|
||||||
|
std::unique_ptr<tensorflow::eager::EagerGrpcServer> worker_server;
|
||||||
|
ASSERT_TRUE(
|
||||||
|
tensorflow::eager::EagerGrpcServer::Create(server_def, &worker_server)
|
||||||
|
.ok());
|
||||||
|
ASSERT_TRUE(worker_server->Start().ok());
|
||||||
|
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
|
TFE_ContextOptionsSetServerDef(opts, serialized.data(), serialized.size(),
|
||||||
|
status);
|
||||||
|
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(1));
|
||||||
|
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
|
||||||
|
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||||
|
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_DeleteContextOptions(opts);
|
||||||
|
|
||||||
|
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle();
|
||||||
|
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle();
|
||||||
|
const char remote_device_name[] =
|
||||||
|
"/job:localhost/replica:0/task:1/device:CPU:0";
|
||||||
|
|
||||||
|
// Handles are on task0, but op is on remote (task1).
|
||||||
|
TFE_Op* matmul = MatMulOp(ctx, h0_task0, h1_task0);
|
||||||
|
TFE_OpSetDevice(matmul, remote_device_name, status);
|
||||||
|
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
|
||||||
|
TFE_TensorHandle* retvals[1];
|
||||||
|
int num_retvals = 1;
|
||||||
|
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
|
||||||
|
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
|
||||||
|
auto* retval_task0 = TFE_TensorHandleCopyToDevice(
|
||||||
|
retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
|
||||||
|
TF_Tensor* t = TFE_TensorHandleResolve(retval_task0, status);
|
||||||
|
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TFE_DeleteTensorHandle(retval_task0);
|
||||||
|
float product[4] = {0};
|
||||||
|
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
|
||||||
|
memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
|
||||||
|
TF_DeleteTensor(t);
|
||||||
|
EXPECT_EQ(7, product[0]);
|
||||||
|
EXPECT_EQ(10, product[1]);
|
||||||
|
EXPECT_EQ(15, product[2]);
|
||||||
|
EXPECT_EQ(22, product[3]);
|
||||||
|
|
||||||
|
TFE_DeleteTensorHandle(h0_task0);
|
||||||
|
TFE_DeleteTensorHandle(h1_task0);
|
||||||
|
TFE_DeleteTensorHandle(retvals[0]);
|
||||||
|
|
||||||
|
TFE_DeleteOp(matmul);
|
||||||
|
|
||||||
|
TFE_ContextAsyncWait(ctx, status);
|
||||||
|
TFE_DeleteContext(ctx, status);
|
||||||
|
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
|
||||||
|
// TODO(nareshmodi): Figure out how to correctly shut the server down.
|
||||||
|
worker_server.release();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CAPI, RemoteExecuteSilentCopies) { TestRemoteExecuteSilentCopies(false); }
|
||||||
|
TEST(CAPI, RemoteExecuteSilentCopiesAsync) {
|
||||||
|
TestRemoteExecuteSilentCopies(true);
|
||||||
|
}
|
||||||
|
|
||||||
TEST(CAPI, TensorHandle) {
|
TEST(CAPI, TensorHandle) {
|
||||||
TFE_TensorHandle* h = TestMatrixTensorHandle();
|
TFE_TensorHandle* h = TestMatrixTensorHandle();
|
||||||
EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h));
|
EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(h));
|
||||||
|
@ -42,7 +42,7 @@ tf_cc_binary(
|
|||||||
"//tensorflow/compiler/xla/service:cpu_plugin",
|
"//tensorflow/compiler/xla/service:cpu_plugin",
|
||||||
"//tensorflow/core:framework_internal",
|
"//tensorflow/core:framework_internal",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"@grpc//:grpc++_unsecure",
|
"@grpc//:grpc++",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -61,7 +61,7 @@ tf_cc_test(
|
|||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
"@grpc//:grpc++_unsecure",
|
"@grpc//:grpc++",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -74,6 +74,6 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla/service",
|
"//tensorflow/compiler/xla/service",
|
||||||
"//tensorflow/compiler/xla/service:platform_util",
|
"//tensorflow/compiler/xla/service:platform_util",
|
||||||
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
|
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
|
||||||
"@grpc//:grpc++_unsecure",
|
"@grpc//:grpc++",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -2379,7 +2379,6 @@ cc_library(
|
|||||||
":hlo_graph_dumper",
|
":hlo_graph_dumper",
|
||||||
":hlo_pass",
|
":hlo_pass",
|
||||||
"//tensorflow/compiler/xla:types",
|
"//tensorflow/compiler/xla:types",
|
||||||
"//tensorflow/compiler/xla:util",
|
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -2574,6 +2573,7 @@ cc_library(
|
|||||||
hdrs = ["hlo_graph_dumper.h"],
|
hdrs = ["hlo_graph_dumper.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":hlo",
|
":hlo",
|
||||||
|
":hlo_casting_utils",
|
||||||
":hlo_execution_profile",
|
":hlo_execution_profile",
|
||||||
":hlo_tfgraph_builder",
|
":hlo_tfgraph_builder",
|
||||||
"//tensorflow/compiler/xla:literal_util",
|
"//tensorflow/compiler/xla:literal_util",
|
||||||
|
@ -47,12 +47,16 @@ bool GpuMultiOutputFusion::ShapesCompatibleForFusion(HloInstruction* instr1,
|
|||||||
element_instr = fused_expression_root;
|
element_instr = fused_expression_root;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Special handling of kReduce instructions -- the fusion
|
||||||
|
// applies to the first operand.
|
||||||
|
if (element_instr->opcode() == HloOpcode::kReduce) {
|
||||||
|
return element_instr->operand(0)->shape();
|
||||||
|
}
|
||||||
return element_instr->shape();
|
return element_instr->shape();
|
||||||
};
|
};
|
||||||
|
|
||||||
// The elementwise output shapes must be the same (including layout)
|
// The elementwise output shapes must be the same (including layout)
|
||||||
return ShapeUtil::ShapeUtil::Equal(get_element_shape(instr1),
|
return ShapeUtil::Equal(get_element_shape(instr1), get_element_shape(instr2));
|
||||||
get_element_shape(instr2));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool GpuMultiOutputFusion::IsProfitableOperand(HloInstruction* instr) {
|
bool GpuMultiOutputFusion::IsProfitableOperand(HloInstruction* instr) {
|
||||||
|
@ -36,6 +36,11 @@ const char kModulePrefix[] = R"(
|
|||||||
scalar_lhs = f32[] parameter(0)
|
scalar_lhs = f32[] parameter(0)
|
||||||
scalar_rhs = f32[] parameter(1)
|
scalar_rhs = f32[] parameter(1)
|
||||||
ROOT add = f32[] add(scalar_lhs, scalar_rhs)
|
ROOT add = f32[] add(scalar_lhs, scalar_rhs)
|
||||||
|
}
|
||||||
|
scalar_mul_computation {
|
||||||
|
scalar_lhs = f32[] parameter(0)
|
||||||
|
scalar_rhs = f32[] parameter(1)
|
||||||
|
ROOT mul = f32[] add(scalar_lhs, scalar_rhs)
|
||||||
})";
|
})";
|
||||||
|
|
||||||
TEST_F(InstructionFusionTest, MultiOutputFusionSiblingReduceAndReduceFusion) {
|
TEST_F(InstructionFusionTest, MultiOutputFusionSiblingReduceAndReduceFusion) {
|
||||||
@ -67,6 +72,34 @@ TEST_F(InstructionFusionTest, MultiOutputFusionSiblingReduceAndReduceFusion) {
|
|||||||
op::Tuple(op::Reduce(), op::Reduce()));
|
op::Tuple(op::Reduce(), op::Reduce()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(InstructionFusionTest, MultiOutputFusionDifferentReduceInputShapes) {
|
||||||
|
auto module = ParseHloString(tensorflow::strings::StrCat(kModulePrefix, R"(
|
||||||
|
fused_computation_1 {
|
||||||
|
p1.1 = f32[6400]{0} parameter(1)
|
||||||
|
mul = f32[6400]{0} multiply(p1.1, p1.1)
|
||||||
|
const.1 = f32[] parameter(0)
|
||||||
|
ROOT reduce.1 = f32[] reduce(p1.1, const.1), dimensions={0}, to_apply=scalar_add_computation
|
||||||
|
}
|
||||||
|
|
||||||
|
fused_computation_2 {
|
||||||
|
p1.2 = f32[6400]{0} parameter(1)
|
||||||
|
r1 = f32[64,100]{0,1} reshape(p1.2)
|
||||||
|
const.2 = f32[] parameter(0)
|
||||||
|
ROOT reduce.2 = f32[] reduce(r1, const.2), dimensions={1,0}, to_apply=scalar_mul_computation
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY entry {
|
||||||
|
p0 = f32[] parameter(0)
|
||||||
|
p1 = f32[6400]{0} parameter(1)
|
||||||
|
const.2 = f32[] constant(1)
|
||||||
|
fusion.1 = f32[] fusion(p0, p1), kind=kInput, calls=fused_computation_1
|
||||||
|
fusion.2 = f32[] fusion(p0, p1), kind=kInput, calls=fused_computation_2
|
||||||
|
ROOT root = (f32[], f32[]) tuple(fusion.1, fusion.2)
|
||||||
|
})"))
|
||||||
|
.ValueOrDie();
|
||||||
|
ASSERT_FALSE(GpuMultiOutputFusion().Run(module.get()).ValueOrDie());
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(InstructionFusionTest, MultiOutputFusionSiblingReduceFusions) {
|
TEST_F(InstructionFusionTest, MultiOutputFusionSiblingReduceFusions) {
|
||||||
// Two sibling fusions with reduce instruction roots sharing the same input
|
// Two sibling fusions with reduce instruction roots sharing the same input
|
||||||
// param.
|
// param.
|
||||||
|
@ -357,7 +357,6 @@ std::list<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
|
|||||||
std::list<HloInstruction*> post_order;
|
std::list<HloInstruction*> post_order;
|
||||||
std::list<HloInstruction*> trace_instructions;
|
std::list<HloInstruction*> trace_instructions;
|
||||||
tensorflow::gtl::FlatSet<HloInstruction*> added_instructions;
|
tensorflow::gtl::FlatSet<HloInstruction*> added_instructions;
|
||||||
std::vector<HloInstruction> dfs_stack;
|
|
||||||
for (auto& instruction : instructions_) {
|
for (auto& instruction : instructions_) {
|
||||||
if (instruction->opcode() == HloOpcode::kTrace) {
|
if (instruction->opcode() == HloOpcode::kTrace) {
|
||||||
// Trace instructions aren't handled by the DFS visitor. Add trace
|
// Trace instructions aren't handled by the DFS visitor. Add trace
|
||||||
|
@ -28,6 +28,8 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/xla/layout_util.h"
|
#include "tensorflow/compiler/xla/layout_util.h"
|
||||||
#include "tensorflow/compiler/xla/literal_util.h"
|
#include "tensorflow/compiler/xla/literal_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h"
|
#include "tensorflow/compiler/xla/service/hlo_tfgraph_builder.h"
|
||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
@ -723,17 +725,14 @@ string HloDotDumper::DumpRootTag() {
|
|||||||
to_id, node_body, node_shape, NodeColorAttributes(color));
|
to_id, node_body, node_shape, NodeColorAttributes(color));
|
||||||
}
|
}
|
||||||
|
|
||||||
static const HloInstruction* TryGetFusionParameterConstant(
|
static const HloConstantInstruction* TryGetFusionParameterConstant(
|
||||||
const HloInstruction* instr) {
|
const HloInstruction* instr) {
|
||||||
if (instr->opcode() != HloOpcode::kParameter || !instr->IsFused()) {
|
if (instr->opcode() != HloOpcode::kParameter || !instr->IsFused()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
const HloInstruction* fusion = instr->parent()->FusionInstruction();
|
const HloInstruction* fusion = instr->parent()->FusionInstruction();
|
||||||
const HloInstruction* operand = fusion->operand(instr->parameter_number());
|
const HloInstruction* operand = fusion->operand(instr->parameter_number());
|
||||||
if (operand->opcode() == HloOpcode::kConstant) {
|
return DynCast<HloConstantInstruction>(operand);
|
||||||
return operand;
|
|
||||||
}
|
|
||||||
return nullptr;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool HloDotDumper::ShouldMergeIntoUsers(const HloInstruction* instr) const {
|
bool HloDotDumper::ShouldMergeIntoUsers(const HloInstruction* instr) const {
|
||||||
@ -826,7 +825,7 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) {
|
|||||||
|
|
||||||
string HloDotDumper::GetInstructionNodeInlinedOperands(
|
string HloDotDumper::GetInstructionNodeInlinedOperands(
|
||||||
const HloInstruction* instr) {
|
const HloInstruction* instr) {
|
||||||
auto stringify_constant = [](const HloInstruction* constant) {
|
auto stringify_constant = [](const HloConstantInstruction* constant) {
|
||||||
const auto& shape = constant->shape();
|
const auto& shape = constant->shape();
|
||||||
|
|
||||||
// If the shape has a dimension of size zero, print it as e.g.
|
// If the shape has a dimension of size zero, print it as e.g.
|
||||||
@ -845,7 +844,7 @@ string HloDotDumper::GetInstructionNodeInlinedOperands(
|
|||||||
*elem_count *= dim;
|
*elem_count *= dim;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (elem_count.has_value() && *elem_count <= 8 && constant->HasLiteral()) {
|
if (elem_count.has_value() && *elem_count <= 8) {
|
||||||
return Printf("%s (%s)", constant->literal().ToString(),
|
return Printf("%s (%s)", constant->literal().ToString(),
|
||||||
ShapeUtil::HumanString(constant->shape()));
|
ShapeUtil::HumanString(constant->shape()));
|
||||||
}
|
}
|
||||||
@ -864,9 +863,10 @@ string HloDotDumper::GetInstructionNodeInlinedOperands(
|
|||||||
std::vector<string> lines;
|
std::vector<string> lines;
|
||||||
for (int64 i = 0; i < instr->operand_count(); ++i) {
|
for (int64 i = 0; i < instr->operand_count(); ++i) {
|
||||||
const HloInstruction* operand = instr->operand(i);
|
const HloInstruction* operand = instr->operand(i);
|
||||||
|
const auto* constant_operand = DynCast<HloConstantInstruction>(operand);
|
||||||
optional<string> operand_str;
|
optional<string> operand_str;
|
||||||
if (operand->opcode() == HloOpcode::kConstant) {
|
if (constant_operand != nullptr) {
|
||||||
operand_str = stringify_constant(operand);
|
operand_str = stringify_constant(constant_operand);
|
||||||
} else if (ShouldMergeIntoUsers(operand)) {
|
} else if (ShouldMergeIntoUsers(operand)) {
|
||||||
// Special case: If the operand is a parameter to a fusion node and it
|
// Special case: If the operand is a parameter to a fusion node and it
|
||||||
// always has a constant value, display it like a regular constant.
|
// always has a constant value, display it like a regular constant.
|
||||||
@ -874,7 +874,7 @@ string HloDotDumper::GetInstructionNodeInlinedOperands(
|
|||||||
// For other parameters, use the parameter number rather than the proper
|
// For other parameters, use the parameter number rather than the proper
|
||||||
// name, because that's generally how people think of the node.
|
// name, because that's generally how people think of the node.
|
||||||
if (operand->opcode() == HloOpcode::kParameter) {
|
if (operand->opcode() == HloOpcode::kParameter) {
|
||||||
if (const HloInstruction* constant =
|
if (const HloConstantInstruction* constant =
|
||||||
TryGetFusionParameterConstant(operand)) {
|
TryGetFusionParameterConstant(operand)) {
|
||||||
operand_str = stringify_constant(constant);
|
operand_str = stringify_constant(constant);
|
||||||
} else {
|
} else {
|
||||||
|
@ -178,6 +178,23 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
|||||||
slice_limits, slice_strides);
|
slice_limits, slice_strides);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case HloOpcode::kConstant: {
|
||||||
|
CHECK(proto.has_literal());
|
||||||
|
TF_ASSIGN_OR_RETURN(auto literal,
|
||||||
|
Literal::CreateFromProto(proto.literal()));
|
||||||
|
instruction = CreateConstant(std::move(literal));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case HloOpcode::kTrace: {
|
||||||
|
TF_RET_CHECK(proto.operand_ids_size() == 1)
|
||||||
|
<< "Trace instruction should have 1 operand but sees "
|
||||||
|
<< proto.operand_ids_size();
|
||||||
|
CHECK(proto.has_literal());
|
||||||
|
TF_ASSIGN_OR_RETURN(auto literal,
|
||||||
|
Literal::CreateFromProto(proto.literal()));
|
||||||
|
instruction = CreateTrace(literal->GetR1U8AsString(), operands(0));
|
||||||
|
break;
|
||||||
|
}
|
||||||
default: {
|
default: {
|
||||||
instruction = WrapUnique(new HloInstruction(opcode, proto.shape()));
|
instruction = WrapUnique(new HloInstruction(opcode, proto.shape()));
|
||||||
for (const int64 operand_id : proto.operand_ids()) {
|
for (const int64 operand_id : proto.operand_ids()) {
|
||||||
@ -223,22 +240,11 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
|||||||
instruction->called_computations_.push_back(fused_computation);
|
instruction->called_computations_.push_back(fused_computation);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (instruction->opcode() == HloOpcode::kTrace) {
|
|
||||||
TF_RET_CHECK(instruction->operands().size() == 1)
|
|
||||||
<< "Trace instruction should have 1 operand but sees "
|
|
||||||
<< instruction->operands().size();
|
|
||||||
instruction->mutable_operand(0)->set_tracing(instruction.get());
|
|
||||||
}
|
|
||||||
|
|
||||||
TF_RET_CHECK(!proto.name().empty());
|
TF_RET_CHECK(!proto.name().empty());
|
||||||
instruction->SetAndSanitizeName(proto.name());
|
instruction->SetAndSanitizeName(proto.name());
|
||||||
|
|
||||||
instruction->metadata_ = proto.metadata();
|
instruction->metadata_ = proto.metadata();
|
||||||
instruction->backend_config_ = proto.backend_config();
|
instruction->backend_config_ = proto.backend_config();
|
||||||
if (proto.has_literal()) {
|
|
||||||
TF_ASSIGN_OR_RETURN(instruction->literal_,
|
|
||||||
Literal::CreateFromProto(proto.literal()));
|
|
||||||
}
|
|
||||||
instruction->parameter_number_ = proto.parameter_number();
|
instruction->parameter_number_ = proto.parameter_number();
|
||||||
|
|
||||||
instruction->tuple_index_ = proto.tuple_index();
|
instruction->tuple_index_ = proto.tuple_index();
|
||||||
@ -301,20 +307,12 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
|||||||
|
|
||||||
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTrace(
|
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTrace(
|
||||||
const string& tag, HloInstruction* operand) {
|
const string& tag, HloInstruction* operand) {
|
||||||
auto instruction =
|
return MakeUnique<HloTraceInstruction>(tag, operand);
|
||||||
WrapUnique(new HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil()));
|
|
||||||
instruction->operands_.push_back(operand);
|
|
||||||
instruction->literal_ = Literal::CreateR1U8(tag);
|
|
||||||
operand->set_tracing(instruction.get());
|
|
||||||
return instruction;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConstant(
|
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateConstant(
|
||||||
std::unique_ptr<Literal> literal) {
|
std::unique_ptr<Literal> literal) {
|
||||||
auto instruction =
|
return MakeUnique<HloConstantInstruction>(std::move(literal));
|
||||||
WrapUnique(new HloInstruction(HloOpcode::kConstant, literal->shape()));
|
|
||||||
instruction->literal_ = std::move(literal);
|
|
||||||
return instruction;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* static */ std::unique_ptr<HloInstruction>
|
/* static */ std::unique_ptr<HloInstruction>
|
||||||
@ -1321,6 +1319,8 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
|
|||||||
case HloOpcode::kBroadcast:
|
case HloOpcode::kBroadcast:
|
||||||
case HloOpcode::kMap:
|
case HloOpcode::kMap:
|
||||||
case HloOpcode::kSlice:
|
case HloOpcode::kSlice:
|
||||||
|
case HloOpcode::kConstant:
|
||||||
|
case HloOpcode::kTrace:
|
||||||
clone = CloneWithNewOperandsImpl(shape, new_operands, context);
|
clone = CloneWithNewOperandsImpl(shape, new_operands, context);
|
||||||
break;
|
break;
|
||||||
// Unary ops.
|
// Unary ops.
|
||||||
@ -1470,9 +1470,6 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
|
|||||||
clone =
|
clone =
|
||||||
CreateWhile(shape, while_condition(), while_body(), new_operands[0]);
|
CreateWhile(shape, while_condition(), while_body(), new_operands[0]);
|
||||||
break;
|
break;
|
||||||
case HloOpcode::kConstant:
|
|
||||||
clone = CreateConstant(literal_->CloneToUnique());
|
|
||||||
break;
|
|
||||||
case HloOpcode::kFusion: {
|
case HloOpcode::kFusion: {
|
||||||
HloModule* module = context != nullptr ? context->module() : GetModule();
|
HloModule* module = context != nullptr ? context->module() : GetModule();
|
||||||
HloComputation* new_fused_computation = nullptr;
|
HloComputation* new_fused_computation = nullptr;
|
||||||
@ -1520,8 +1517,6 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
|
|||||||
case HloOpcode::kGenerateToken:
|
case HloOpcode::kGenerateToken:
|
||||||
clone = CreateGenerateToken(new_operands);
|
clone = CreateGenerateToken(new_operands);
|
||||||
break;
|
break;
|
||||||
case HloOpcode::kTrace:
|
|
||||||
LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode_);
|
|
||||||
}
|
}
|
||||||
SetupDerivedInstruction(clone.get());
|
SetupDerivedInstruction(clone.get());
|
||||||
clone->set_parent(parent_);
|
clone->set_parent(parent_);
|
||||||
@ -1602,13 +1597,6 @@ const HloInstruction* HloInstruction::LatestNonGteAncestor() const {
|
|||||||
return hlo;
|
return hlo;
|
||||||
}
|
}
|
||||||
|
|
||||||
const Literal& HloInstruction::literal() const {
|
|
||||||
CHECK_EQ(HloOpcode::kConstant, opcode_);
|
|
||||||
return *literal_;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool HloInstruction::HasLiteral() const { return literal_ != nullptr; }
|
|
||||||
|
|
||||||
int64 HloInstruction::tuple_index() const {
|
int64 HloInstruction::tuple_index() const {
|
||||||
CHECK_EQ(HloOpcode::kGetTupleElement, opcode_);
|
CHECK_EQ(HloOpcode::kGetTupleElement, opcode_);
|
||||||
return tuple_index_;
|
return tuple_index_;
|
||||||
@ -1702,10 +1690,6 @@ void HloInstruction::AddUser(HloInstruction* user) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool HloInstruction::IsConstant() const {
|
|
||||||
return opcode_ == HloOpcode::kConstant;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool HloInstruction::HasConstantOperand() const {
|
bool HloInstruction::HasConstantOperand() const {
|
||||||
for (const HloInstruction* operand : operands_) {
|
for (const HloInstruction* operand : operands_) {
|
||||||
if (operand->IsConstant()) {
|
if (operand->IsConstant()) {
|
||||||
@ -1782,7 +1766,6 @@ bool HloInstruction::IdenticalSlowPath(
|
|||||||
// These opcodes have complex or special behavior so just return false.
|
// These opcodes have complex or special behavior so just return false.
|
||||||
case HloOpcode::kDomain:
|
case HloOpcode::kDomain:
|
||||||
case HloOpcode::kRng:
|
case HloOpcode::kRng:
|
||||||
case HloOpcode::kTrace:
|
|
||||||
case HloOpcode::kWhile:
|
case HloOpcode::kWhile:
|
||||||
case HloOpcode::kGenerateToken:
|
case HloOpcode::kGenerateToken:
|
||||||
return false;
|
return false;
|
||||||
@ -1790,10 +1773,6 @@ bool HloInstruction::IdenticalSlowPath(
|
|||||||
case HloOpcode::kParameter:
|
case HloOpcode::kParameter:
|
||||||
return parameter_number() == other.parameter_number();
|
return parameter_number() == other.parameter_number();
|
||||||
|
|
||||||
// A constant is defined by the value in the literal.
|
|
||||||
case HloOpcode::kConstant:
|
|
||||||
return literal() == other.literal();
|
|
||||||
|
|
||||||
// A reduce-precision operation is determined by the bit sizes.
|
// A reduce-precision operation is determined by the bit sizes.
|
||||||
case HloOpcode::kReducePrecision:
|
case HloOpcode::kReducePrecision:
|
||||||
return exponent_bits() == other.exponent_bits() &&
|
return exponent_bits() == other.exponent_bits() &&
|
||||||
@ -1878,6 +1857,8 @@ bool HloInstruction::IdenticalSlowPath(
|
|||||||
case HloOpcode::kBroadcast:
|
case HloOpcode::kBroadcast:
|
||||||
case HloOpcode::kMap:
|
case HloOpcode::kMap:
|
||||||
case HloOpcode::kSlice:
|
case HloOpcode::kSlice:
|
||||||
|
case HloOpcode::kConstant:
|
||||||
|
case HloOpcode::kTrace:
|
||||||
LOG(FATAL) << "Base class impl called for opcode with subclass: "
|
LOG(FATAL) << "Base class impl called for opcode with subclass: "
|
||||||
<< opcode();
|
<< opcode();
|
||||||
}
|
}
|
||||||
@ -2172,34 +2153,7 @@ string HloInstruction::OperandsToStringWithCanonicalNameMap(
|
|||||||
const HloPrintOptions& options,
|
const HloPrintOptions& options,
|
||||||
CanonicalNameMap* canonical_name_map) const {
|
CanonicalNameMap* canonical_name_map) const {
|
||||||
string operands;
|
string operands;
|
||||||
if (opcode() == HloOpcode::kConstant) {
|
if (opcode() == HloOpcode::kParameter) {
|
||||||
// For constants, show the actual value in place of an empty operand list.
|
|
||||||
//
|
|
||||||
// In HloInstruction, sometimes a constant literal is not constructed due
|
|
||||||
// to its size. Skip the printing in this case.
|
|
||||||
if (HasLiteral() && ((!ShapeUtil::IsTuple(shape()) &&
|
|
||||||
ShapeUtil::ElementsIn(shape()) <= 10) ||
|
|
||||||
options.print_large_constants())) {
|
|
||||||
// Literal::ToString emits multidimensional arrays over multiple
|
|
||||||
// lines. Compact this into one line by stripping out white space.
|
|
||||||
string tmp = literal().ToString();
|
|
||||||
std::replace(tmp.begin(), tmp.end(), '\n', ' ');
|
|
||||||
std::vector<string> v = tensorflow::str_util::Split(tmp, ' ');
|
|
||||||
bool first = true;
|
|
||||||
// Concatenate elements in "v" with spaces separating them, but ignoring
|
|
||||||
// empty entries.
|
|
||||||
for (const auto& s : v) {
|
|
||||||
if (s.empty()) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
StrAppend(&operands, (first ? "" : " "), s);
|
|
||||||
first = false;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Do not show large constants or tuples.
|
|
||||||
operands = "{...}";
|
|
||||||
}
|
|
||||||
} else if (opcode() == HloOpcode::kParameter) {
|
|
||||||
StrAppend(&operands, parameter_number_);
|
StrAppend(&operands, parameter_number_);
|
||||||
} else {
|
} else {
|
||||||
tensorflow::gtl::ArraySlice<HloInstruction*> slice(operands_);
|
tensorflow::gtl::ArraySlice<HloInstruction*> slice(operands_);
|
||||||
@ -2410,9 +2364,6 @@ HloInstructionProto HloInstruction::ToProto() const {
|
|||||||
|
|
||||||
*proto.mutable_metadata() = metadata_;
|
*proto.mutable_metadata() = metadata_;
|
||||||
proto.set_backend_config(backend_config_);
|
proto.set_backend_config(backend_config_);
|
||||||
if (literal_ != nullptr) {
|
|
||||||
*proto.mutable_literal() = literal_->ToProto();
|
|
||||||
}
|
|
||||||
proto.set_parameter_number(parameter_number_);
|
proto.set_parameter_number(parameter_number_);
|
||||||
if (opcode() == HloOpcode::kFusion) {
|
if (opcode() == HloOpcode::kFusion) {
|
||||||
proto.set_fusion_kind(xla::ToString(fusion_kind()));
|
proto.set_fusion_kind(xla::ToString(fusion_kind()));
|
||||||
@ -2518,12 +2469,6 @@ void HloInstruction::set_tracing(HloInstruction* trace_instruction) {
|
|||||||
trace_instruction_ = trace_instruction;
|
trace_instruction_ = trace_instruction;
|
||||||
}
|
}
|
||||||
|
|
||||||
string HloInstruction::TracingTag() const {
|
|
||||||
CHECK_EQ(HloOpcode::kTrace, opcode());
|
|
||||||
CHECK(literal_ != nullptr);
|
|
||||||
return literal_->GetR1U8AsString();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool HloInstruction::IsFused() const { return parent_->IsFusionComputation(); }
|
bool HloInstruction::IsFused() const { return parent_->IsFusionComputation(); }
|
||||||
|
|
||||||
bool HloInstruction::IsFusable() const {
|
bool HloInstruction::IsFusable() const {
|
||||||
@ -3035,10 +2980,6 @@ bool HloInstruction::IsElementwiseBinary() const {
|
|||||||
|
|
||||||
bool HloInstruction::IsElementwise() const {
|
bool HloInstruction::IsElementwise() const {
|
||||||
switch (opcode_) {
|
switch (opcode_) {
|
||||||
// Nullary elementwise operations.
|
|
||||||
case HloOpcode::kConstant:
|
|
||||||
return true;
|
|
||||||
|
|
||||||
// Unary elementwise operations.
|
// Unary elementwise operations.
|
||||||
case HloOpcode::kAbs:
|
case HloOpcode::kAbs:
|
||||||
case HloOpcode::kRoundNearestAfz:
|
case HloOpcode::kRoundNearestAfz:
|
||||||
@ -3500,23 +3441,6 @@ void HloInstruction::set_outer_dimension_partitions(
|
|||||||
outer_dimension_partitions_ = outer_dimension_partitions;
|
outer_dimension_partitions_ = outer_dimension_partitions;
|
||||||
}
|
}
|
||||||
|
|
||||||
void HloInstruction::RelayoutConstant(const Layout& new_layout,
|
|
||||||
const ShapeIndex& shape_index) {
|
|
||||||
CHECK_EQ(opcode(), HloOpcode::kConstant);
|
|
||||||
Shape* mutable_array_subshape =
|
|
||||||
ShapeUtil::GetMutableSubshape(mutable_shape(), shape_index);
|
|
||||||
CHECK(ShapeUtil::IsArray(*mutable_array_subshape));
|
|
||||||
|
|
||||||
// Normally array_subshape will always have a layout, but this invariant is
|
|
||||||
// temporarily broken in LayoutAssignment::AssignLayouts.
|
|
||||||
|
|
||||||
if (!mutable_array_subshape->has_layout() ||
|
|
||||||
!LayoutUtil::Equal(mutable_array_subshape->layout(), new_layout)) {
|
|
||||||
literal_ = literal_->Relayout(new_layout, shape_index);
|
|
||||||
*mutable_array_subshape->mutable_layout() = new_layout;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(b/80131774): Remove these temporary methods after transition.
|
// TODO(b/80131774): Remove these temporary methods after transition.
|
||||||
int64 HloInstruction::feature_index() const {
|
int64 HloInstruction::feature_index() const {
|
||||||
return Cast<HloBatchNormInstruction>(this)->feature_index();
|
return Cast<HloBatchNormInstruction>(this)->feature_index();
|
||||||
@ -3574,4 +3498,21 @@ const std::vector<int64>& HloInstruction::slice_strides() const {
|
|||||||
bool HloInstruction::IsInPlaceSlice() const {
|
bool HloInstruction::IsInPlaceSlice() const {
|
||||||
return Cast<HloSliceInstruction>(this)->IsInPlaceSlice();
|
return Cast<HloSliceInstruction>(this)->IsInPlaceSlice();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const Literal& HloInstruction::literal() const {
|
||||||
|
return Cast<HloConstantInstruction>(this)->literal();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool HloInstruction::IsConstant() const {
|
||||||
|
return DynCast<HloConstantInstruction>(this) != nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
void HloInstruction::RelayoutConstant(const Layout& new_layout,
|
||||||
|
const ShapeIndex& shape_index) {
|
||||||
|
Cast<HloConstantInstruction>(this)->RelayoutConstant(new_layout, shape_index);
|
||||||
|
}
|
||||||
|
|
||||||
|
string HloInstruction::TracingTag() const {
|
||||||
|
return Cast<HloTraceInstruction>(this)->TracingTag();
|
||||||
|
}
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -875,14 +875,6 @@ class HloInstruction {
|
|||||||
template <typename HloInstructionPtr>
|
template <typename HloInstructionPtr>
|
||||||
Status Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor);
|
Status Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor);
|
||||||
|
|
||||||
// Returns the literal associated with this instruction.
|
|
||||||
//
|
|
||||||
// Note: only constant and parameter opcodes have an associated literal.
|
|
||||||
const Literal& literal() const;
|
|
||||||
|
|
||||||
// Returns whether there is literal associated with this instruction.
|
|
||||||
bool HasLiteral() const;
|
|
||||||
|
|
||||||
// Returns the parameter number associated with this instruction.
|
// Returns the parameter number associated with this instruction.
|
||||||
//
|
//
|
||||||
// Note: only parameter opcodes have an associated parameter number.
|
// Note: only parameter opcodes have an associated parameter number.
|
||||||
@ -1014,14 +1006,6 @@ class HloInstruction {
|
|||||||
string infeed_config() const { return infeed_config_; }
|
string infeed_config() const { return infeed_config_; }
|
||||||
void set_infeed_config(const string& config) { infeed_config_ = config; }
|
void set_infeed_config(const string& config) { infeed_config_ = config; }
|
||||||
|
|
||||||
// Returns a tag to be used in tracing.
|
|
||||||
//
|
|
||||||
// Precondition: opcode() == HloOpcode::kTrace
|
|
||||||
string TracingTag() const;
|
|
||||||
|
|
||||||
// Returns whether the instruction is a constant.
|
|
||||||
bool IsConstant() const;
|
|
||||||
|
|
||||||
// Returns true if this instruction is fused, ie contained within a fusion
|
// Returns true if this instruction is fused, ie contained within a fusion
|
||||||
// instruction.
|
// instruction.
|
||||||
bool IsFused() const;
|
bool IsFused() const;
|
||||||
@ -1452,12 +1436,6 @@ class HloInstruction {
|
|||||||
void set_outer_dimension_partitions(
|
void set_outer_dimension_partitions(
|
||||||
const std::vector<int64>& outer_dimension_partitions);
|
const std::vector<int64>& outer_dimension_partitions);
|
||||||
|
|
||||||
// Change the layout for an Constant Hlo instruction to match new_layout. For
|
|
||||||
// tuple shaped constants shape_index is the path to the internal array
|
|
||||||
// subshape whose layout needs to be changed.
|
|
||||||
void RelayoutConstant(const Layout& new_layout,
|
|
||||||
const ShapeIndex& shape_index = {});
|
|
||||||
|
|
||||||
// Old methods kept for smooth subclassing transition BEGIN.
|
// Old methods kept for smooth subclassing transition BEGIN.
|
||||||
// TODO(b/80131774): Remove this code.
|
// TODO(b/80131774): Remove this code.
|
||||||
|
|
||||||
@ -1504,6 +1482,19 @@ class HloInstruction {
|
|||||||
|
|
||||||
// Delegates to HloSliceInstruction::IsInPlaceSlice.
|
// Delegates to HloSliceInstruction::IsInPlaceSlice.
|
||||||
bool IsInPlaceSlice() const;
|
bool IsInPlaceSlice() const;
|
||||||
|
|
||||||
|
// Returns the literal associated with this instruction.
|
||||||
|
const Literal& literal() const;
|
||||||
|
|
||||||
|
// Returns whether the instruction is a constant.
|
||||||
|
bool IsConstant() const;
|
||||||
|
|
||||||
|
// Delegate to HloConstantInstruction::RelayoutConstant.
|
||||||
|
void RelayoutConstant(const Layout& new_layout,
|
||||||
|
const ShapeIndex& shape_index = {});
|
||||||
|
|
||||||
|
// Delegates to HloTraceInstruction::TracingTag.
|
||||||
|
string TracingTag() const;
|
||||||
// Old methods kept for smooth subclassing transition END.
|
// Old methods kept for smooth subclassing transition END.
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
@ -1544,7 +1535,7 @@ class HloInstruction {
|
|||||||
CanonicalNameMap* canonical_name_map) const;
|
CanonicalNameMap* canonical_name_map) const;
|
||||||
|
|
||||||
// Prints an operand to a string.
|
// Prints an operand to a string.
|
||||||
string OperandsToStringWithCanonicalNameMap(
|
virtual string OperandsToStringWithCanonicalNameMap(
|
||||||
const HloPrintOptions& options,
|
const HloPrintOptions& options,
|
||||||
CanonicalNameMap* canonical_name_map) const;
|
CanonicalNameMap* canonical_name_map) const;
|
||||||
|
|
||||||
@ -1639,9 +1630,6 @@ class HloInstruction {
|
|||||||
// Result shape of this instruction.
|
// Result shape of this instruction.
|
||||||
Shape shape_;
|
Shape shape_;
|
||||||
|
|
||||||
// Literal, only present for kConstant.
|
|
||||||
std::unique_ptr<Literal> literal_;
|
|
||||||
|
|
||||||
// Constant index, only present for kGetTupleElement.
|
// Constant index, only present for kGetTupleElement.
|
||||||
int64 tuple_index_ = -1;
|
int64 tuple_index_ = -1;
|
||||||
|
|
||||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
using ::tensorflow::str_util::Join;
|
using ::tensorflow::str_util::Join;
|
||||||
|
using ::tensorflow::strings::StrAppend;
|
||||||
using ::tensorflow::strings::StrCat;
|
using ::tensorflow::strings::StrCat;
|
||||||
|
|
||||||
HloBatchNormInstruction::HloBatchNormInstruction(
|
HloBatchNormInstruction::HloBatchNormInstruction(
|
||||||
@ -586,4 +587,105 @@ std::unique_ptr<HloInstruction> HloSliceInstruction::CloneWithNewOperandsImpl(
|
|||||||
return MakeUnique<HloSliceInstruction>(shape, new_operands[0], slice_starts_,
|
return MakeUnique<HloSliceInstruction>(shape, new_operands[0], slice_starts_,
|
||||||
slice_limits_, slice_strides_);
|
slice_limits_, slice_strides_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
HloConstantInstruction::HloConstantInstruction(std::unique_ptr<Literal> literal)
|
||||||
|
: HloInstruction(HloOpcode::kConstant, CHECK_NOTNULL(literal)->shape()),
|
||||||
|
literal_(std::move(literal)) {}
|
||||||
|
|
||||||
|
HloInstructionProto HloConstantInstruction::ToProto() const {
|
||||||
|
HloInstructionProto proto = HloInstruction::ToProto();
|
||||||
|
*proto.mutable_literal() = literal_->ToProto();
|
||||||
|
return proto;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool HloConstantInstruction::IsElementwise() const { return true; }
|
||||||
|
|
||||||
|
void HloConstantInstruction::RelayoutConstant(const Layout& new_layout,
|
||||||
|
const ShapeIndex& shape_index) {
|
||||||
|
Shape* mutable_array_subshape =
|
||||||
|
ShapeUtil::GetMutableSubshape(mutable_shape(), shape_index);
|
||||||
|
CHECK(ShapeUtil::IsArray(*mutable_array_subshape));
|
||||||
|
|
||||||
|
// Normally array_subshape will always have a layout, but this invariant is
|
||||||
|
// temporarily broken in LayoutAssignment::AssignLayouts.
|
||||||
|
|
||||||
|
if (!mutable_array_subshape->has_layout() ||
|
||||||
|
!LayoutUtil::Equal(mutable_array_subshape->layout(), new_layout)) {
|
||||||
|
literal_ = literal_->Relayout(new_layout, shape_index);
|
||||||
|
*mutable_array_subshape->mutable_layout() = new_layout;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool HloConstantInstruction::IdenticalSlowPath(
|
||||||
|
const HloInstruction& other,
|
||||||
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
||||||
|
eq_computations) const {
|
||||||
|
const auto& other_slice = static_cast<const HloSliceInstruction&>(other);
|
||||||
|
return literal() == other_slice.literal();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<HloInstruction>
|
||||||
|
HloConstantInstruction::CloneWithNewOperandsImpl(
|
||||||
|
const Shape& shape,
|
||||||
|
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
|
||||||
|
HloCloneContext* context) const {
|
||||||
|
return MakeUnique<HloConstantInstruction>(literal_->CloneToUnique());
|
||||||
|
}
|
||||||
|
|
||||||
|
string HloConstantInstruction::OperandsToStringWithCanonicalNameMap(
|
||||||
|
const HloPrintOptions& options,
|
||||||
|
CanonicalNameMap* canonical_name_map) const {
|
||||||
|
string operands;
|
||||||
|
// For constants, show the actual value in place of an empty operand list.
|
||||||
|
if ((!ShapeUtil::IsTuple(shape()) && ShapeUtil::ElementsIn(shape()) <= 10) ||
|
||||||
|
options.print_large_constants()) {
|
||||||
|
// Literal::ToString emits multidimensional arrays over multiple
|
||||||
|
// lines. Compact this into one line by stripping out white space.
|
||||||
|
string tmp = literal().ToString();
|
||||||
|
std::replace(tmp.begin(), tmp.end(), '\n', ' ');
|
||||||
|
std::vector<string> v = tensorflow::str_util::Split(tmp, ' ');
|
||||||
|
bool first = true;
|
||||||
|
// Concatenate elements in "v" with spaces separating them, but ignoring
|
||||||
|
// empty entries.
|
||||||
|
for (const auto& s : v) {
|
||||||
|
if (s.empty()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
StrAppend(&operands, (first ? "" : " "), s);
|
||||||
|
first = false;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Do not show large constants or tuples.
|
||||||
|
operands = "{...}";
|
||||||
|
}
|
||||||
|
return operands;
|
||||||
|
}
|
||||||
|
|
||||||
|
HloTraceInstruction::HloTraceInstruction(const string& tag,
|
||||||
|
HloInstruction* operand)
|
||||||
|
: HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil()),
|
||||||
|
literal_(Literal::CreateR1U8(tag)) {
|
||||||
|
AppendOperand(operand);
|
||||||
|
operand->set_tracing(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
HloInstructionProto HloTraceInstruction::ToProto() const {
|
||||||
|
HloInstructionProto proto = HloInstruction::ToProto();
|
||||||
|
*proto.mutable_literal() = literal_->ToProto();
|
||||||
|
return proto;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool HloTraceInstruction::IdenticalSlowPath(
|
||||||
|
const HloInstruction& other,
|
||||||
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
||||||
|
eq_computations) const {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<HloInstruction> HloTraceInstruction::CloneWithNewOperandsImpl(
|
||||||
|
const Shape& shape,
|
||||||
|
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
|
||||||
|
HloCloneContext* context) const {
|
||||||
|
LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode());
|
||||||
|
}
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -433,6 +433,62 @@ class HloSliceInstruction : public HloInstruction {
|
|||||||
// Describes whether the slice can be lowered to an offset into the operand.
|
// Describes whether the slice can be lowered to an offset into the operand.
|
||||||
bool is_in_place_slice_ = false;
|
bool is_in_place_slice_ = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class HloConstantInstruction : public HloInstruction {
|
||||||
|
public:
|
||||||
|
explicit HloConstantInstruction(std::unique_ptr<Literal> literal);
|
||||||
|
// Returns the literal associated with this instruction.
|
||||||
|
const Literal& literal() const { return *literal_; }
|
||||||
|
// Returns a serialized representation of this instruction.
|
||||||
|
HloInstructionProto ToProto() const override;
|
||||||
|
// Returns true if this instruction is elementwise on all its operands.
|
||||||
|
bool IsElementwise() const override;
|
||||||
|
|
||||||
|
// Change the layout for an Constant Hlo instruction to match new_layout. For
|
||||||
|
// tuple shaped constants shape_index is the path to the internal array
|
||||||
|
// subshape whose layout needs to be changed.
|
||||||
|
void RelayoutConstant(const Layout& new_layout,
|
||||||
|
const ShapeIndex& shape_index = {});
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool IdenticalSlowPath(
|
||||||
|
const HloInstruction& other,
|
||||||
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
||||||
|
eq_computations) const override;
|
||||||
|
string OperandsToStringWithCanonicalNameMap(
|
||||||
|
const HloPrintOptions& options,
|
||||||
|
CanonicalNameMap* canonical_name_map) const override;
|
||||||
|
// Implementation for non-common logic of CloneWithNewOperands.
|
||||||
|
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
||||||
|
const Shape& shape,
|
||||||
|
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
|
||||||
|
HloCloneContext* context) const override;
|
||||||
|
// TODO(b/36360764): Remove unique_ptr wrapping.
|
||||||
|
std::unique_ptr<Literal> literal_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class HloTraceInstruction : public HloInstruction {
|
||||||
|
public:
|
||||||
|
explicit HloTraceInstruction(const string& tag, HloInstruction* operand);
|
||||||
|
// Returns a tag to be used in tracing.
|
||||||
|
string TracingTag() const { return literal_->GetR1U8AsString(); }
|
||||||
|
// Returns a serialized representation of this instruction.
|
||||||
|
HloInstructionProto ToProto() const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool IdenticalSlowPath(
|
||||||
|
const HloInstruction& other,
|
||||||
|
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
||||||
|
eq_computations) const override;
|
||||||
|
// Implementation for non-common logic of CloneWithNewOperands.
|
||||||
|
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
||||||
|
const Shape& shape,
|
||||||
|
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
|
||||||
|
HloCloneContext* context) const override;
|
||||||
|
// TODO(b/36360764): Remove unique_ptr wrapping.
|
||||||
|
std::unique_ptr<Literal> literal_;
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_
|
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_
|
||||||
|
@ -127,9 +127,14 @@ Status HloModuleGroupMetadata::VerifyCompanionSets() const {
|
|||||||
for (HloInstruction* instruction : *companions) {
|
for (HloInstruction* instruction : *companions) {
|
||||||
// Go through all the communicating instructions (send, recv) of the given
|
// Go through all the communicating instructions (send, recv) of the given
|
||||||
// companion, and record their device.
|
// companion, and record their device.
|
||||||
|
auto it = tracked_instructions_comms_.find(instruction);
|
||||||
|
if (it == tracked_instructions_comms_.end()) {
|
||||||
|
// Companions can be added even if they have no communicating
|
||||||
|
// instructions, if they are parent of companions.
|
||||||
|
continue;
|
||||||
|
}
|
||||||
std::unordered_set<int64> comm_devices;
|
std::unordered_set<int64> comm_devices;
|
||||||
for (HloInstruction* comm_instruction :
|
for (HloInstruction* comm_instruction : it->second) {
|
||||||
tracked_instructions_comms_.at(instruction)) {
|
|
||||||
auto device = GetInstructionDevice(*comm_instruction);
|
auto device = GetInstructionDevice(*comm_instruction);
|
||||||
TF_RET_CHECK(device) << "Instruction " << comm_instruction->ToString()
|
TF_RET_CHECK(device) << "Instruction " << comm_instruction->ToString()
|
||||||
<< " does not have a device";
|
<< " does not have a device";
|
||||||
|
@ -232,7 +232,13 @@ def _dnn_tree_combined_model_fn(features,
|
|||||||
return update_op
|
return update_op
|
||||||
|
|
||||||
if predict_with_tree_only:
|
if predict_with_tree_only:
|
||||||
|
if mode == model_fn.ModeKeys.TRAIN or mode == model_fn.ModeKeys.PREDICT:
|
||||||
tree_train_logits = tree_logits
|
tree_train_logits = tree_logits
|
||||||
|
else:
|
||||||
|
tree_train_logits = control_flow_ops.cond(
|
||||||
|
global_step > dnn_steps_to_train,
|
||||||
|
lambda: tree_logits,
|
||||||
|
lambda: dnn_logits)
|
||||||
else:
|
else:
|
||||||
tree_train_logits = dnn_logits + tree_logits
|
tree_train_logits = dnn_logits + tree_logits
|
||||||
|
|
||||||
|
@ -36,6 +36,7 @@ except ImportError:
|
|||||||
|
|
||||||
|
|
||||||
_GKE_ENV_VARIABLE = 'KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'
|
_GKE_ENV_VARIABLE = 'KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'
|
||||||
|
_ENDPOINTS_SEPARATOR = ','
|
||||||
_DEFAULT_ENV_VARIABLE = 'TPU_NAME'
|
_DEFAULT_ENV_VARIABLE = 'TPU_NAME'
|
||||||
_DISCOVERY_SERVICE_URL_ENV_VARIABLE = 'TPU_API_DISCOVERY_URL'
|
_DISCOVERY_SERVICE_URL_ENV_VARIABLE = 'TPU_API_DISCOVERY_URL'
|
||||||
|
|
||||||
@ -69,8 +70,8 @@ class TPUClusterResolver(ClusterResolver):
|
|||||||
return _GKE_ENV_VARIABLE in os.environ
|
return _GKE_ENV_VARIABLE in os.environ
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _gkeMaster():
|
def _gkeEndpoints():
|
||||||
return os.environ[_GKE_ENV_VARIABLE].split(',')[0]
|
return os.environ[_GKE_ENV_VARIABLE]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _envVarFallback():
|
def _envVarFallback():
|
||||||
@ -143,7 +144,7 @@ class TPUClusterResolver(ClusterResolver):
|
|||||||
# When using GKE with Cloud TPUs, the env variable will be set.
|
# When using GKE with Cloud TPUs, the env variable will be set.
|
||||||
if tpu is None:
|
if tpu is None:
|
||||||
if in_gke:
|
if in_gke:
|
||||||
tpu = self._gkeMaster()
|
tpu = self._gkeEndpoints()
|
||||||
else:
|
else:
|
||||||
tpu = self._envVarFallback()
|
tpu = self._envVarFallback()
|
||||||
|
|
||||||
@ -214,7 +215,7 @@ class TPUClusterResolver(ClusterResolver):
|
|||||||
ValueError: If none of the TPUs specified exists.
|
ValueError: If none of the TPUs specified exists.
|
||||||
"""
|
"""
|
||||||
if not self._shouldResolve():
|
if not self._shouldResolve():
|
||||||
return self._tpu
|
return self._tpu.split(compat.as_bytes(_ENDPOINTS_SEPARATOR))[0]
|
||||||
|
|
||||||
job_tasks = self.cluster_spec().job_tasks(self._job_name)
|
job_tasks = self.cluster_spec().job_tasks(self._job_name)
|
||||||
if not job_tasks:
|
if not job_tasks:
|
||||||
@ -280,8 +281,12 @@ class TPUClusterResolver(ClusterResolver):
|
|||||||
# Case 3.
|
# Case 3.
|
||||||
return None
|
return None
|
||||||
# Case 2.
|
# Case 2.
|
||||||
cluster_spec = {self._job_name: [self._tpu[len(
|
cluster_spec = {
|
||||||
compat.as_bytes('grpc://')):]]}
|
self._job_name: [
|
||||||
|
x[len(compat.as_bytes('grpc://')):]
|
||||||
|
for x in self._tpu.split(compat.as_bytes(_ENDPOINTS_SEPARATOR))
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
if self._coordinator_address:
|
if self._coordinator_address:
|
||||||
# {1, 2}.a
|
# {1, 2}.a
|
||||||
|
@ -402,13 +402,61 @@ class TPUClusterResolverTest(test.TestCase):
|
|||||||
compat.as_bytes('/bns/foo/bar'), tpu_cluster_resolver.master())
|
compat.as_bytes('/bns/foo/bar'), tpu_cluster_resolver.master())
|
||||||
self.assertEqual(None, tpu_cluster_resolver.cluster_spec())
|
self.assertEqual(None, tpu_cluster_resolver.cluster_spec())
|
||||||
|
|
||||||
def testGkeEnvironment(self):
|
def testGkeEnvironmentForDonut(self):
|
||||||
os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] = 'grpc://10.120.27.5:8470'
|
os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] = 'grpc://10.120.27.5:8470'
|
||||||
self.assertTrue('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS' in os.environ)
|
|
||||||
|
self.assertIn('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS', os.environ)
|
||||||
self.assertTrue(TPUClusterResolver._inGke())
|
self.assertTrue(TPUClusterResolver._inGke())
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
compat.as_bytes('grpc://10.120.27.5:8470'),
|
compat.as_bytes('grpc://10.120.27.5:8470'),
|
||||||
compat.as_bytes(TPUClusterResolver._gkeMaster()))
|
compat.as_bytes(TPUClusterResolver._gkeEndpoints()))
|
||||||
|
|
||||||
|
tpu_cluster_resolver = TPUClusterResolver()
|
||||||
|
self.assertEqual(
|
||||||
|
compat.as_bytes('grpc://10.120.27.5:8470'),
|
||||||
|
compat.as_bytes(tpu_cluster_resolver.master()))
|
||||||
|
actual_cluster_spec = tpu_cluster_resolver.cluster_spec()
|
||||||
|
expected_proto = """
|
||||||
|
job {
|
||||||
|
name: 'worker'
|
||||||
|
tasks { key: 0 value: '10.120.27.5:8470' }
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
|
||||||
|
|
||||||
|
del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS']
|
||||||
|
|
||||||
|
def testGkeEnvironmentForPod(self):
|
||||||
|
os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] = ('grpc://10.120.27.5:8470,'
|
||||||
|
'grpc://10.120.27.6:8470,'
|
||||||
|
'grpc://10.120.27.7:8470,'
|
||||||
|
'grpc://10.120.27.8:8470')
|
||||||
|
|
||||||
|
self.assertIn('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS', os.environ)
|
||||||
|
self.assertTrue(TPUClusterResolver._inGke())
|
||||||
|
self.assertEqual(
|
||||||
|
compat.as_bytes('grpc://10.120.27.5:8470,'
|
||||||
|
'grpc://10.120.27.6:8470,'
|
||||||
|
'grpc://10.120.27.7:8470,'
|
||||||
|
'grpc://10.120.27.8:8470'),
|
||||||
|
compat.as_bytes(TPUClusterResolver._gkeEndpoints()))
|
||||||
|
|
||||||
|
tpu_cluster_resolver = TPUClusterResolver()
|
||||||
|
self.assertEqual(
|
||||||
|
compat.as_bytes('grpc://10.120.27.5:8470'),
|
||||||
|
compat.as_bytes(tpu_cluster_resolver.master()))
|
||||||
|
actual_cluster_spec = tpu_cluster_resolver.cluster_spec()
|
||||||
|
expected_proto = """
|
||||||
|
job {
|
||||||
|
name: 'worker'
|
||||||
|
tasks { key: 0 value: '10.120.27.5:8470' }
|
||||||
|
tasks { key: 1 value: '10.120.27.6:8470' }
|
||||||
|
tasks { key: 2 value: '10.120.27.7:8470' }
|
||||||
|
tasks { key: 3 value: '10.120.27.8:8470' }
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
|
||||||
|
|
||||||
del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS']
|
del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS']
|
||||||
|
|
||||||
def testDiscoveryUrl(self):
|
def testDiscoveryUrl(self):
|
||||||
|
@ -18,7 +18,16 @@ cmake_policy(SET CMP0022 NEW)
|
|||||||
|
|
||||||
# Options
|
# Options
|
||||||
option(tensorflow_VERBOSE "Enable for verbose output" OFF)
|
option(tensorflow_VERBOSE "Enable for verbose output" OFF)
|
||||||
|
|
||||||
|
if(WIN32)
|
||||||
|
# BoringSSL is disabled for windows as it currently doesn't build with
|
||||||
|
# MSBuild. (Ninja is required.)
|
||||||
option(tensorflow_ENABLE_SSL_SUPPORT "Enable boringssl support" OFF)
|
option(tensorflow_ENABLE_SSL_SUPPORT "Enable boringssl support" OFF)
|
||||||
|
else()
|
||||||
|
# BoringSSL is enabled for gRPC.
|
||||||
|
option(tensorflow_ENABLE_SSL_SUPPORT "Enable boringssl support" ON)
|
||||||
|
endif()
|
||||||
|
|
||||||
option(tensorflow_ENABLE_GRPC_SUPPORT "Enable gRPC support" ON)
|
option(tensorflow_ENABLE_GRPC_SUPPORT "Enable gRPC support" ON)
|
||||||
option(tensorflow_ENABLE_HDFS_SUPPORT "Enable HDFS support" OFF)
|
option(tensorflow_ENABLE_HDFS_SUPPORT "Enable HDFS support" OFF)
|
||||||
option(tensorflow_ENABLE_JEMALLOC_SUPPORT "Enable jemalloc support" OFF)
|
option(tensorflow_ENABLE_JEMALLOC_SUPPORT "Enable jemalloc support" OFF)
|
||||||
|
17
tensorflow/contrib/cmake/external/grpc.cmake
vendored
17
tensorflow/contrib/cmake/external/grpc.cmake
vendored
@ -20,6 +20,10 @@ set(GRPC_BUILD ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc)
|
|||||||
set(GRPC_TAG d184fa229d75d336aedea0041bd59cb93e7e267f)
|
set(GRPC_TAG d184fa229d75d336aedea0041bd59cb93e7e267f)
|
||||||
|
|
||||||
if(WIN32)
|
if(WIN32)
|
||||||
|
# We use unsecure gRPC because boringssl does not build on windows
|
||||||
|
set(grpc_TARGET grpc++_unsecure)
|
||||||
|
set(grpc_DEPENDS protobuf zlib)
|
||||||
|
set(grpc_SSL_PROVIDER NONE)
|
||||||
if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*")
|
if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*")
|
||||||
set(grpc_STATIC_LIBRARIES
|
set(grpc_STATIC_LIBRARIES
|
||||||
${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/Release/grpc++_unsecure.lib
|
${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/Release/grpc++_unsecure.lib
|
||||||
@ -32,9 +36,12 @@ if(WIN32)
|
|||||||
${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/gpr.lib)
|
${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/gpr.lib)
|
||||||
endif()
|
endif()
|
||||||
else()
|
else()
|
||||||
|
set(grpc_TARGET grpc++)
|
||||||
|
set(grpc_DEPENDS boringssl protobuf zlib)
|
||||||
|
set(grpc_SSL_PROVIDER module)
|
||||||
set(grpc_STATIC_LIBRARIES
|
set(grpc_STATIC_LIBRARIES
|
||||||
${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc++_unsecure.a
|
${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc++.a
|
||||||
${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc_unsecure.a
|
${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc.a
|
||||||
${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libaddress_sorting.a
|
${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libaddress_sorting.a
|
||||||
${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/third_party/cares/cares/lib/libcares.a
|
${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/third_party/cares/cares/lib/libcares.a
|
||||||
${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgpr.a)
|
${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgpr.a)
|
||||||
@ -44,13 +51,13 @@ add_definitions(-DGRPC_ARES=0)
|
|||||||
|
|
||||||
ExternalProject_Add(grpc
|
ExternalProject_Add(grpc
|
||||||
PREFIX grpc
|
PREFIX grpc
|
||||||
DEPENDS protobuf zlib
|
DEPENDS ${grpc_DEPENDS}
|
||||||
GIT_REPOSITORY ${GRPC_URL}
|
GIT_REPOSITORY ${GRPC_URL}
|
||||||
GIT_TAG ${GRPC_TAG}
|
GIT_TAG ${GRPC_TAG}
|
||||||
DOWNLOAD_DIR "${DOWNLOAD_LOCATION}"
|
DOWNLOAD_DIR "${DOWNLOAD_LOCATION}"
|
||||||
BUILD_IN_SOURCE 1
|
BUILD_IN_SOURCE 1
|
||||||
BUILD_BYPRODUCTS ${grpc_STATIC_LIBRARIES}
|
BUILD_BYPRODUCTS ${grpc_STATIC_LIBRARIES}
|
||||||
BUILD_COMMAND ${CMAKE_COMMAND} --build . --config Release --target grpc++_unsecure
|
BUILD_COMMAND ${CMAKE_COMMAND} --build . --config Release --target ${grpc_TARGET}
|
||||||
COMMAND ${CMAKE_COMMAND} --build . --config Release --target grpc_cpp_plugin
|
COMMAND ${CMAKE_COMMAND} --build . --config Release --target grpc_cpp_plugin
|
||||||
INSTALL_COMMAND ""
|
INSTALL_COMMAND ""
|
||||||
CMAKE_CACHE_ARGS
|
CMAKE_CACHE_ARGS
|
||||||
@ -59,7 +66,7 @@ ExternalProject_Add(grpc
|
|||||||
-DPROTOBUF_INCLUDE_DIRS:STRING=${PROTOBUF_INCLUDE_DIRS}
|
-DPROTOBUF_INCLUDE_DIRS:STRING=${PROTOBUF_INCLUDE_DIRS}
|
||||||
-DPROTOBUF_LIBRARIES:STRING=${protobuf_STATIC_LIBRARIES}
|
-DPROTOBUF_LIBRARIES:STRING=${protobuf_STATIC_LIBRARIES}
|
||||||
-DZLIB_ROOT:STRING=${ZLIB_INSTALL}
|
-DZLIB_ROOT:STRING=${ZLIB_INSTALL}
|
||||||
-DgRPC_SSL_PROVIDER:STRING=NONE
|
-DgRPC_SSL_PROVIDER:STRING=${grpc_SSL_PROVIDER}
|
||||||
)
|
)
|
||||||
|
|
||||||
# grpc/src/core/ext/census/tracing.c depends on the existence of openssl/rand.h.
|
# grpc/src/core/ext/census/tracing.c depends on the existence of openssl/rand.h.
|
||||||
|
@ -77,6 +77,7 @@ py_library(
|
|||||||
"//tensorflow/python:device_util",
|
"//tensorflow/python:device_util",
|
||||||
"//tensorflow/python:distribute",
|
"//tensorflow/python:distribute",
|
||||||
"//tensorflow/python:framework_ops",
|
"//tensorflow/python:framework_ops",
|
||||||
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python:pywrap_tensorflow",
|
"//tensorflow/python:pywrap_tensorflow",
|
||||||
"//tensorflow/python:training",
|
"//tensorflow/python:training",
|
||||||
"//tensorflow/python:variable_scope",
|
"//tensorflow/python:variable_scope",
|
||||||
@ -590,3 +591,22 @@ cuda_py_test(
|
|||||||
"notsan",
|
"notsan",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cuda_py_test(
|
||||||
|
name = "metrics_v1_test",
|
||||||
|
srcs = ["metrics_v1_test.py"],
|
||||||
|
additional_deps = [
|
||||||
|
":combinations",
|
||||||
|
"@absl_py//absl/testing:parameterized",
|
||||||
|
"//tensorflow/contrib/data/python/ops:batching",
|
||||||
|
"//tensorflow/python:math_ops",
|
||||||
|
"//tensorflow/python:metrics",
|
||||||
|
"//tensorflow/python:variables",
|
||||||
|
"//tensorflow/python/data/ops:dataset_ops",
|
||||||
|
"//tensorflow/python/eager:test",
|
||||||
|
],
|
||||||
|
tags = [
|
||||||
|
"multi_and_single_gpu",
|
||||||
|
"no_pip",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
438
tensorflow/contrib/distribute/python/metrics_v1_test.py
Normal file
438
tensorflow/contrib/distribute/python/metrics_v1_test.py
Normal file
@ -0,0 +1,438 @@
|
|||||||
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for V1 metrics."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
|
|
||||||
|
from tensorflow.contrib.data.python.ops import batching
|
||||||
|
from tensorflow.contrib.distribute.python import combinations
|
||||||
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
|
from tensorflow.python.eager import test
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.ops import metrics
|
||||||
|
from tensorflow.python.ops import variables
|
||||||
|
|
||||||
|
|
||||||
|
def _labeled_dataset_fn():
|
||||||
|
# First four batches of x: labels, predictions -> (labels == predictions)
|
||||||
|
# 0: 0, 0 -> True; 1: 1, 1 -> True; 2: 2, 2 -> True; 3: 3, 0 -> False
|
||||||
|
# 4: 4, 1 -> False; 5: 0, 2 -> False; 6: 1, 0 -> False; 7: 2, 1 -> False
|
||||||
|
# 8: 3, 2 -> False; 9: 4, 0 -> False; 10: 0, 1 -> False; 11: 1, 2 -> False
|
||||||
|
# 12: 2, 0 -> False; 13: 3, 1 -> False; 14: 4, 2 -> False; 15: 0, 0 -> True
|
||||||
|
return dataset_ops.Dataset.range(1000).map(
|
||||||
|
lambda x: {"labels": x % 5, "predictions": x % 3}).batch(4)
|
||||||
|
|
||||||
|
|
||||||
|
def _boolean_dataset_fn():
|
||||||
|
# First four batches of labels, predictions: {TP, FP, TN, FN}
|
||||||
|
# with a threshold of 0.5:
|
||||||
|
# T, T -> TP; F, T -> FP; T, F -> FN
|
||||||
|
# F, F -> TN; T, T -> TP; F, T -> FP
|
||||||
|
# T, F -> FN; F, F -> TN; T, T -> TP
|
||||||
|
# F, T -> FP; T, F -> FN; F, F -> TN
|
||||||
|
return dataset_ops.Dataset.from_tensor_slices({
|
||||||
|
"labels": [True, False, True, False],
|
||||||
|
"predictions": [True, True, False, False]}).repeat().batch(3)
|
||||||
|
|
||||||
|
|
||||||
|
def _threshold_dataset_fn():
|
||||||
|
# First four batches of labels, predictions: {TP, FP, TN, FN}
|
||||||
|
# with a threshold of 0.5:
|
||||||
|
# True, 1.0 -> TP; False, .75 -> FP; True, .25 -> FN
|
||||||
|
# False, 0.0 -> TN; True, 1.0 -> TP; False, .75 -> FP
|
||||||
|
# True, .25 -> FN; False, 0.0 -> TN; True, 1.0 -> TP
|
||||||
|
# False, .75 -> FP; True, .25 -> FN; False, 0.0 -> TN
|
||||||
|
return dataset_ops.Dataset.from_tensor_slices({
|
||||||
|
"labels": [True, False, True, False],
|
||||||
|
"predictions": [1.0, 0.75, 0.25, 0.]}).repeat().batch(3)
|
||||||
|
|
||||||
|
|
||||||
|
def _regression_dataset_fn():
|
||||||
|
return dataset_ops.Dataset.from_tensor_slices({
|
||||||
|
"labels": [1., .5, 1., 0.],
|
||||||
|
"predictions": [1., .75, .25, 0.]}).repeat()
|
||||||
|
|
||||||
|
|
||||||
|
def all_combinations():
|
||||||
|
return combinations.combine(
|
||||||
|
distribution=[combinations.default_strategy,
|
||||||
|
combinations.one_device_strategy,
|
||||||
|
combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||||
|
combinations.mirrored_strategy_with_two_gpus],
|
||||||
|
mode=["graph"])
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(josh11b): Test metrics.recall_at_top_k, metrics.average_precision_at_k,
|
||||||
|
# metrics.precision_at_k
|
||||||
|
class MetricsV1Test(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
def _test_metric(self, distribution, dataset_fn, metric_fn, expected_fn):
|
||||||
|
with ops.Graph().as_default(), distribution.scope():
|
||||||
|
iterator = distribution.distribute_dataset(
|
||||||
|
dataset_fn).make_one_shot_iterator()
|
||||||
|
value, update = distribution.call_for_each_tower(
|
||||||
|
metric_fn, iterator.get_next())
|
||||||
|
update = distribution.group(update)
|
||||||
|
self.evaluate(variables.local_variables_initializer())
|
||||||
|
# TODO(josh11b): Once we switch to using a global batch size for input,
|
||||||
|
# replace "distribution.num_towers" with "1".
|
||||||
|
batches_per_update = distribution.num_towers
|
||||||
|
|
||||||
|
# Update variables using the first `num_towers` batches.
|
||||||
|
self.evaluate(update)
|
||||||
|
self.assertAllClose(expected_fn(batches_per_update), self.evaluate(value),
|
||||||
|
0.001, msg="After first update")
|
||||||
|
|
||||||
|
# Update variables using the second `num_towers` batches.
|
||||||
|
self.evaluate(update)
|
||||||
|
self.assertAllClose(expected_fn(2 * batches_per_update),
|
||||||
|
self.evaluate(value),
|
||||||
|
0.001,
|
||||||
|
msg="After second update")
|
||||||
|
|
||||||
|
if batches_per_update == 1: # Consume 4 input batches
|
||||||
|
self.evaluate(update)
|
||||||
|
self.assertAllClose(expected_fn(3 * batches_per_update),
|
||||||
|
self.evaluate(value),
|
||||||
|
0.001,
|
||||||
|
msg="After third update")
|
||||||
|
self.evaluate(update)
|
||||||
|
self.assertAllClose(expected_fn(4 * batches_per_update),
|
||||||
|
self.evaluate(value),
|
||||||
|
0.001,
|
||||||
|
msg="After fourth update")
|
||||||
|
|
||||||
|
@combinations.generate(all_combinations())
|
||||||
|
def testMean(self, distribution):
|
||||||
|
def _dataset_fn():
|
||||||
|
return dataset_ops.Dataset.range(1000).map(math_ops.to_float).batch(4)
|
||||||
|
|
||||||
|
def _expected_fn(num_batches):
|
||||||
|
# Mean(0..3) = 1.5, Mean(0..7) = 3.5, Mean(0..11) = 5.5, etc.
|
||||||
|
return num_batches * 2 - 0.5
|
||||||
|
|
||||||
|
self._test_metric(distribution, _dataset_fn, metrics.mean, _expected_fn)
|
||||||
|
|
||||||
|
@combinations.generate(all_combinations())
|
||||||
|
def testAccuracy(self, distribution):
|
||||||
|
def _metric_fn(x):
|
||||||
|
labels = x["labels"]
|
||||||
|
predictions = x["predictions"]
|
||||||
|
return metrics.accuracy(labels, predictions)
|
||||||
|
|
||||||
|
def _expected_fn(num_batches):
|
||||||
|
return [3./4, 3./8, 3./12, 4./16][num_batches - 1]
|
||||||
|
|
||||||
|
self._test_metric(
|
||||||
|
distribution, _labeled_dataset_fn, _metric_fn, _expected_fn)
|
||||||
|
|
||||||
|
@combinations.generate(all_combinations())
|
||||||
|
def testMeanPerClassAccuracy(self, distribution):
|
||||||
|
def _metric_fn(x):
|
||||||
|
labels = x["labels"]
|
||||||
|
predictions = x["predictions"]
|
||||||
|
return metrics.mean_per_class_accuracy(
|
||||||
|
labels, predictions, num_classes=5)
|
||||||
|
|
||||||
|
def _expected_fn(num_batches):
|
||||||
|
mean = lambda x: sum(x) / len(x)
|
||||||
|
return [mean([1., 1., 1., 0., 0.]),
|
||||||
|
mean([0.5, 0.5, 0.5, 0., 0.]),
|
||||||
|
mean([1./3, 1./3, 0.5, 0., 0.]),
|
||||||
|
mean([0.5, 1./3, 1./3, 0., 0.])][num_batches - 1]
|
||||||
|
|
||||||
|
self._test_metric(
|
||||||
|
distribution, _labeled_dataset_fn, _metric_fn, _expected_fn)
|
||||||
|
|
||||||
|
@combinations.generate(all_combinations())
|
||||||
|
def testMeanIOU(self, distribution):
|
||||||
|
def _metric_fn(x):
|
||||||
|
labels = x["labels"]
|
||||||
|
predictions = x["predictions"]
|
||||||
|
return metrics.mean_iou(
|
||||||
|
labels, predictions, num_classes=5)
|
||||||
|
|
||||||
|
def _expected_fn(num_batches):
|
||||||
|
mean = lambda x: sum(x) / len(x)
|
||||||
|
return [mean([1./2, 1./1, 1./1, 0.]), # no class 4 in first batch
|
||||||
|
mean([1./4, 1./4, 1./3, 0., 0.]),
|
||||||
|
mean([1./6, 1./6, 1./5, 0., 0.]),
|
||||||
|
mean([2./8, 1./7, 1./7, 0., 0.])][num_batches - 1]
|
||||||
|
|
||||||
|
self._test_metric(
|
||||||
|
distribution, _labeled_dataset_fn, _metric_fn, _expected_fn)
|
||||||
|
|
||||||
|
@combinations.generate(all_combinations())
|
||||||
|
def testMeanTensor(self, distribution):
|
||||||
|
def _dataset_fn():
|
||||||
|
dataset = dataset_ops.Dataset.range(1000).map(math_ops.to_float)
|
||||||
|
# Want to produce a fixed, known shape, so drop remainder when batching.
|
||||||
|
dataset = dataset.apply(batching.batch_and_drop_remainder(4))
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
def _expected_fn(num_batches):
|
||||||
|
# Mean(0, 4, ..., 4 * num_batches - 4) == 2 * num_batches - 2
|
||||||
|
# Mean(1, 5, ..., 4 * num_batches - 3) == 2 * num_batches - 1
|
||||||
|
# Mean(2, 6, ..., 4 * num_batches - 2) == 2 * num_batches
|
||||||
|
# Mean(3, 7, ..., 4 * num_batches - 1) == 2 * num_batches + 1
|
||||||
|
first = 2. * num_batches - 2.
|
||||||
|
return [first, first + 1., first + 2., first + 3.]
|
||||||
|
|
||||||
|
self._test_metric(
|
||||||
|
distribution, _dataset_fn, metrics.mean_tensor, _expected_fn)
|
||||||
|
|
||||||
|
@combinations.generate(all_combinations())
|
||||||
|
def testAUCROC(self, distribution):
|
||||||
|
def _metric_fn(x):
|
||||||
|
labels = x["labels"]
|
||||||
|
predictions = x["predictions"]
|
||||||
|
return metrics.auc(labels, predictions, num_thresholds=8, curve="ROC",
|
||||||
|
summation_method="careful_interpolation")
|
||||||
|
|
||||||
|
def _expected_fn(num_batches):
|
||||||
|
return [0.5, 7./9, 0.8, 0.75][num_batches - 1]
|
||||||
|
|
||||||
|
self._test_metric(
|
||||||
|
distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
|
||||||
|
|
||||||
|
@combinations.generate(all_combinations())
|
||||||
|
def testAUCPR(self, distribution):
|
||||||
|
def _metric_fn(x):
|
||||||
|
labels = x["labels"]
|
||||||
|
predictions = x["predictions"]
|
||||||
|
return metrics.auc(labels, predictions, num_thresholds=8, curve="PR",
|
||||||
|
summation_method="careful_interpolation")
|
||||||
|
|
||||||
|
def _expected_fn(num_batches):
|
||||||
|
return [0.797267, 0.851238, 0.865411, 0.797267][num_batches - 1]
|
||||||
|
|
||||||
|
self._test_metric(
|
||||||
|
distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
|
||||||
|
|
||||||
|
@combinations.generate(all_combinations())
|
||||||
|
def testFalseNegatives(self, distribution):
|
||||||
|
def _metric_fn(x):
|
||||||
|
labels = x["labels"]
|
||||||
|
predictions = x["predictions"]
|
||||||
|
return metrics.false_negatives(labels, predictions)
|
||||||
|
|
||||||
|
def _expected_fn(num_batches):
|
||||||
|
return [1., 1., 2., 3.][num_batches - 1]
|
||||||
|
|
||||||
|
self._test_metric(
|
||||||
|
distribution, _boolean_dataset_fn, _metric_fn, _expected_fn)
|
||||||
|
|
||||||
|
@combinations.generate(all_combinations())
|
||||||
|
def testFalseNegativesAtThresholds(self, distribution):
|
||||||
|
def _metric_fn(x):
|
||||||
|
labels = x["labels"]
|
||||||
|
predictions = x["predictions"]
|
||||||
|
return metrics.false_negatives_at_thresholds(labels, predictions, [.5])
|
||||||
|
|
||||||
|
def _expected_fn(num_batches):
|
||||||
|
return [[1.], [1.], [2.], [3.]][num_batches - 1]
|
||||||
|
|
||||||
|
self._test_metric(
|
||||||
|
distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
|
||||||
|
|
||||||
|
@combinations.generate(all_combinations())
|
||||||
|
def testTrueNegatives(self, distribution):
|
||||||
|
def _metric_fn(x):
|
||||||
|
labels = x["labels"]
|
||||||
|
predictions = x["predictions"]
|
||||||
|
return metrics.true_negatives(labels, predictions)
|
||||||
|
|
||||||
|
def _expected_fn(num_batches):
|
||||||
|
return [0., 1., 2., 3.][num_batches - 1]
|
||||||
|
|
||||||
|
self._test_metric(
|
||||||
|
distribution, _boolean_dataset_fn, _metric_fn, _expected_fn)
|
||||||
|
|
||||||
|
@combinations.generate(all_combinations())
|
||||||
|
def testTrueNegativesAtThresholds(self, distribution):
|
||||||
|
def _metric_fn(x):
|
||||||
|
labels = x["labels"]
|
||||||
|
predictions = x["predictions"]
|
||||||
|
return metrics.true_negatives_at_thresholds(labels, predictions, [.5])
|
||||||
|
|
||||||
|
def _expected_fn(num_batches):
|
||||||
|
return [[0.], [1.], [2.], [3.]][num_batches - 1]
|
||||||
|
|
||||||
|
self._test_metric(
|
||||||
|
distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
|
||||||
|
|
||||||
|
@combinations.generate(all_combinations())
|
||||||
|
def testFalsePositives(self, distribution):
|
||||||
|
def _metric_fn(x):
|
||||||
|
labels = x["labels"]
|
||||||
|
predictions = x["predictions"]
|
||||||
|
return metrics.false_positives(labels, predictions)
|
||||||
|
|
||||||
|
def _expected_fn(num_batches):
|
||||||
|
return [1., 2., 2., 3.][num_batches - 1]
|
||||||
|
|
||||||
|
self._test_metric(
|
||||||
|
distribution, _boolean_dataset_fn, _metric_fn, _expected_fn)
|
||||||
|
|
||||||
|
@combinations.generate(all_combinations())
|
||||||
|
def testFalsePositivesAtThresholds(self, distribution):
|
||||||
|
def _metric_fn(x):
|
||||||
|
labels = x["labels"]
|
||||||
|
predictions = x["predictions"]
|
||||||
|
return metrics.false_positives_at_thresholds(labels, predictions, [.5])
|
||||||
|
|
||||||
|
def _expected_fn(num_batches):
|
||||||
|
return [[1.], [2.], [2.], [3.]][num_batches - 1]
|
||||||
|
|
||||||
|
self._test_metric(
|
||||||
|
distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
|
||||||
|
|
||||||
|
@combinations.generate(all_combinations())
|
||||||
|
def testTruePositives(self, distribution):
|
||||||
|
def _metric_fn(x):
|
||||||
|
labels = x["labels"]
|
||||||
|
predictions = x["predictions"]
|
||||||
|
return metrics.true_positives(labels, predictions)
|
||||||
|
|
||||||
|
def _expected_fn(num_batches):
|
||||||
|
return [1., 2., 3., 3.][num_batches - 1]
|
||||||
|
|
||||||
|
self._test_metric(
|
||||||
|
distribution, _boolean_dataset_fn, _metric_fn, _expected_fn)
|
||||||
|
|
||||||
|
@combinations.generate(all_combinations())
|
||||||
|
def testTruePositivesAtThresholds(self, distribution):
|
||||||
|
def _metric_fn(x):
|
||||||
|
labels = x["labels"]
|
||||||
|
predictions = x["predictions"]
|
||||||
|
return metrics.true_positives_at_thresholds(labels, predictions, [.5])
|
||||||
|
|
||||||
|
def _expected_fn(num_batches):
|
||||||
|
return [[1.], [2.], [3.], [3.]][num_batches - 1]
|
||||||
|
|
||||||
|
self._test_metric(
|
||||||
|
distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
|
||||||
|
|
||||||
|
@combinations.generate(all_combinations())
|
||||||
|
def testPrecision(self, distribution):
|
||||||
|
def _metric_fn(x):
|
||||||
|
labels = x["labels"]
|
||||||
|
predictions = x["predictions"]
|
||||||
|
return metrics.precision(labels, predictions)
|
||||||
|
|
||||||
|
def _expected_fn(num_batches):
|
||||||
|
return [0.5, 0.5, 0.6, 0.5][num_batches - 1]
|
||||||
|
|
||||||
|
self._test_metric(
|
||||||
|
distribution, _boolean_dataset_fn, _metric_fn, _expected_fn)
|
||||||
|
|
||||||
|
@combinations.generate(all_combinations())
|
||||||
|
def testPrecisionAtThreshold(self, distribution):
|
||||||
|
def _metric_fn(x):
|
||||||
|
labels = x["labels"]
|
||||||
|
predictions = x["predictions"]
|
||||||
|
return metrics.precision_at_thresholds(labels, predictions, [0.5])
|
||||||
|
|
||||||
|
def _expected_fn(num_batches):
|
||||||
|
return [[0.5], [0.5], [0.6], [0.5]][num_batches - 1]
|
||||||
|
|
||||||
|
self._test_metric(
|
||||||
|
distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
|
||||||
|
|
||||||
|
@combinations.generate(all_combinations())
|
||||||
|
def testRecall(self, distribution):
|
||||||
|
def _metric_fn(x):
|
||||||
|
labels = x["labels"]
|
||||||
|
predictions = x["predictions"]
|
||||||
|
return metrics.recall(labels, predictions)
|
||||||
|
|
||||||
|
def _expected_fn(num_batches):
|
||||||
|
return [0.5, 2./3, 0.6, 0.5][num_batches - 1]
|
||||||
|
|
||||||
|
self._test_metric(
|
||||||
|
distribution, _boolean_dataset_fn, _metric_fn, _expected_fn)
|
||||||
|
|
||||||
|
@combinations.generate(all_combinations())
|
||||||
|
def testRecallAtThreshold(self, distribution):
|
||||||
|
def _metric_fn(x):
|
||||||
|
labels = x["labels"]
|
||||||
|
predictions = x["predictions"]
|
||||||
|
return metrics.recall_at_thresholds(labels, predictions, [0.5])
|
||||||
|
|
||||||
|
def _expected_fn(num_batches):
|
||||||
|
return [[0.5], [2./3], [0.6], [0.5]][num_batches - 1]
|
||||||
|
|
||||||
|
self._test_metric(
|
||||||
|
distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
|
||||||
|
|
||||||
|
@combinations.generate(all_combinations())
|
||||||
|
def testMeanSquaredError(self, distribution):
|
||||||
|
def _metric_fn(x):
|
||||||
|
labels = x["labels"]
|
||||||
|
predictions = x["predictions"]
|
||||||
|
return metrics.mean_squared_error(labels, predictions)
|
||||||
|
|
||||||
|
def _expected_fn(num_batches):
|
||||||
|
return [0., 1./32, 0.208333, 0.15625][num_batches - 1]
|
||||||
|
|
||||||
|
self._test_metric(
|
||||||
|
distribution, _regression_dataset_fn, _metric_fn, _expected_fn)
|
||||||
|
|
||||||
|
@combinations.generate(all_combinations())
|
||||||
|
def testRootMeanSquaredError(self, distribution):
|
||||||
|
def _metric_fn(x):
|
||||||
|
labels = x["labels"]
|
||||||
|
predictions = x["predictions"]
|
||||||
|
return metrics.root_mean_squared_error(labels, predictions)
|
||||||
|
|
||||||
|
def _expected_fn(num_batches):
|
||||||
|
return [0., 0.176777, 0.456435, 0.395285][num_batches - 1]
|
||||||
|
|
||||||
|
self._test_metric(
|
||||||
|
distribution, _regression_dataset_fn, _metric_fn, _expected_fn)
|
||||||
|
|
||||||
|
@combinations.generate(all_combinations())
|
||||||
|
def testSensitivityAtSpecificity(self, distribution):
|
||||||
|
def _metric_fn(x):
|
||||||
|
labels = x["labels"]
|
||||||
|
predictions = x["predictions"]
|
||||||
|
return metrics.sensitivity_at_specificity(labels, predictions, 0.8)
|
||||||
|
|
||||||
|
def _expected_fn(num_batches):
|
||||||
|
return [0.5, 2./3, 0.6, 0.5][num_batches - 1]
|
||||||
|
|
||||||
|
self._test_metric(
|
||||||
|
distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
|
||||||
|
|
||||||
|
@combinations.generate(all_combinations())
|
||||||
|
def testSpecificityAtSensitivity(self, distribution):
|
||||||
|
def _metric_fn(x):
|
||||||
|
labels = x["labels"]
|
||||||
|
predictions = x["predictions"]
|
||||||
|
return metrics.specificity_at_sensitivity(labels, predictions, 0.95)
|
||||||
|
|
||||||
|
def _expected_fn(num_batches):
|
||||||
|
return [0., 1./3, 0.5, 0.5][num_batches - 1]
|
||||||
|
|
||||||
|
self._test_metric(
|
||||||
|
distribution, _threshold_dataset_fn, _metric_fn, _expected_fn)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test.main()
|
@ -31,6 +31,7 @@ from tensorflow.python.eager import tape
|
|||||||
from tensorflow.python.framework import device as tf_device
|
from tensorflow.python.framework import device as tf_device
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import variable_scope
|
from tensorflow.python.ops import variable_scope
|
||||||
from tensorflow.python.training import coordinator
|
from tensorflow.python.training import coordinator
|
||||||
from tensorflow.python.training import device_util
|
from tensorflow.python.training import device_util
|
||||||
@ -343,6 +344,13 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
|
|||||||
**values.select_device_mirrored(d, kwargs))
|
**values.select_device_mirrored(d, kwargs))
|
||||||
return values.regroup(updates, values.Mirrored)
|
return values.regroup(updates, values.Mirrored)
|
||||||
|
|
||||||
|
def read_var(self, tower_local_var):
|
||||||
|
"""Read the aggregate value of a tower-local variable."""
|
||||||
|
if isinstance(tower_local_var, values.TowerLocalVariable):
|
||||||
|
return math_ops.add_n(self.unwrap(tower_local_var))
|
||||||
|
assert isinstance(tower_local_var, values.Mirrored)
|
||||||
|
return array_ops.identity(tower_local_var.get())
|
||||||
|
|
||||||
def _fetch(self, val, destination, fn):
|
def _fetch(self, val, destination, fn):
|
||||||
"""Return a copy of `val` or `fn(val)` on `destination`."""
|
"""Return a copy of `val` or `fn(val)` on `destination`."""
|
||||||
if isinstance(val, values.TowerLocalVariable):
|
if isinstance(val, values.TowerLocalVariable):
|
||||||
|
@ -102,6 +102,10 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy):
|
|||||||
with ops.device(self._device), distribute_lib.UpdateContext(self._device):
|
with ops.device(self._device), distribute_lib.UpdateContext(self._device):
|
||||||
return fn(*args, **kwargs)
|
return fn(*args, **kwargs)
|
||||||
|
|
||||||
|
def read_var(self, tower_local_var):
|
||||||
|
"""Read the aggregate value of a tower-local variable."""
|
||||||
|
return array_ops.identity(tower_local_var)
|
||||||
|
|
||||||
def _fetch(self, val, destination, fn):
|
def _fetch(self, val, destination, fn):
|
||||||
"""Return a copy of `val` or `fn(val)` on `destination`."""
|
"""Return a copy of `val` or `fn(val)` on `destination`."""
|
||||||
with ops.device(self._device):
|
with ops.device(self._device):
|
||||||
|
@ -29,7 +29,9 @@ from tensorflow.contrib.distributions.python.ops import mvn_diag
|
|||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import random_ops
|
||||||
from tensorflow.python.ops.distributions import categorical
|
from tensorflow.python.ops.distributions import categorical
|
||||||
from tensorflow.python.ops.distributions import normal
|
from tensorflow.python.ops.distributions import normal
|
||||||
from tensorflow.python.ops.linalg import linear_operator_diag
|
from tensorflow.python.ops.linalg import linear_operator_diag
|
||||||
@ -540,5 +542,51 @@ class PadDynamicTest(_PadTest, test.TestCase):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class TestMoveDimension(test.TestCase):
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
|
def test_move_dimension_static_shape(self):
|
||||||
|
|
||||||
|
x = random_ops.random_normal(shape=[200, 30, 4, 1, 6])
|
||||||
|
|
||||||
|
x_perm = distribution_util.move_dimension(x, 1, 1)
|
||||||
|
self.assertAllEqual(x_perm.shape.as_list(), [200, 30, 4, 1, 6])
|
||||||
|
|
||||||
|
x_perm = distribution_util.move_dimension(x, 0, 3)
|
||||||
|
self.assertAllEqual(x_perm.shape.as_list(), [30, 4, 1, 200, 6])
|
||||||
|
|
||||||
|
x_perm = distribution_util.move_dimension(x, 0, -2)
|
||||||
|
self.assertAllEqual(x_perm.shape.as_list(), [30, 4, 1, 200, 6])
|
||||||
|
|
||||||
|
x_perm = distribution_util.move_dimension(x, 4, 2)
|
||||||
|
self.assertAllEqual(x_perm.shape.as_list(), [200, 30, 6, 4, 1])
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes()
|
||||||
|
def test_move_dimension_dynamic_shape(self):
|
||||||
|
|
||||||
|
x_ = random_ops.random_normal(shape=[200, 30, 4, 1, 6])
|
||||||
|
x = array_ops.placeholder_with_default(input=x_, shape=None)
|
||||||
|
|
||||||
|
x_perm = distribution_util.move_dimension(x, 1, 1)
|
||||||
|
self.assertAllEqual(self.evaluate(array_ops.shape(x_perm)),
|
||||||
|
[200, 30, 4, 1, 6])
|
||||||
|
|
||||||
|
x_perm = distribution_util.move_dimension(x, 0, 3)
|
||||||
|
self.assertAllEqual(self.evaluate(array_ops.shape(x_perm)),
|
||||||
|
[30, 4, 1, 200, 6])
|
||||||
|
|
||||||
|
x_perm = distribution_util.move_dimension(x, 0, -2)
|
||||||
|
self.assertAllEqual(self.evaluate(array_ops.shape(x_perm)),
|
||||||
|
[30, 4, 1, 200, 6])
|
||||||
|
|
||||||
|
x_perm = distribution_util.move_dimension(x, 4, 2)
|
||||||
|
self.assertAllEqual(self.evaluate(array_ops.shape(x_perm)),
|
||||||
|
[200, 30, 6, 4, 1])
|
||||||
|
|
||||||
|
x_perm = distribution_util.move_dimension(x, -1, 2)
|
||||||
|
self.assertAllEqual(self.evaluate(array_ops.shape(x_perm)),
|
||||||
|
[200, 30, 6, 4, 1])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -21,12 +21,19 @@ from __future__ import print_function
|
|||||||
from tensorflow.contrib import linalg
|
from tensorflow.contrib import linalg
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import smart_cond
|
||||||
from tensorflow.python.framework import tensor_util
|
from tensorflow.python.framework import tensor_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import check_ops
|
from tensorflow.python.ops import check_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops.distributions import distribution as distribution_lib
|
from tensorflow.python.ops.distributions import distribution as distribution_lib
|
||||||
|
|
||||||
|
# The following two lines are redundant, in a sense. The first enables
|
||||||
|
# good coding practice *within* this file (`util.prefer_static_value`
|
||||||
|
# rather than `prefer_static_value`). The second ensures that users
|
||||||
|
# also get the core utils when they import this file.
|
||||||
|
from tensorflow.python.ops.distributions import util
|
||||||
from tensorflow.python.ops.distributions.util import * # pylint: disable=wildcard-import
|
from tensorflow.python.ops.distributions.util import * # pylint: disable=wildcard-import
|
||||||
|
|
||||||
|
|
||||||
@ -484,3 +491,75 @@ def pad_mixture_dimensions(x, mixture_distribution, categorical_distribution,
|
|||||||
def static_value(x):
|
def static_value(x):
|
||||||
"""Returns the static value of a `Tensor` or `None`."""
|
"""Returns the static value of a `Tensor` or `None`."""
|
||||||
return tensor_util.constant_value(ops.convert_to_tensor(x))
|
return tensor_util.constant_value(ops.convert_to_tensor(x))
|
||||||
|
|
||||||
|
|
||||||
|
def move_dimension(x, source_idx, dest_idx):
|
||||||
|
"""Move a single tensor dimension within its shape.
|
||||||
|
|
||||||
|
This is a special case of `tf.transpose()`, which applies
|
||||||
|
arbitrary permutations to tensor dimensions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Tensor of rank `ndims`.
|
||||||
|
source_idx: Integer index into `x.shape` (negative indexing is
|
||||||
|
supported).
|
||||||
|
dest_idx: Integer index into `x.shape` (negative indexing is
|
||||||
|
supported).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
x_perm: Tensor of rank `ndims`, in which the dimension at original
|
||||||
|
index `source_idx` has been moved to new index `dest_idx`, with
|
||||||
|
all other dimensions retained in their original order.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
x = tf.placeholder(shape=[200, 30, 4, 1, 6])
|
||||||
|
x_perm = _move_dimension(x, 1, 1) # no-op
|
||||||
|
x_perm = _move_dimension(x, 0, 3) # result shape [30, 4, 1, 200, 6]
|
||||||
|
x_perm = _move_dimension(x, 0, -2) # equivalent to previous
|
||||||
|
x_perm = _move_dimension(x, 4, 2) # result shape [200, 30, 6, 4, 1]
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
ndims = util.prefer_static_rank(x)
|
||||||
|
if isinstance(source_idx, int):
|
||||||
|
dtype = dtypes.int32
|
||||||
|
else:
|
||||||
|
dtype = dtypes.as_dtype(source_idx.dtype)
|
||||||
|
|
||||||
|
# Handle negative indexing. Since ndims might be dynamic, this makes
|
||||||
|
# source_idx and dest_idx also possibly dynamic.
|
||||||
|
if source_idx < 0:
|
||||||
|
source_idx = ndims + source_idx
|
||||||
|
if dest_idx < 0:
|
||||||
|
dest_idx = ndims + dest_idx
|
||||||
|
|
||||||
|
# Construct the appropriate permutation of dimensions, depending
|
||||||
|
# whether the source is before or after the destination.
|
||||||
|
def move_left_permutation():
|
||||||
|
return util.prefer_static_value(
|
||||||
|
array_ops.concat([
|
||||||
|
math_ops.range(0, dest_idx, dtype=dtype),
|
||||||
|
[source_idx],
|
||||||
|
math_ops.range(dest_idx, source_idx, dtype=dtype),
|
||||||
|
math_ops.range(source_idx+1, ndims, dtype=dtype)], axis=0))
|
||||||
|
|
||||||
|
def move_right_permutation():
|
||||||
|
return util.prefer_static_value(
|
||||||
|
array_ops.concat([
|
||||||
|
math_ops.range(0, source_idx, dtype=dtype),
|
||||||
|
math_ops.range(source_idx+1, dest_idx+1, dtype=dtype),
|
||||||
|
[source_idx],
|
||||||
|
math_ops.range(dest_idx+1, ndims, dtype=dtype)], axis=0))
|
||||||
|
|
||||||
|
def x_permuted():
|
||||||
|
return array_ops.transpose(
|
||||||
|
x, perm=smart_cond.smart_cond(source_idx < dest_idx,
|
||||||
|
move_right_permutation,
|
||||||
|
move_left_permutation))
|
||||||
|
|
||||||
|
# One final conditional to handle the special case where source
|
||||||
|
# and destination indices are equal.
|
||||||
|
return smart_cond.smart_cond(math_ops.equal(source_idx, dest_idx),
|
||||||
|
lambda: x,
|
||||||
|
x_permuted)
|
||||||
|
@ -529,6 +529,7 @@ def multi_label_head(n_classes,
|
|||||||
applications, the shape is `[batch_size, n_classes]`.
|
applications, the shape is `[batch_size, n_classes]`.
|
||||||
|
|
||||||
Labels can be:
|
Labels can be:
|
||||||
|
|
||||||
* A multi-hot tensor of shape `[D0, D1, ... DN, n_classes]`
|
* A multi-hot tensor of shape `[D0, D1, ... DN, n_classes]`
|
||||||
* An integer `SparseTensor` of class indices. The `dense_shape` must be
|
* An integer `SparseTensor` of class indices. The `dense_shape` must be
|
||||||
`[D0, D1, ... DN, ?]` and the values within `[0, n_classes)`.
|
`[D0, D1, ... DN, ?]` and the values within `[0, n_classes)`.
|
||||||
|
@ -28,7 +28,6 @@ from tensorflow.python.framework import dtypes
|
|||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import functional_ops
|
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import tensor_array_ops
|
from tensorflow.python.ops import tensor_array_ops
|
||||||
|
|
||||||
@ -279,13 +278,27 @@ def _assert_increasing(t):
|
|||||||
return ops.control_dependencies([assert_increasing])
|
return ops.control_dependencies([assert_increasing])
|
||||||
|
|
||||||
|
|
||||||
def _check_input_types(t, y0):
|
def _check_input_types(y0, t, dt=None):
|
||||||
if not (y0.dtype.is_floating or y0.dtype.is_complex):
|
if not (y0.dtype.is_floating or y0.dtype.is_complex):
|
||||||
raise TypeError('`y0` must have a floating point or complex floating '
|
raise TypeError('`y0` must have a floating point or complex floating '
|
||||||
'point dtype')
|
'point dtype')
|
||||||
if not t.dtype.is_floating:
|
if not t.dtype.is_floating:
|
||||||
raise TypeError('`t` must have a floating point dtype')
|
raise TypeError('`t` must have a floating point dtype')
|
||||||
|
|
||||||
|
if dt is not None and not dt.dtype.is_floating:
|
||||||
|
raise TypeError('`dt` must have a floating point dtype')
|
||||||
|
|
||||||
|
|
||||||
|
def _check_input_sizes(t, dt):
|
||||||
|
if len(t.get_shape().as_list()) > 1:
|
||||||
|
raise ValueError('t must be a 1D tensor')
|
||||||
|
|
||||||
|
if len(dt.get_shape().as_list()) > 1:
|
||||||
|
raise ValueError('t must be a 1D tensor')
|
||||||
|
|
||||||
|
if t.get_shape()[0] != dt.get_shape()[0] + 1:
|
||||||
|
raise ValueError('t and dt have incompatible lengths, must be N and N-1')
|
||||||
|
|
||||||
|
|
||||||
def _dopri5(func,
|
def _dopri5(func,
|
||||||
y0,
|
y0,
|
||||||
@ -510,7 +523,7 @@ def odeint(func,
|
|||||||
# avoiding the need to pack/unpack in user functions.
|
# avoiding the need to pack/unpack in user functions.
|
||||||
y0 = ops.convert_to_tensor(y0, name='y0')
|
y0 = ops.convert_to_tensor(y0, name='y0')
|
||||||
t = ops.convert_to_tensor(t, preferred_dtype=dtypes.float64, name='t')
|
t = ops.convert_to_tensor(t, preferred_dtype=dtypes.float64, name='t')
|
||||||
_check_input_types(t, y0)
|
_check_input_types(y0, t)
|
||||||
|
|
||||||
error_dtype = abs(y0).dtype
|
error_dtype = abs(y0).dtype
|
||||||
rtol = ops.convert_to_tensor(rtol, dtype=error_dtype, name='rtol')
|
rtol = ops.convert_to_tensor(rtol, dtype=error_dtype, name='rtol')
|
||||||
@ -530,24 +543,74 @@ def odeint(func,
|
|||||||
class _FixedGridIntegrator(six.with_metaclass(abc.ABCMeta)):
|
class _FixedGridIntegrator(six.with_metaclass(abc.ABCMeta)):
|
||||||
"""Base class for fixed-grid ODE integrators."""
|
"""Base class for fixed-grid ODE integrators."""
|
||||||
|
|
||||||
def integrate(self, evol_func, y0, time_grid):
|
def integrate(self, evol_func, y0, time_grid, dt_grid, steps_on_intervals):
|
||||||
time_delta_grid = time_grid[1:] - time_grid[:-1]
|
"""Returns integrated values of differential equation on the `time grid`.
|
||||||
|
|
||||||
scan_func = self._make_scan_func(evol_func)
|
Numerically integrates differential equation defined via time derivative
|
||||||
|
evaluator `evol_func` using fixed time steps specified in dt_grid.
|
||||||
|
|
||||||
y_grid = functional_ops.scan(scan_func, (time_grid[:-1], time_delta_grid),
|
Args:
|
||||||
y0)
|
evol_func: Callable, evaluates time derivative of y at a given time.
|
||||||
return array_ops.concat([[y0], y_grid], axis=0)
|
y0: N-D Tensor holds initial values of the solution.
|
||||||
|
time_grid: 1-D Tensor holding the time points at which the solution
|
||||||
|
will be recorded, must have a floating dtype.
|
||||||
|
dt_grid: 1-D Tensor holds fixed time steps to be used on time_grid
|
||||||
|
intervals. Must be a floating dtype and have one less element than that
|
||||||
|
of the time_grid.
|
||||||
|
steps_on_intervals: 1-D Tensor of integer dtype, must have the same size
|
||||||
|
as dt_grid. Specifies number of steps needed for every interval. Assumes
|
||||||
|
steps_on_intervals * dt_grid == time intervals.
|
||||||
|
|
||||||
def _make_scan_func(self, evol_func):
|
Returns:
|
||||||
|
(N+1)-D tensor, where the first dimension corresponds to different
|
||||||
|
time points. Contains the solved value of y for each desired time point in
|
||||||
|
`t`, with the initial value `y0` being the first element along the first
|
||||||
|
dimension.
|
||||||
|
"""
|
||||||
|
|
||||||
def scan_func(y, t_and_dt):
|
iteration_func = self._make_iteration_func(evol_func, dt_grid)
|
||||||
t, dt = t_and_dt
|
integrate_interval = self._make_interval_integrator(iteration_func,
|
||||||
|
steps_on_intervals)
|
||||||
|
|
||||||
|
num_times = array_ops.size(time_grid)
|
||||||
|
current_time = time_grid[0]
|
||||||
|
solution_array = tensor_array_ops.TensorArray(y0.dtype, num_times)
|
||||||
|
solution_array = solution_array.write(0, y0)
|
||||||
|
|
||||||
|
solution_array, _, _, _ = control_flow_ops.while_loop(
|
||||||
|
lambda _, __, ___, i: i < num_times,
|
||||||
|
integrate_interval,
|
||||||
|
(solution_array, y0, current_time, 1)
|
||||||
|
)
|
||||||
|
solution_array = solution_array.stack()
|
||||||
|
solution_array.set_shape(time_grid.get_shape().concatenate(y0.get_shape()))
|
||||||
|
return solution_array
|
||||||
|
|
||||||
|
def _make_iteration_func(self, evol_func, dt_grid):
|
||||||
|
"""Returns a function that builds operations of a single time step."""
|
||||||
|
|
||||||
|
def iteration_func(y, t, dt_step, interval_step):
|
||||||
|
"""Performs a single time step advance."""
|
||||||
|
dt = dt_grid[interval_step - 1]
|
||||||
dy = self._step_func(evol_func, t, dt, y)
|
dy = self._step_func(evol_func, t, dt, y)
|
||||||
dy = math_ops.cast(dy, dtype=y.dtype)
|
dy = math_ops.cast(dy, dtype=y.dtype)
|
||||||
return y + dy
|
return y + dy, t + dt, dt_step + 1, interval_step
|
||||||
|
|
||||||
return scan_func
|
return iteration_func
|
||||||
|
|
||||||
|
def _make_interval_integrator(self, iteration_func, interval_sizes):
|
||||||
|
"""Returns a function that builds operations for interval integration."""
|
||||||
|
|
||||||
|
def integrate_interval(solution_array, y, t, interval_num):
|
||||||
|
"""Integrates y with fixed time step on interval `interval_num`."""
|
||||||
|
y, t, _, _ = control_flow_ops.while_loop(
|
||||||
|
lambda _, __, j, interval_num: j < interval_sizes[interval_num - 1],
|
||||||
|
iteration_func,
|
||||||
|
(y, t, 0, interval_num)
|
||||||
|
)
|
||||||
|
return solution_array.write(interval_num, y), y, t, interval_num + 1
|
||||||
|
|
||||||
|
return integrate_interval
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def _step_func(self, evol_func, t, dt, y):
|
def _step_func(self, evol_func, t, dt, y):
|
||||||
@ -555,6 +618,7 @@ class _FixedGridIntegrator(six.with_metaclass(abc.ABCMeta)):
|
|||||||
|
|
||||||
|
|
||||||
class _MidpointFixedGridIntegrator(_FixedGridIntegrator):
|
class _MidpointFixedGridIntegrator(_FixedGridIntegrator):
|
||||||
|
"""Fixed grid integrator implementing midpoint scheme."""
|
||||||
|
|
||||||
def _step_func(self, evol_func, t, dt, y):
|
def _step_func(self, evol_func, t, dt, y):
|
||||||
dt_cast = math_ops.cast(dt, y.dtype)
|
dt_cast = math_ops.cast(dt, y.dtype)
|
||||||
@ -563,6 +627,7 @@ class _MidpointFixedGridIntegrator(_FixedGridIntegrator):
|
|||||||
|
|
||||||
|
|
||||||
class _RK4FixedGridIntegrator(_FixedGridIntegrator):
|
class _RK4FixedGridIntegrator(_FixedGridIntegrator):
|
||||||
|
"""Fixed grid integrator implementing RK4 scheme."""
|
||||||
|
|
||||||
def _step_func(self, evol_func, t, dt, y):
|
def _step_func(self, evol_func, t, dt, y):
|
||||||
k1 = evol_func(y, t)
|
k1 = evol_func(y, t)
|
||||||
@ -575,7 +640,7 @@ class _RK4FixedGridIntegrator(_FixedGridIntegrator):
|
|||||||
return math_ops.add_n([k1, 2 * k2, 2 * k3, k4]) * (dt_cast / 6)
|
return math_ops.add_n([k1, 2 * k2, 2 * k3, k4]) * (dt_cast / 6)
|
||||||
|
|
||||||
|
|
||||||
def odeint_fixed(func, y0, t, method='rk4', name=None):
|
def odeint_fixed(func, y0, t, dt=None, method='rk4', name=None):
|
||||||
"""ODE integration on a fixed grid (with no step size control).
|
"""ODE integration on a fixed grid (with no step size control).
|
||||||
|
|
||||||
Useful in certain scenarios to avoid the overhead of adaptive step size
|
Useful in certain scenarios to avoid the overhead of adaptive step size
|
||||||
@ -590,6 +655,14 @@ def odeint_fixed(func, y0, t, method='rk4', name=None):
|
|||||||
`y`. The initial time point should be the first element of this sequence,
|
`y`. The initial time point should be the first element of this sequence,
|
||||||
and each time must be larger than the previous time. May have any floating
|
and each time must be larger than the previous time. May have any floating
|
||||||
point dtype.
|
point dtype.
|
||||||
|
dt: 0-D or 1-D Tensor providing time step suggestion to be used on time
|
||||||
|
integration intervals in `t`. 1-D Tensor should provide values
|
||||||
|
for all intervals, must have 1 less element than that of `t`.
|
||||||
|
If given a 0-D Tensor, the value is interpreted as time step suggestion
|
||||||
|
same for all intervals. If passed None, then time step is set to be the
|
||||||
|
t[1:] - t[:-1]. Defaults to None. The actual step size is obtained by
|
||||||
|
insuring an integer number of steps per interval, potentially reducing the
|
||||||
|
time step.
|
||||||
method: One of 'midpoint' or 'rk4'.
|
method: One of 'midpoint' or 'rk4'.
|
||||||
name: Optional name for the resulting operation.
|
name: Optional name for the resulting operation.
|
||||||
|
|
||||||
@ -602,16 +675,29 @@ def odeint_fixed(func, y0, t, method='rk4', name=None):
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: Upon caller errors.
|
ValueError: Upon caller errors.
|
||||||
"""
|
"""
|
||||||
with ops.name_scope(name, 'odeint_fixed', [y0, t]):
|
with ops.name_scope(name, 'odeint_fixed', [y0, t, dt]):
|
||||||
t = ops.convert_to_tensor(t, preferred_dtype=dtypes.float64, name='t')
|
t = ops.convert_to_tensor(t, preferred_dtype=dtypes.float64, name='t')
|
||||||
y0 = ops.convert_to_tensor(y0, name='y0')
|
y0 = ops.convert_to_tensor(y0, name='y0')
|
||||||
_check_input_types(t, y0)
|
|
||||||
|
intervals = t[1:] - t[:-1]
|
||||||
|
if dt is None:
|
||||||
|
dt = intervals
|
||||||
|
dt = ops.convert_to_tensor(dt, preferred_dtype=dtypes.float64, name='dt')
|
||||||
|
|
||||||
|
steps_on_intervals = math_ops.ceil(intervals / dt)
|
||||||
|
dt = intervals / steps_on_intervals
|
||||||
|
steps_on_intervals = math_ops.cast(steps_on_intervals, dtype=dtypes.int32)
|
||||||
|
|
||||||
|
_check_input_types(y0, t, dt)
|
||||||
|
_check_input_sizes(t, dt)
|
||||||
|
|
||||||
with _assert_increasing(t):
|
with _assert_increasing(t):
|
||||||
with ops.name_scope(method):
|
with ops.name_scope(method):
|
||||||
if method == 'midpoint':
|
if method == 'midpoint':
|
||||||
return _MidpointFixedGridIntegrator().integrate(func, y0, t)
|
return _MidpointFixedGridIntegrator().integrate(func, y0, t, dt,
|
||||||
|
steps_on_intervals)
|
||||||
elif method == 'rk4':
|
elif method == 'rk4':
|
||||||
return _RK4FixedGridIntegrator().integrate(func, y0, t)
|
return _RK4FixedGridIntegrator().integrate(func, y0, t, dt,
|
||||||
|
steps_on_intervals)
|
||||||
else:
|
else:
|
||||||
raise ValueError('method not supported: {!s}'.format(method))
|
raise ValueError('method not supported: {!s}'.format(method))
|
||||||
|
@ -242,40 +242,56 @@ class InterpolationTest(test.TestCase):
|
|||||||
|
|
||||||
class OdeIntFixedTest(test.TestCase):
|
class OdeIntFixedTest(test.TestCase):
|
||||||
|
|
||||||
def _test_integrate_sine(self, method):
|
def _test_integrate_sine(self, method, t, dt=None):
|
||||||
|
|
||||||
def evol_func(y, t):
|
def evol_func(y, t):
|
||||||
del t
|
del t
|
||||||
return array_ops.stack([y[1], -y[0]])
|
return array_ops.stack([y[1], -y[0]])
|
||||||
|
|
||||||
y0 = [0., 1.]
|
y0 = [0., 1.]
|
||||||
time_grid = np.linspace(0., 10., 200)
|
y_grid = odes.odeint_fixed(evol_func, y0, t, dt, method=method)
|
||||||
y_grid = odes.odeint_fixed(evol_func, y0, time_grid, method=method)
|
|
||||||
|
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
y_grid_array = sess.run(y_grid)
|
y_grid_array = sess.run(y_grid)
|
||||||
|
|
||||||
np.testing.assert_allclose(
|
np.testing.assert_allclose(
|
||||||
y_grid_array[:, 0], np.sin(time_grid), rtol=1e-2, atol=1e-2)
|
y_grid_array[:, 0], np.sin(t), rtol=1e-2, atol=1e-2)
|
||||||
|
|
||||||
def _test_integrate_gaussian(self, method):
|
def _test_integrate_gaussian(self, method, t, dt=None):
|
||||||
|
|
||||||
def evol_func(y, t):
|
def evol_func(y, t):
|
||||||
return -math_ops.cast(t, dtype=y.dtype) * y[0]
|
return -math_ops.cast(t, dtype=y.dtype) * y[0]
|
||||||
|
|
||||||
y0 = [1.]
|
y0 = [1.]
|
||||||
time_grid = np.linspace(0., 2., 100)
|
y_grid = odes.odeint_fixed(evol_func, y0, t, dt, method=method)
|
||||||
y_grid = odes.odeint_fixed(evol_func, y0, time_grid, method=method)
|
|
||||||
|
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
y_grid_array = sess.run(y_grid)
|
y_grid_array = sess.run(y_grid)
|
||||||
|
|
||||||
np.testing.assert_allclose(
|
np.testing.assert_allclose(
|
||||||
y_grid_array[:, 0], np.exp(-time_grid**2 / 2), rtol=1e-2, atol=1e-2)
|
y_grid_array[:, 0], np.exp(-t**2 / 2), rtol=1e-2, atol=1e-2)
|
||||||
|
|
||||||
|
def _test_integrate_sine_all(self, method):
|
||||||
|
uniform_time_grid = np.linspace(0., 10., 200)
|
||||||
|
non_uniform_time_grid = np.asarray([0.0, 0.4, 4.7, 5.2, 7.0])
|
||||||
|
uniform_dt = 0.02
|
||||||
|
non_uniform_dt = np.asarray([0.01, 0.001, 0.05, 0.03])
|
||||||
|
self._test_integrate_sine(method, uniform_time_grid)
|
||||||
|
self._test_integrate_sine(method, non_uniform_time_grid, uniform_dt)
|
||||||
|
self._test_integrate_sine(method, non_uniform_time_grid, non_uniform_dt)
|
||||||
|
|
||||||
|
def _test_integrate_gaussian_all(self, method):
|
||||||
|
uniform_time_grid = np.linspace(0., 2., 100)
|
||||||
|
non_uniform_time_grid = np.asarray([0.0, 0.1, 0.7, 1.2, 2.0])
|
||||||
|
uniform_dt = 0.01
|
||||||
|
non_uniform_dt = np.asarray([0.01, 0.001, 0.1, 0.03])
|
||||||
|
self._test_integrate_gaussian(method, uniform_time_grid)
|
||||||
|
self._test_integrate_gaussian(method, non_uniform_time_grid, uniform_dt)
|
||||||
|
self._test_integrate_gaussian(method, non_uniform_time_grid, non_uniform_dt)
|
||||||
|
|
||||||
def _test_everything(self, method):
|
def _test_everything(self, method):
|
||||||
self._test_integrate_sine(method)
|
self._test_integrate_sine_all(method)
|
||||||
self._test_integrate_gaussian(method)
|
self._test_integrate_gaussian_all(method)
|
||||||
|
|
||||||
def test_midpoint(self):
|
def test_midpoint(self):
|
||||||
self._test_everything('midpoint')
|
self._test_everything('midpoint')
|
||||||
@ -283,6 +299,21 @@ class OdeIntFixedTest(test.TestCase):
|
|||||||
def test_rk4(self):
|
def test_rk4(self):
|
||||||
self._test_everything('rk4')
|
self._test_everything('rk4')
|
||||||
|
|
||||||
|
def test_dt_size_exceptions(self):
|
||||||
|
times = np.linspace(0., 2., 100)
|
||||||
|
dt = np.ones(99) * 0.01
|
||||||
|
dt_wrong_length = np.asarray([0.01, 0.001, 0.1, 0.03])
|
||||||
|
dt_wrong_dim = np.expand_dims(np.linspace(0., 2., 99), axis=0)
|
||||||
|
times_wrong_dim = np.expand_dims(np.linspace(0., 2., 100), axis=0)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
self._test_integrate_gaussian('midpoint', times, dt_wrong_length)
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
self._test_integrate_gaussian('midpoint', times, dt_wrong_dim)
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
self._test_integrate_gaussian('midpoint', times_wrong_dim, dt)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -17,7 +17,7 @@ limitations under the License.
|
|||||||
#define TENSORFLOW_CONTRIB_LITE_BUILTIN_OPS_H_
|
#define TENSORFLOW_CONTRIB_LITE_BUILTIN_OPS_H_
|
||||||
|
|
||||||
// DO NOT EDIT MANUALLY: This file is automatically generated by
|
// DO NOT EDIT MANUALLY: This file is automatically generated by
|
||||||
// `schema_builtin_ops_header_generator.py`.
|
// `schema/builtin_ops_header/generator.cc`.
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
extern "C" {
|
extern "C" {
|
||||||
|
@ -474,8 +474,9 @@ cc_test(
|
|||||||
)
|
)
|
||||||
|
|
||||||
cc_test(
|
cc_test(
|
||||||
name = "resize_bilinear_float_test",
|
name = "resize_bilinear_test",
|
||||||
srcs = ["resize_bilinear_float_test.cc"],
|
srcs = ["resize_bilinear_test.cc"],
|
||||||
|
tags = ["tflite_not_portable"],
|
||||||
deps = [
|
deps = [
|
||||||
":optimized_base",
|
":optimized_base",
|
||||||
":reference_base",
|
":reference_base",
|
||||||
|
@ -5722,6 +5722,46 @@ inline void ResizeBilinearGeneric(const float* input_data,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline void ResizeBilinearGenericSmallChannel(
|
||||||
|
const T* input_data, const Dims<4>& input_dims, T* output_data,
|
||||||
|
const Dims<4>& output_dims, int32 batches, int32 input_height,
|
||||||
|
int32 input_width, int32 depth, int32 output_height, int32 output_width,
|
||||||
|
float height_scale, float width_scale) {
|
||||||
|
memset(output_data, 0,
|
||||||
|
batches * output_height * output_width * depth * sizeof(T));
|
||||||
|
|
||||||
|
T* output_ptr = &output_data[0];
|
||||||
|
for (int b = 0; b < batches; ++b) {
|
||||||
|
for (int y = 0; y < output_height; ++y) {
|
||||||
|
float input_y = y * height_scale;
|
||||||
|
int32 y0 = static_cast<int32>(std::floor(input_y));
|
||||||
|
int32 y1 = std::min(y0 + 1, input_height - 1);
|
||||||
|
for (int x = 0; x < output_width; ++x) {
|
||||||
|
float input_x = x * width_scale;
|
||||||
|
int32 x0 = static_cast<int32>(input_x);
|
||||||
|
int32 x1 = std::min(x0 + 1, input_width - 1);
|
||||||
|
|
||||||
|
int32 input_offset[4] = {
|
||||||
|
Offset(input_dims, 0, x0, y0, b), Offset(input_dims, 0, x1, y0, b),
|
||||||
|
Offset(input_dims, 0, x0, y1, b), Offset(input_dims, 0, x1, y1, b)};
|
||||||
|
float scale[4] = {(1 - (input_y - y0)) * (1 - (input_x - x0)),
|
||||||
|
(1 - (input_y - y0)) * (input_x - x0),
|
||||||
|
(input_y - y0) * (1 - (input_x - x0)),
|
||||||
|
(input_y - y0) * (input_x - x0)};
|
||||||
|
|
||||||
|
for (int d = 0; d < depth; d++) {
|
||||||
|
const T* input_ptr = &input_data[d];
|
||||||
|
*output_ptr++ = static_cast<T>(input_ptr[input_offset[0]] * scale[0] +
|
||||||
|
input_ptr[input_offset[1]] * scale[1] +
|
||||||
|
input_ptr[input_offset[2]] * scale[2] +
|
||||||
|
input_ptr[input_offset[3]] * scale[3]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
|
inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
|
||||||
const int32* output_size_data,
|
const int32* output_size_data,
|
||||||
const Dims<4>& output_size_dims, float* output_data,
|
const Dims<4>& output_size_dims, float* output_data,
|
||||||
@ -5762,6 +5802,41 @@ inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(prabhumk): This is not a real quantized bilinear. It does not use int8
|
||||||
|
// or int16 arithmetic.
|
||||||
|
inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims,
|
||||||
|
const int32* output_size_data,
|
||||||
|
const Dims<4>& output_size_dims, uint8* output_data,
|
||||||
|
const Dims<4>& output_dims, bool align_corners) {
|
||||||
|
gemmlowp::ScopedProfilingLabel label("ResizeBilinear");
|
||||||
|
int32 batches = MatchingArraySize(input_dims, 3, output_dims, 3);
|
||||||
|
int32 input_height = ArraySize(input_dims, 2);
|
||||||
|
int32 input_width = ArraySize(input_dims, 1);
|
||||||
|
int32 depth = MatchingArraySize(input_dims, 0, output_dims, 0);
|
||||||
|
|
||||||
|
TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 3), 1);
|
||||||
|
TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 2), 1);
|
||||||
|
TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 1), 1);
|
||||||
|
TFLITE_DCHECK_EQ(ArraySize(output_size_dims, 0), 2);
|
||||||
|
int32 output_height = output_size_data[Offset(output_size_dims, 0, 0, 0, 0)];
|
||||||
|
int32 output_width = output_size_data[Offset(output_size_dims, 1, 0, 0, 0)];
|
||||||
|
|
||||||
|
float height_scale =
|
||||||
|
(align_corners && output_height > 1)
|
||||||
|
? (static_cast<float>(input_height - 1) / (output_height - 1))
|
||||||
|
: (static_cast<float>(input_height) / output_height);
|
||||||
|
|
||||||
|
float width_scale =
|
||||||
|
(align_corners && output_width > 1)
|
||||||
|
? (static_cast<float>(input_width - 1) / (output_width - 1))
|
||||||
|
: (static_cast<float>(input_width) / output_width);
|
||||||
|
|
||||||
|
ResizeBilinearGenericSmallChannel<uint8>(
|
||||||
|
input_data, input_dims, output_data, output_dims, batches, input_height,
|
||||||
|
input_width, depth, output_height, output_width, height_scale,
|
||||||
|
width_scale);
|
||||||
|
}
|
||||||
|
|
||||||
// legacy, for compatibility with old checked-in code
|
// legacy, for compatibility with old checked-in code
|
||||||
inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
|
inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
|
||||||
const int32* output_size_data,
|
const int32* output_size_data,
|
||||||
@ -5771,6 +5846,15 @@ inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
|
|||||||
output_data, output_dims, /*align_corners=*/false);
|
output_data, output_dims, /*align_corners=*/false);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// legacy, for compatibility with old checked-in code
|
||||||
|
inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims,
|
||||||
|
const int32* output_size_data,
|
||||||
|
const Dims<4>& output_size_dims, uint8* output_data,
|
||||||
|
const Dims<4>& output_dims) {
|
||||||
|
ResizeBilinear(input_data, input_dims, output_size_data, output_size_dims,
|
||||||
|
output_data, output_dims, /*align_corners=*/false);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims,
|
inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims,
|
||||||
const int32* block_shape_data,
|
const int32* block_shape_data,
|
||||||
|
@ -3202,9 +3202,10 @@ inline void Gather(const T* input_data, const Dims<4>& input_dims,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
|
template <typename T>
|
||||||
|
inline void ResizeBilinear(const T* input_data, const Dims<4>& input_dims,
|
||||||
const int32* output_size_data,
|
const int32* output_size_data,
|
||||||
const Dims<4>& output_size_dims, float* output_data,
|
const Dims<4>& output_size_dims, T* output_data,
|
||||||
const Dims<4>& output_dims, bool align_corners) {
|
const Dims<4>& output_dims, bool align_corners) {
|
||||||
int32 batches = MatchingArraySize(input_dims, 3, output_dims, 3);
|
int32 batches = MatchingArraySize(input_dims, 3, output_dims, 3);
|
||||||
int32 input_height = ArraySize(input_dims, 2);
|
int32 input_height = ArraySize(input_dims, 2);
|
||||||
@ -3236,15 +3237,15 @@ inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
|
|||||||
int32 x0 = static_cast<int32>(std::floor(input_x));
|
int32 x0 = static_cast<int32>(std::floor(input_x));
|
||||||
int32 x1 = std::min(x0 + 1, input_width - 1);
|
int32 x1 = std::min(x0 + 1, input_width - 1);
|
||||||
for (int c = 0; c < depth; ++c) {
|
for (int c = 0; c < depth; ++c) {
|
||||||
float interpolation = input_data[Offset(input_dims, c, x0, y0, b)] *
|
T interpolation =
|
||||||
(1 - (input_y - y0)) *
|
static_cast<T>(input_data[Offset(input_dims, c, x0, y0, b)] *
|
||||||
(1 - (input_x - x0)) +
|
(1 - (input_y - y0)) * (1 - (input_x - x0)) +
|
||||||
input_data[Offset(input_dims, c, x0, y1, b)] *
|
input_data[Offset(input_dims, c, x0, y1, b)] *
|
||||||
(input_y - y0) * (1 - (input_x - x0)) +
|
(input_y - y0) * (1 - (input_x - x0)) +
|
||||||
input_data[Offset(input_dims, c, x1, y0, b)] *
|
input_data[Offset(input_dims, c, x1, y0, b)] *
|
||||||
(1 - (input_y - y0)) * (input_x - x0) +
|
(1 - (input_y - y0)) * (input_x - x0) +
|
||||||
input_data[Offset(input_dims, c, x1, y1, b)] *
|
input_data[Offset(input_dims, c, x1, y1, b)] *
|
||||||
(input_y - y0) * (input_x - x0);
|
(input_y - y0) * (input_x - x0));
|
||||||
output_data[Offset(output_dims, c, x, y, b)] = interpolation;
|
output_data[Offset(output_dims, c, x, y, b)] = interpolation;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -3257,8 +3258,18 @@ inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims,
|
|||||||
const int32* output_size_data,
|
const int32* output_size_data,
|
||||||
const Dims<4>& output_size_dims, float* output_data,
|
const Dims<4>& output_size_dims, float* output_data,
|
||||||
const Dims<4>& output_dims) {
|
const Dims<4>& output_dims) {
|
||||||
ResizeBilinear(input_data, input_dims, output_size_data, output_size_dims,
|
ResizeBilinear<float>(input_data, input_dims, output_size_data,
|
||||||
output_data, output_dims, /*align_corners=*/false);
|
output_size_dims, output_data, output_dims,
|
||||||
|
/*align_corners=*/false);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void ResizeBilinear(const uint8* input_data, const Dims<4>& input_dims,
|
||||||
|
const int32* output_size_data,
|
||||||
|
const Dims<4>& output_size_dims, uint8* output_data,
|
||||||
|
const Dims<4>& output_dims) {
|
||||||
|
ResizeBilinear<uint8>(input_data, input_dims, output_size_data,
|
||||||
|
output_size_dims, output_data, output_dims,
|
||||||
|
/*align_corners=*/false);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
@ -24,9 +24,10 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tflite {
|
namespace tflite {
|
||||||
namespace {
|
namespace {
|
||||||
|
template <typename T>
|
||||||
void TestOneResizeBilinear(int batch, int depth, int input_width,
|
void TestOneResizeBilinear(int batch, int depth, int input_width,
|
||||||
int input_height, int output_width,
|
int input_height, int output_width,
|
||||||
int output_height) {
|
int output_height, float error_threshold) {
|
||||||
Dims<4> input_dims_inference =
|
Dims<4> input_dims_inference =
|
||||||
MakeDimsForInference(depth, input_width, input_height, batch);
|
MakeDimsForInference(depth, input_width, input_height, batch);
|
||||||
Dims<4> output_dims_inference =
|
Dims<4> output_dims_inference =
|
||||||
@ -36,14 +37,15 @@ void TestOneResizeBilinear(int batch, int depth, int input_width,
|
|||||||
const int output_buffer_size =
|
const int output_buffer_size =
|
||||||
RequiredBufferSizeForDims(output_dims_inference);
|
RequiredBufferSizeForDims(output_dims_inference);
|
||||||
|
|
||||||
std::vector<float> input_data(input_buffer_size, 0);
|
std::vector<T> input_data(input_buffer_size, 0);
|
||||||
std::vector<float> reference_output_data(output_buffer_size, 0);
|
std::vector<T> reference_output_data(output_buffer_size, 0);
|
||||||
// Initialize the output data with something other than zero, so we can catch
|
// Initialize the output data with something other than zero, so we can catch
|
||||||
// issue with kernels failing to initialize the output.
|
// issue with kernels failing to initialize the output.
|
||||||
std::vector<float> output_data(output_buffer_size, 3.1415);
|
std::vector<T> output_data(output_buffer_size, 3);
|
||||||
|
|
||||||
const float input_amplitude = 1.f;
|
const T min_amplitude = static_cast<T>(0);
|
||||||
FillRandom(&input_data, -input_amplitude, input_amplitude);
|
const T max_amplitude = static_cast<T>(255);
|
||||||
|
FillRandom(&input_data, min_amplitude, max_amplitude);
|
||||||
|
|
||||||
Dims<4> output_size_dims = MakeDimsForInference(2, 1, 1, 1);
|
Dims<4> output_size_dims = MakeDimsForInference(2, 1, 1, 1);
|
||||||
std::vector<int32> output_size_data = {output_height, output_width};
|
std::vector<int32> output_size_data = {output_height, output_width};
|
||||||
@ -58,14 +60,46 @@ void TestOneResizeBilinear(int batch, int depth, int input_width,
|
|||||||
double sum_diff = 0;
|
double sum_diff = 0;
|
||||||
float max_abs_val = 0;
|
float max_abs_val = 0;
|
||||||
for (int i = 0; i < output_buffer_size; i++) {
|
for (int i = 0; i < output_buffer_size; i++) {
|
||||||
sum_diff += std::abs(output_data[i] - reference_output_data[i]);
|
sum_diff += std::abs(static_cast<float>(output_data[i]) -
|
||||||
max_abs_val = std::max(max_abs_val, std::abs(reference_output_data[i]));
|
static_cast<float>(reference_output_data[i]));
|
||||||
|
max_abs_val = std::max(
|
||||||
|
max_abs_val, std::abs(static_cast<float>(reference_output_data[i])));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (sum_diff != 0.f) {
|
if (sum_diff != 0.f) {
|
||||||
const float mean_diff = static_cast<float>(sum_diff / output_buffer_size);
|
const float mean_diff = static_cast<float>(sum_diff / output_buffer_size);
|
||||||
const float relative_error = std::abs(mean_diff) / max_abs_val;
|
const float relative_error = std::abs(mean_diff) / max_abs_val;
|
||||||
ASSERT_LT(relative_error, 1e-5f);
|
ASSERT_LT(relative_error, error_threshold);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ResizeBilinear, TestResizeBilinear8Bit) {
|
||||||
|
const int kTestsToRun = 100 * 1000;
|
||||||
|
for (int i = 0; i < kTestsToRun; i++) {
|
||||||
|
const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20);
|
||||||
|
const int depth = ExponentialRandomPositiveInt(0.9f, 6, 50);
|
||||||
|
const int input_width = ExponentialRandomPositiveInt(0.9f, 20, 200);
|
||||||
|
const int input_height = ExponentialRandomPositiveInt(0.9f, 20, 200);
|
||||||
|
const int output_width = ExponentialRandomPositiveInt(0.9f, 20, 200);
|
||||||
|
const int output_height = ExponentialRandomPositiveInt(0.9f, 20, 200);
|
||||||
|
|
||||||
|
TestOneResizeBilinear<uint8>(batch, depth, input_width, input_height,
|
||||||
|
output_width, output_height, 0.025);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ResizeBilinear2x2, TestResizeBilinear8Bit) {
|
||||||
|
const int kTestsToRun = 100 * 1000;
|
||||||
|
for (int i = 0; i < kTestsToRun; i++) {
|
||||||
|
const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20);
|
||||||
|
const int depth = ExponentialRandomPositiveInt(0.9f, 6, 50);
|
||||||
|
const int input_width = ExponentialRandomPositiveInt(0.9f, 20, 200);
|
||||||
|
const int input_height = ExponentialRandomPositiveInt(0.9f, 20, 200);
|
||||||
|
const int output_width = input_width * 2;
|
||||||
|
const int output_height = input_height * 2;
|
||||||
|
|
||||||
|
TestOneResizeBilinear<uint8>(batch, depth, input_width, input_height,
|
||||||
|
output_width, output_height, 1e-5);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -79,8 +113,8 @@ TEST(ResizeBilinear, TestResizeBilinear) {
|
|||||||
const int output_width = ExponentialRandomPositiveInt(0.9f, 20, 200);
|
const int output_width = ExponentialRandomPositiveInt(0.9f, 20, 200);
|
||||||
const int output_height = ExponentialRandomPositiveInt(0.9f, 20, 200);
|
const int output_height = ExponentialRandomPositiveInt(0.9f, 20, 200);
|
||||||
|
|
||||||
TestOneResizeBilinear(batch, depth, input_width, input_height, output_width,
|
TestOneResizeBilinear<float>(batch, depth, input_width, input_height,
|
||||||
output_height);
|
output_width, output_height, 1e-5);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -94,8 +128,8 @@ TEST(ResizeBilinear2x2, TestResizeBilinear) {
|
|||||||
const int output_width = input_width * 2;
|
const int output_width = input_width * 2;
|
||||||
const int output_height = input_height * 2;
|
const int output_height = input_height * 2;
|
||||||
|
|
||||||
TestOneResizeBilinear(batch, depth, input_width, input_height, output_width,
|
TestOneResizeBilinear<float>(batch, depth, input_width, input_height,
|
||||||
output_height);
|
output_width, output_height, 1e-5);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
@ -121,6 +121,10 @@ class RuntimeShape {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline void BuildFrom(const std::initializer_list<int> init_list) {
|
||||||
|
BuildFrom<const std::initializer_list<int>>(init_list);
|
||||||
|
}
|
||||||
|
|
||||||
// Returns the total count of elements, that is the size when flattened into a
|
// Returns the total count of elements, that is the size when flattened into a
|
||||||
// vector.
|
// vector.
|
||||||
inline int FlatSize() const {
|
inline int FlatSize() const {
|
||||||
|
@ -61,12 +61,10 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
|
TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4);
|
||||||
TF_LITE_ENSURE_EQ(context, NumDimensions(size), 1);
|
TF_LITE_ENSURE_EQ(context, NumDimensions(size), 1);
|
||||||
|
|
||||||
// TODO(ahentz): Our current implementations only support float32.
|
|
||||||
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
|
|
||||||
TF_LITE_ENSURE_EQ(context, size->type, kTfLiteInt32);
|
TF_LITE_ENSURE_EQ(context, size->type, kTfLiteInt32);
|
||||||
// ResizeBilinear creates a float tensor even when the input is made of
|
// ResizeBilinear creates a float tensor even when the input is made of
|
||||||
// integers.
|
// integers.
|
||||||
output->type = kTfLiteFloat32;
|
output->type = input->type;
|
||||||
|
|
||||||
if (!IsConstantTensor(size)) {
|
if (!IsConstantTensor(size)) {
|
||||||
SetTensorToDynamic(output);
|
SetTensorToDynamic(output);
|
||||||
@ -90,17 +88,24 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (output->type == kTfLiteFloat32) {
|
if (output->type == kTfLiteFloat32) {
|
||||||
#define TF_LITE_RESIZE_BILINEAR(type) \
|
#define TF_LITE_RESIZE_BILINEAR(type, datatype) \
|
||||||
type::ResizeBilinear(GetTensorData<float>(input), GetTensorDims(input), \
|
type::ResizeBilinear(GetTensorData<datatype>(input), GetTensorDims(input), \
|
||||||
GetTensorData<int32>(size), GetTensorDims(size), \
|
GetTensorData<int32>(size), GetTensorDims(size), \
|
||||||
GetTensorData<float>(output), GetTensorDims(output), \
|
GetTensorData<datatype>(output), GetTensorDims(output), \
|
||||||
params->align_corners)
|
params->align_corners)
|
||||||
|
|
||||||
if (kernel_type == kReference) {
|
if (kernel_type == kReference) {
|
||||||
TF_LITE_RESIZE_BILINEAR(reference_ops);
|
TF_LITE_RESIZE_BILINEAR(reference_ops, float);
|
||||||
}
|
}
|
||||||
if (kernel_type == kGenericOptimized || kernel_type == kNeonOptimized) {
|
if (kernel_type == kGenericOptimized || kernel_type == kNeonOptimized) {
|
||||||
TF_LITE_RESIZE_BILINEAR(optimized_ops);
|
TF_LITE_RESIZE_BILINEAR(optimized_ops, float);
|
||||||
|
}
|
||||||
|
} else if (output->type == kTfLiteUInt8) {
|
||||||
|
if (kernel_type == kReference) {
|
||||||
|
TF_LITE_RESIZE_BILINEAR(reference_ops, uint8_t);
|
||||||
|
}
|
||||||
|
if (kernel_type == kGenericOptimized || kernel_type == kNeonOptimized) {
|
||||||
|
TF_LITE_RESIZE_BILINEAR(optimized_ops, uint8_t);
|
||||||
}
|
}
|
||||||
#undef TF_LITE_RESIZE_BILINEAR
|
#undef TF_LITE_RESIZE_BILINEAR
|
||||||
} else {
|
} else {
|
||||||
|
@ -22,6 +22,7 @@ namespace tflite {
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
using ::testing::ElementsAreArray;
|
using ::testing::ElementsAreArray;
|
||||||
|
using uint8 = std::uint8_t;
|
||||||
|
|
||||||
class ResizeBilinearOpModel : public SingleOpModel {
|
class ResizeBilinearOpModel : public SingleOpModel {
|
||||||
public:
|
public:
|
||||||
@ -34,7 +35,7 @@ class ResizeBilinearOpModel : public SingleOpModel {
|
|||||||
} else {
|
} else {
|
||||||
size_ = AddInput({TensorType_INT32, {2}});
|
size_ = AddInput({TensorType_INT32, {2}});
|
||||||
}
|
}
|
||||||
output_ = AddOutput(TensorType_FLOAT32); // Always float.
|
output_ = AddOutput(input.type);
|
||||||
SetBuiltinOp(BuiltinOperator_RESIZE_BILINEAR,
|
SetBuiltinOp(BuiltinOperator_RESIZE_BILINEAR,
|
||||||
BuiltinOptions_ResizeBilinearOptions,
|
BuiltinOptions_ResizeBilinearOptions,
|
||||||
CreateResizeBilinearOptions(builder_).Union());
|
CreateResizeBilinearOptions(builder_).Union());
|
||||||
@ -45,12 +46,16 @@ class ResizeBilinearOpModel : public SingleOpModel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetInput(std::initializer_list<float> data) {
|
template <typename T>
|
||||||
|
void SetInput(std::initializer_list<T> data) {
|
||||||
PopulateTensor(input_, data);
|
PopulateTensor(input_, data);
|
||||||
}
|
}
|
||||||
void SetSize(std::initializer_list<int> data) { PopulateTensor(size_, data); }
|
void SetSize(std::initializer_list<int> data) { PopulateTensor(size_, data); }
|
||||||
|
|
||||||
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
|
template <typename T>
|
||||||
|
std::vector<T> GetOutput() {
|
||||||
|
return ExtractVector<T>(output_);
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int input_;
|
int input_;
|
||||||
@ -60,51 +65,112 @@ class ResizeBilinearOpModel : public SingleOpModel {
|
|||||||
|
|
||||||
TEST(ResizeBilinearOpTest, HorizontalResize) {
|
TEST(ResizeBilinearOpTest, HorizontalResize) {
|
||||||
ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 1, 2, 1}});
|
ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 1, 2, 1}});
|
||||||
m.SetInput({3, 6});
|
m.SetInput<float>({3, 6});
|
||||||
m.SetSize({1, 3});
|
m.SetSize({1, 3});
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 5, 6})));
|
EXPECT_THAT(m.GetOutput<float>(),
|
||||||
|
ElementsAreArray(ArrayFloatNear({3, 5, 6})));
|
||||||
|
|
||||||
ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 1, 2, 1}}, {1, 3});
|
ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 1, 2, 1}}, {1, 3});
|
||||||
const_m.SetInput({3, 6});
|
const_m.SetInput<float>({3, 6});
|
||||||
const_m.Invoke();
|
const_m.Invoke();
|
||||||
EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 5, 6})));
|
EXPECT_THAT(const_m.GetOutput<float>(),
|
||||||
|
ElementsAreArray(ArrayFloatNear({3, 5, 6})));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ResizeBilinearOpTest, HorizontalResize8Bit) {
|
||||||
|
ResizeBilinearOpModel m({TensorType_UINT8, {1, 1, 2, 1}});
|
||||||
|
m.SetInput<uint8>({3, 6});
|
||||||
|
m.SetSize({1, 3});
|
||||||
|
m.Invoke();
|
||||||
|
EXPECT_THAT(m.GetOutput<uint8>(),
|
||||||
|
ElementsAreArray(ArrayFloatNear({3, 5, 6})));
|
||||||
|
|
||||||
|
ResizeBilinearOpModel const_m({TensorType_UINT8, {1, 1, 2, 1}}, {1, 3});
|
||||||
|
const_m.SetInput<uint8>({3, 6});
|
||||||
|
const_m.Invoke();
|
||||||
|
EXPECT_THAT(const_m.GetOutput<uint8>(),
|
||||||
|
ElementsAreArray(ArrayFloatNear({3, 5, 6})));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ResizeBilinearOpTest, VerticalResize) {
|
TEST(ResizeBilinearOpTest, VerticalResize) {
|
||||||
ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 1, 1}});
|
ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 1, 1}});
|
||||||
m.SetInput({3, 9});
|
m.SetInput<float>({3, 9});
|
||||||
m.SetSize({3, 1});
|
m.SetSize({3, 1});
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 7, 9})));
|
EXPECT_THAT(m.GetOutput<float>(),
|
||||||
|
ElementsAreArray(ArrayFloatNear({3, 7, 9})));
|
||||||
|
|
||||||
ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 1, 1}}, {3, 1});
|
ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 1, 1}}, {3, 1});
|
||||||
const_m.SetInput({3, 9});
|
const_m.SetInput<float>({3, 9});
|
||||||
const_m.Invoke();
|
const_m.Invoke();
|
||||||
EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 7, 9})));
|
EXPECT_THAT(const_m.GetOutput<float>(),
|
||||||
|
ElementsAreArray(ArrayFloatNear({3, 7, 9})));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ResizeBilinearOpTest, VerticalResize8Bit) {
|
||||||
|
ResizeBilinearOpModel m({TensorType_UINT8, {1, 2, 1, 1}});
|
||||||
|
m.SetInput<uint8>({3, 9});
|
||||||
|
m.SetSize({3, 1});
|
||||||
|
m.Invoke();
|
||||||
|
EXPECT_THAT(m.GetOutput<uint8>(),
|
||||||
|
ElementsAreArray(ArrayFloatNear({3, 7, 9})));
|
||||||
|
|
||||||
|
ResizeBilinearOpModel const_m({TensorType_UINT8, {1, 2, 1, 1}}, {3, 1});
|
||||||
|
const_m.SetInput<uint8>({3, 9});
|
||||||
|
const_m.Invoke();
|
||||||
|
EXPECT_THAT(const_m.GetOutput<uint8>(),
|
||||||
|
ElementsAreArray(ArrayFloatNear({3, 7, 9})));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ResizeBilinearOpTest, TwoDimensionalResize) {
|
TEST(ResizeBilinearOpTest, TwoDimensionalResize) {
|
||||||
ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}});
|
ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}});
|
||||||
m.SetInput({
|
m.SetInput<float>({
|
||||||
3, 6, //
|
3, 6, //
|
||||||
9, 12 //
|
9, 12 //
|
||||||
});
|
});
|
||||||
m.SetSize({3, 3});
|
m.SetSize({3, 3});
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
|
EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
|
||||||
3, 5, 6, //
|
3, 5, 6, //
|
||||||
7, 9, 10, //
|
7, 9, 10, //
|
||||||
9, 11, 12, //
|
9, 11, 12, //
|
||||||
})));
|
})));
|
||||||
|
|
||||||
ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 2, 1}}, {3, 3});
|
ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 2, 1}}, {3, 3});
|
||||||
const_m.SetInput({
|
const_m.SetInput<float>({
|
||||||
3, 6, //
|
3, 6, //
|
||||||
9, 12 //
|
9, 12 //
|
||||||
});
|
});
|
||||||
const_m.Invoke();
|
const_m.Invoke();
|
||||||
EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({
|
EXPECT_THAT(const_m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
|
||||||
|
3, 5, 6, //
|
||||||
|
7, 9, 10, //
|
||||||
|
9, 11, 12, //
|
||||||
|
})));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ResizeBilinearOpTest, TwoDimensionalResize8Bit) {
|
||||||
|
ResizeBilinearOpModel m({TensorType_UINT8, {1, 2, 2, 1}});
|
||||||
|
m.SetInput<uint8>({
|
||||||
|
3, 6, //
|
||||||
|
9, 12 //
|
||||||
|
});
|
||||||
|
m.SetSize({3, 3});
|
||||||
|
m.Invoke();
|
||||||
|
EXPECT_THAT(m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear({
|
||||||
|
3, 5, 6, //
|
||||||
|
7, 9, 10, //
|
||||||
|
9, 11, 12, //
|
||||||
|
})));
|
||||||
|
|
||||||
|
ResizeBilinearOpModel const_m({TensorType_UINT8, {1, 2, 2, 1}}, {3, 3});
|
||||||
|
const_m.SetInput<uint8>({
|
||||||
|
3, 6, //
|
||||||
|
9, 12 //
|
||||||
|
});
|
||||||
|
const_m.Invoke();
|
||||||
|
EXPECT_THAT(const_m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear({
|
||||||
3, 5, 6, //
|
3, 5, 6, //
|
||||||
7, 9, 10, //
|
7, 9, 10, //
|
||||||
9, 11, 12, //
|
9, 11, 12, //
|
||||||
@ -113,7 +179,7 @@ TEST(ResizeBilinearOpTest, TwoDimensionalResize) {
|
|||||||
|
|
||||||
TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches) {
|
TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches) {
|
||||||
ResizeBilinearOpModel m({TensorType_FLOAT32, {2, 2, 2, 1}});
|
ResizeBilinearOpModel m({TensorType_FLOAT32, {2, 2, 2, 1}});
|
||||||
m.SetInput({
|
m.SetInput<float>({
|
||||||
3, 6, //
|
3, 6, //
|
||||||
9, 12, //
|
9, 12, //
|
||||||
4, 10, //
|
4, 10, //
|
||||||
@ -121,7 +187,7 @@ TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches) {
|
|||||||
});
|
});
|
||||||
m.SetSize({3, 3});
|
m.SetSize({3, 3});
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
|
EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
|
||||||
3, 5, 6, //
|
3, 5, 6, //
|
||||||
7, 9, 10, //
|
7, 9, 10, //
|
||||||
9, 11, 12, //
|
9, 11, 12, //
|
||||||
@ -131,14 +197,14 @@ TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches) {
|
|||||||
})));
|
})));
|
||||||
|
|
||||||
ResizeBilinearOpModel const_m({TensorType_FLOAT32, {2, 2, 2, 1}}, {3, 3});
|
ResizeBilinearOpModel const_m({TensorType_FLOAT32, {2, 2, 2, 1}}, {3, 3});
|
||||||
const_m.SetInput({
|
const_m.SetInput<float>({
|
||||||
3, 6, //
|
3, 6, //
|
||||||
9, 12, //
|
9, 12, //
|
||||||
4, 10, //
|
4, 10, //
|
||||||
10, 16 //
|
10, 16 //
|
||||||
});
|
});
|
||||||
const_m.Invoke();
|
const_m.Invoke();
|
||||||
EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({
|
EXPECT_THAT(const_m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
|
||||||
3, 5, 6, //
|
3, 5, 6, //
|
||||||
7, 9, 10, //
|
7, 9, 10, //
|
||||||
9, 11, 12, //
|
9, 11, 12, //
|
||||||
@ -150,31 +216,94 @@ TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches) {
|
|||||||
|
|
||||||
TEST(ResizeBilinearOpTest, ThreeDimensionalResize) {
|
TEST(ResizeBilinearOpTest, ThreeDimensionalResize) {
|
||||||
ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 2, 2}});
|
ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 2, 2}});
|
||||||
m.SetInput({
|
m.SetInput<float>({
|
||||||
3, 4, 6, 10, //
|
3, 4, 6, 10, //
|
||||||
9, 10, 12, 16, //
|
9, 10, 12, 16, //
|
||||||
});
|
});
|
||||||
m.SetSize({3, 3});
|
m.SetSize({3, 3});
|
||||||
m.Invoke();
|
m.Invoke();
|
||||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
|
EXPECT_THAT(m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
|
||||||
3, 4, 5, 8, 6, 10, //
|
3, 4, 5, 8, 6, 10, //
|
||||||
7, 8, 9, 12, 10, 14, //
|
7, 8, 9, 12, 10, 14, //
|
||||||
9, 10, 11, 14, 12, 16, //
|
9, 10, 11, 14, 12, 16, //
|
||||||
})));
|
})));
|
||||||
|
|
||||||
ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 2, 2}}, {3, 3});
|
ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 2, 2}}, {3, 3});
|
||||||
const_m.SetInput({
|
const_m.SetInput<float>({
|
||||||
3, 4, 6, 10, //
|
3, 4, 6, 10, //
|
||||||
9, 10, 12, 16, //
|
9, 10, 12, 16, //
|
||||||
});
|
});
|
||||||
const_m.Invoke();
|
const_m.Invoke();
|
||||||
EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({
|
EXPECT_THAT(const_m.GetOutput<float>(), ElementsAreArray(ArrayFloatNear({
|
||||||
3, 4, 5, 8, 6, 10, //
|
3, 4, 5, 8, 6, 10, //
|
||||||
7, 8, 9, 12, 10, 14, //
|
7, 8, 9, 12, 10, 14, //
|
||||||
9, 10, 11, 14, 12, 16, //
|
9, 10, 11, 14, 12, 16, //
|
||||||
})));
|
})));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches8Bit) {
|
||||||
|
ResizeBilinearOpModel m({TensorType_UINT8, {2, 2, 2, 1}});
|
||||||
|
m.SetInput<uint8>({
|
||||||
|
3, 6, //
|
||||||
|
9, 12, //
|
||||||
|
4, 10, //
|
||||||
|
10, 16 //
|
||||||
|
});
|
||||||
|
m.SetSize({3, 3});
|
||||||
|
m.Invoke();
|
||||||
|
EXPECT_THAT(m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear({
|
||||||
|
3, 5, 6, //
|
||||||
|
7, 9, 10, //
|
||||||
|
9, 11, 12, //
|
||||||
|
4, 8, 10, //
|
||||||
|
8, 12, 14, //
|
||||||
|
10, 13, 16, //
|
||||||
|
})));
|
||||||
|
|
||||||
|
ResizeBilinearOpModel const_m({TensorType_UINT8, {2, 2, 2, 1}}, {3, 3});
|
||||||
|
const_m.SetInput<uint8>({
|
||||||
|
3, 6, //
|
||||||
|
9, 12, //
|
||||||
|
4, 10, //
|
||||||
|
10, 16 //
|
||||||
|
});
|
||||||
|
const_m.Invoke();
|
||||||
|
EXPECT_THAT(const_m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear({
|
||||||
|
3, 5, 6, //
|
||||||
|
7, 9, 10, //
|
||||||
|
9, 11, 12, //
|
||||||
|
4, 8, 10, //
|
||||||
|
8, 12, 14, //
|
||||||
|
10, 13, 16, //
|
||||||
|
})));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ResizeBilinearOpTest, ThreeDimensionalResize8Bit) {
|
||||||
|
ResizeBilinearOpModel m({TensorType_UINT8, {1, 2, 2, 2}});
|
||||||
|
m.SetInput<uint8>({
|
||||||
|
3, 4, 6, 10, //
|
||||||
|
9, 10, 12, 16, //
|
||||||
|
});
|
||||||
|
m.SetSize({3, 3});
|
||||||
|
m.Invoke();
|
||||||
|
EXPECT_THAT(m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear({
|
||||||
|
3, 4, 5, 8, 6, 10, //
|
||||||
|
7, 8, 9, 12, 10, 14, //
|
||||||
|
9, 10, 11, 13, 12, 16, //
|
||||||
|
})));
|
||||||
|
|
||||||
|
ResizeBilinearOpModel const_m({TensorType_UINT8, {1, 2, 2, 2}}, {3, 3});
|
||||||
|
const_m.SetInput<uint8>({
|
||||||
|
3, 4, 6, 10, //
|
||||||
|
9, 10, 12, 16, //
|
||||||
|
});
|
||||||
|
const_m.Invoke();
|
||||||
|
EXPECT_THAT(const_m.GetOutput<uint8>(), ElementsAreArray(ArrayFloatNear({
|
||||||
|
3, 4, 5, 8, 6, 10, //
|
||||||
|
7, 8, 9, 12, 10, 14, //
|
||||||
|
9, 10, 11, 13, 12, 16, //
|
||||||
|
})));
|
||||||
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tflite
|
} // namespace tflite
|
||||||
|
|
||||||
|
@ -322,12 +322,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
|||||||
|
|
||||||
*builtin_data = nullptr;
|
*builtin_data = nullptr;
|
||||||
switch (op_type) {
|
switch (op_type) {
|
||||||
case BuiltinOperator_CALL:
|
|
||||||
// TODO(aselle): Implement call in BuiltinOptions, but nullptrs are
|
|
||||||
// ok for now, since there is no call implementation either.
|
|
||||||
break;
|
|
||||||
case BuiltinOperator_CUSTOM:
|
|
||||||
break;
|
|
||||||
case BuiltinOperator_CONV_2D: {
|
case BuiltinOperator_CONV_2D: {
|
||||||
TfLiteConvParams* params = MallocPOD<TfLiteConvParams>();
|
TfLiteConvParams* params = MallocPOD<TfLiteConvParams>();
|
||||||
if (auto* conv_params = op->builtin_options_as_Conv2DOptions()) {
|
if (auto* conv_params = op->builtin_options_as_Conv2DOptions()) {
|
||||||
@ -343,22 +337,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
|||||||
*builtin_data = reinterpret_cast<void*>(params);
|
*builtin_data = reinterpret_cast<void*>(params);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case BuiltinOperator_TANH:
|
|
||||||
case BuiltinOperator_LOGISTIC:
|
|
||||||
case BuiltinOperator_RELU:
|
|
||||||
case BuiltinOperator_RELU_N1_TO_1:
|
|
||||||
case BuiltinOperator_RELU6:
|
|
||||||
case BuiltinOperator_CONCAT_EMBEDDINGS:
|
|
||||||
case BuiltinOperator_EXP:
|
|
||||||
case BuiltinOperator_TOPK_V2:
|
|
||||||
case BuiltinOperator_LOG_SOFTMAX:
|
|
||||||
case BuiltinOperator_DEQUANTIZE:
|
|
||||||
case BuiltinOperator_PRELU:
|
|
||||||
case BuiltinOperator_FLOOR:
|
|
||||||
case BuiltinOperator_NEG:
|
|
||||||
case BuiltinOperator_SIN:
|
|
||||||
case BuiltinOperator_LOG:
|
|
||||||
break;
|
|
||||||
case BuiltinOperator_CAST: {
|
case BuiltinOperator_CAST: {
|
||||||
TfLiteCastParams* params = MallocPOD<TfLiteCastParams>();
|
TfLiteCastParams* params = MallocPOD<TfLiteCastParams>();
|
||||||
if (auto* schema_params = op->builtin_options_as_CastOptions()) {
|
if (auto* schema_params = op->builtin_options_as_CastOptions()) {
|
||||||
@ -446,9 +424,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
|||||||
*builtin_data = reinterpret_cast<void*>(params);
|
*builtin_data = reinterpret_cast<void*>(params);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case BuiltinOperator_EMBEDDING_LOOKUP:
|
|
||||||
// no-op.
|
|
||||||
break;
|
|
||||||
case BuiltinOperator_EMBEDDING_LOOKUP_SPARSE: {
|
case BuiltinOperator_EMBEDDING_LOOKUP_SPARSE: {
|
||||||
TfLiteEmbeddingLookupSparseParams* params =
|
TfLiteEmbeddingLookupSparseParams* params =
|
||||||
MallocPOD<TfLiteEmbeddingLookupSparseParams>();
|
MallocPOD<TfLiteEmbeddingLookupSparseParams>();
|
||||||
@ -580,12 +555,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
|||||||
*builtin_data = reinterpret_cast<void*>(params);
|
*builtin_data = reinterpret_cast<void*>(params);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case BuiltinOperator_PAD: {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case BuiltinOperator_PADV2: {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case BuiltinOperator_RESHAPE: {
|
case BuiltinOperator_RESHAPE: {
|
||||||
auto* params = MallocPOD<TfLiteReshapeParams>();
|
auto* params = MallocPOD<TfLiteReshapeParams>();
|
||||||
if (auto* schema_params = op->builtin_options_as_ReshapeOptions()) {
|
if (auto* schema_params = op->builtin_options_as_ReshapeOptions()) {
|
||||||
@ -625,15 +594,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
|||||||
*builtin_data = reinterpret_cast<void*>(params);
|
*builtin_data = reinterpret_cast<void*>(params);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case BuiltinOperator_SPACE_TO_BATCH_ND: {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case BuiltinOperator_BATCH_TO_SPACE_ND: {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case BuiltinOperator_TRANSPOSE: {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case BuiltinOperator_MEAN: {
|
case BuiltinOperator_MEAN: {
|
||||||
auto* params = MallocPOD<TfLiteMeanParams>();
|
auto* params = MallocPOD<TfLiteMeanParams>();
|
||||||
if (auto* schema_params = op->builtin_options_as_MeanOptions()) {
|
if (auto* schema_params = op->builtin_options_as_MeanOptions()) {
|
||||||
@ -673,10 +633,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
|||||||
*builtin_data = reinterpret_cast<void*>(params);
|
*builtin_data = reinterpret_cast<void*>(params);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case BuiltinOperator_MAXIMUM:
|
|
||||||
case BuiltinOperator_MINIMUM: {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case BuiltinOperator_ARG_MAX: {
|
case BuiltinOperator_ARG_MAX: {
|
||||||
auto* params = MallocPOD<TfLiteArgMaxParams>();
|
auto* params = MallocPOD<TfLiteArgMaxParams>();
|
||||||
if (auto* schema_params = op->builtin_options_as_ArgMaxOptions()) {
|
if (auto* schema_params = op->builtin_options_as_ArgMaxOptions()) {
|
||||||
@ -686,18 +642,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
|||||||
*builtin_data = reinterpret_cast<void*>(params);
|
*builtin_data = reinterpret_cast<void*>(params);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case BuiltinOperator_GREATER:
|
|
||||||
case BuiltinOperator_GREATER_EQUAL:
|
|
||||||
case BuiltinOperator_LESS:
|
|
||||||
case BuiltinOperator_LESS_EQUAL:
|
|
||||||
case BuiltinOperator_EQUAL:
|
|
||||||
case BuiltinOperator_NOT_EQUAL:
|
|
||||||
case BuiltinOperator_SELECT: {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case BuiltinOperator_SLICE: {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
case BuiltinOperator_TRANSPOSE_CONV: {
|
case BuiltinOperator_TRANSPOSE_CONV: {
|
||||||
TfLiteTransposeConvParams* params =
|
TfLiteTransposeConvParams* params =
|
||||||
MallocPOD<TfLiteTransposeConvParams>();
|
MallocPOD<TfLiteTransposeConvParams>();
|
||||||
@ -725,11 +669,47 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
|||||||
error_reporter->Report("DELEGATE op shouldn't exist in model.");
|
error_reporter->Report("DELEGATE op shouldn't exist in model.");
|
||||||
return kTfLiteError;
|
return kTfLiteError;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Below are the ops with no builtin_data strcture.
|
||||||
|
case BuiltinOperator_BATCH_TO_SPACE_ND:
|
||||||
|
// TODO(aselle): Implement call in BuiltinOptions, but nullptrs are
|
||||||
|
// ok for now, since there is no call implementation either.
|
||||||
|
case BuiltinOperator_CALL:
|
||||||
|
case BuiltinOperator_CONCAT_EMBEDDINGS:
|
||||||
|
case BuiltinOperator_CUSTOM:
|
||||||
|
case BuiltinOperator_DEQUANTIZE:
|
||||||
|
case BuiltinOperator_EMBEDDING_LOOKUP:
|
||||||
|
case BuiltinOperator_EQUAL:
|
||||||
|
case BuiltinOperator_EXP:
|
||||||
case BuiltinOperator_EXPAND_DIMS:
|
case BuiltinOperator_EXPAND_DIMS:
|
||||||
case BuiltinOperator_TILE: {
|
case BuiltinOperator_FLOOR:
|
||||||
|
case BuiltinOperator_GREATER:
|
||||||
|
case BuiltinOperator_GREATER_EQUAL:
|
||||||
|
case BuiltinOperator_LESS:
|
||||||
|
case BuiltinOperator_LESS_EQUAL:
|
||||||
|
case BuiltinOperator_LOG:
|
||||||
|
case BuiltinOperator_LOGISTIC:
|
||||||
|
case BuiltinOperator_LOG_SOFTMAX:
|
||||||
|
case BuiltinOperator_MAXIMUM:
|
||||||
|
case BuiltinOperator_MINIMUM:
|
||||||
|
case BuiltinOperator_NEG:
|
||||||
|
case BuiltinOperator_NOT_EQUAL:
|
||||||
|
case BuiltinOperator_PAD:
|
||||||
|
case BuiltinOperator_PADV2:
|
||||||
|
case BuiltinOperator_PRELU:
|
||||||
|
case BuiltinOperator_RELU:
|
||||||
|
case BuiltinOperator_RELU6:
|
||||||
|
case BuiltinOperator_RELU_N1_TO_1:
|
||||||
|
case BuiltinOperator_SELECT:
|
||||||
|
case BuiltinOperator_SIN:
|
||||||
|
case BuiltinOperator_SLICE:
|
||||||
|
case BuiltinOperator_SPACE_TO_BATCH_ND:
|
||||||
|
case BuiltinOperator_TANH:
|
||||||
|
case BuiltinOperator_TILE:
|
||||||
|
case BuiltinOperator_TOPK_V2:
|
||||||
|
case BuiltinOperator_TRANSPOSE:
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
return kTfLiteOk;
|
return kTfLiteOk;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -234,7 +234,10 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
|
|||||||
next_id++;
|
next_id++;
|
||||||
};
|
};
|
||||||
|
|
||||||
auto add_add_params = [&add_scalar_int32]() { add_scalar_int32(0); };
|
auto add_add_params = [&add_scalar_int32](void* data) {
|
||||||
|
auto* builtin = reinterpret_cast<TfLiteAddParams*>(data);
|
||||||
|
add_scalar_int32(builtin->activation);
|
||||||
|
};
|
||||||
|
|
||||||
auto add_pooling_params = [&add_scalar_int32](void* data) {
|
auto add_pooling_params = [&add_scalar_int32](void* data) {
|
||||||
auto builtin = reinterpret_cast<TfLitePoolParams*>(data);
|
auto builtin = reinterpret_cast<TfLitePoolParams*>(data);
|
||||||
@ -345,11 +348,11 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
|
|||||||
switch (builtin) {
|
switch (builtin) {
|
||||||
case tflite::BuiltinOperator_ADD:
|
case tflite::BuiltinOperator_ADD:
|
||||||
nn_op_type = ANEURALNETWORKS_ADD;
|
nn_op_type = ANEURALNETWORKS_ADD;
|
||||||
add_add_params();
|
add_add_params(node.builtin_data);
|
||||||
break;
|
break;
|
||||||
case tflite::BuiltinOperator_MUL:
|
case tflite::BuiltinOperator_MUL:
|
||||||
nn_op_type = ANEURALNETWORKS_MUL;
|
nn_op_type = ANEURALNETWORKS_MUL;
|
||||||
add_add_params();
|
add_add_params(node.builtin_data);
|
||||||
break;
|
break;
|
||||||
case tflite::BuiltinOperator_AVERAGE_POOL_2D:
|
case tflite::BuiltinOperator_AVERAGE_POOL_2D:
|
||||||
add_pooling_params(node.builtin_data);
|
add_pooling_params(node.builtin_data);
|
||||||
|
@ -2,9 +2,11 @@ package(default_visibility = ["//visibility:public"])
|
|||||||
|
|
||||||
licenses(["notice"]) # Apache 2.0
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
|
load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts")
|
||||||
|
|
||||||
common_copts = [
|
common_copts = [
|
||||||
"-Wall",
|
"-Wall",
|
||||||
]
|
] + tflite_copts()
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "profiler",
|
name = "profiler",
|
||||||
@ -36,12 +38,14 @@ cc_library(
|
|||||||
name = "time",
|
name = "time",
|
||||||
srcs = ["time.cc"],
|
srcs = ["time.cc"],
|
||||||
hdrs = ["time.h"],
|
hdrs = ["time.h"],
|
||||||
|
copts = common_copts,
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "profile_summarizer",
|
name = "profile_summarizer",
|
||||||
srcs = ["profile_summarizer.cc"],
|
srcs = ["profile_summarizer.cc"],
|
||||||
hdrs = ["profile_summarizer.h"],
|
hdrs = ["profile_summarizer.h"],
|
||||||
|
copts = common_copts,
|
||||||
deps = [
|
deps = [
|
||||||
":profiler",
|
":profiler",
|
||||||
"//tensorflow/contrib/lite:framework",
|
"//tensorflow/contrib/lite:framework",
|
||||||
@ -53,6 +57,7 @@ cc_library(
|
|||||||
cc_test(
|
cc_test(
|
||||||
name = "profile_summarizer_test",
|
name = "profile_summarizer_test",
|
||||||
srcs = ["profile_summarizer_test.cc"],
|
srcs = ["profile_summarizer_test.cc"],
|
||||||
|
copts = common_copts,
|
||||||
deps = [
|
deps = [
|
||||||
":profile_summarizer",
|
":profile_summarizer",
|
||||||
"//tensorflow/contrib/lite:framework",
|
"//tensorflow/contrib/lite:framework",
|
||||||
|
@ -111,8 +111,7 @@ def tensor_name(x):
|
|||||||
return x.name.split(":")[0]
|
return x.name.split(":")[0]
|
||||||
|
|
||||||
|
|
||||||
def toco_convert(input_data,
|
def build_toco_convert_protos(input_tensors,
|
||||||
input_tensors,
|
|
||||||
output_tensors,
|
output_tensors,
|
||||||
inference_type=lite_constants.FLOAT,
|
inference_type=lite_constants.FLOAT,
|
||||||
inference_input_type=None,
|
inference_input_type=None,
|
||||||
@ -127,21 +126,20 @@ def toco_convert(input_data,
|
|||||||
quantize_weights=False,
|
quantize_weights=False,
|
||||||
dump_graphviz_dir=None,
|
dump_graphviz_dir=None,
|
||||||
dump_graphviz_video=False):
|
dump_graphviz_video=False):
|
||||||
"""Convert a model using TOCO from `input_format` to `output_format`.
|
"""Builds protocol buffers describing a conversion of a model using TOCO.
|
||||||
|
|
||||||
Typically this is to convert from TensorFlow GraphDef to TFLite, in which
|
Typically this is to convert from TensorFlow GraphDef to TFLite, in which
|
||||||
case the default `input_format` and `output_format` are sufficient.
|
case the default `input_format` and `output_format` are sufficient.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_data: Input data (i.e. often `sess.graph_def`).
|
|
||||||
input_tensors: List of input tensors. Type and shape are computed using
|
input_tensors: List of input tensors. Type and shape are computed using
|
||||||
`foo.get_shape()` and `foo.dtype`.
|
`foo.get_shape()` and `foo.dtype`.
|
||||||
output_tensors: List of output tensors (only .name is used from this).
|
output_tensors: List of output tensors (only .name is used from this).
|
||||||
inference_type: Target data type of arrays in the output file. Currently
|
inference_type: Target data type of arrays in the output file. Currently
|
||||||
must be `{FLOAT, QUANTIZED_UINT8}`. (default FLOAT)
|
must be `{FLOAT, QUANTIZED_UINT8, STRING}`. (default FLOAT)
|
||||||
inference_input_type: Target data type of input arrays. Allows for a
|
inference_input_type: Target data type of input arrays. Allows for a
|
||||||
different type for input arrays in the case of quantization. Currently
|
different type for input arrays in the case of quantization. Currently
|
||||||
must be `{FLOAT, QUANTIZED_UINT8}`. (default `inference_type`)
|
must be `{FLOAT, QUANTIZED_UINT8, STRING}`. (default `inference_type`)
|
||||||
input_format: Type of data to read Currently must be
|
input_format: Type of data to read Currently must be
|
||||||
`{TENSORFLOW_GRAPHDEF}`. (default TENSORFLOW_GRAPHDEF)
|
`{TENSORFLOW_GRAPHDEF}`. (default TENSORFLOW_GRAPHDEF)
|
||||||
output_format: Output file format. Currently must be `{TFLITE,
|
output_format: Output file format. Currently must be `{TFLITE,
|
||||||
@ -180,8 +178,8 @@ def toco_convert(input_data,
|
|||||||
every graph transformation. (default False)
|
every graph transformation. (default False)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The converted data. For example if TFLite was the destination, then
|
model_flags, toco_flags: two protocol buffers describing the conversion
|
||||||
this will be a tflite flatbuffer in a bytes array.
|
process.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If the input tensor type is unknown
|
ValueError: If the input tensor type is unknown
|
||||||
@ -204,7 +202,6 @@ def toco_convert(input_data,
|
|||||||
if dump_graphviz_dir:
|
if dump_graphviz_dir:
|
||||||
toco.dump_graphviz_dir = dump_graphviz_dir
|
toco.dump_graphviz_dir = dump_graphviz_dir
|
||||||
toco.dump_graphviz_include_video = dump_graphviz_video
|
toco.dump_graphviz_include_video = dump_graphviz_video
|
||||||
|
|
||||||
model = _model_flags_pb2.ModelFlags()
|
model = _model_flags_pb2.ModelFlags()
|
||||||
model.change_concat_input_ranges = change_concat_input_ranges
|
model.change_concat_input_ranges = change_concat_input_ranges
|
||||||
for idx, input_tensor in enumerate(input_tensors):
|
for idx, input_tensor in enumerate(input_tensors):
|
||||||
@ -216,7 +213,8 @@ def toco_convert(input_data,
|
|||||||
tflite_input_type = lite_constants.INT64
|
tflite_input_type = lite_constants.INT64
|
||||||
elif input_tensor.dtype == _dtypes.uint8:
|
elif input_tensor.dtype == _dtypes.uint8:
|
||||||
tflite_input_type = lite_constants.QUANTIZED_UINT8
|
tflite_input_type = lite_constants.QUANTIZED_UINT8
|
||||||
# TODO(aselle): Insert strings when they are available
|
elif input_tensor.dtype == _dtypes.string:
|
||||||
|
tflite_input_type = lite_constants.STRING
|
||||||
else:
|
else:
|
||||||
raise ValueError("Tensors %s not known type %r" % (input_tensor.name,
|
raise ValueError("Tensors %s not known type %r" % (input_tensor.name,
|
||||||
input_tensor.dtype))
|
input_tensor.dtype))
|
||||||
@ -233,10 +231,35 @@ def toco_convert(input_data,
|
|||||||
|
|
||||||
for output_tensor in output_tensors:
|
for output_tensor in output_tensors:
|
||||||
model.output_arrays.append(tensor_name(output_tensor))
|
model.output_arrays.append(tensor_name(output_tensor))
|
||||||
|
return model, toco
|
||||||
|
|
||||||
# TODO(aselle): Consider handling the case of allowing quantized
|
|
||||||
# inputs to be converted to float (via the toco.inference_input_type field).
|
def toco_convert(input_data, input_tensors, output_tensors, *args, **kwargs):
|
||||||
data = toco_convert_protos(model.SerializeToString(),
|
""""Convert a model using TOCO.
|
||||||
toco.SerializeToString(),
|
|
||||||
|
Typically this function is used to convert from TensorFlow GraphDef to TFLite.
|
||||||
|
Conversion can be customized by providing arguments that are forwarded to
|
||||||
|
`build_toco_convert_protos` (see documentation for details).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_data: Input data (i.e. often `sess.graph_def`),
|
||||||
|
input_tensors: List of input tensors. Type and shape are computed using
|
||||||
|
`foo.get_shape()` and `foo.dtype`.
|
||||||
|
output_tensors: List of output tensors (only .name is used from this).
|
||||||
|
*args: See `build_toco_convert_protos`,
|
||||||
|
**kwargs: See `build_toco_convert_protos`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The converted data. For example if TFLite was the destination, then
|
||||||
|
this will be a tflite flatbuffer in a bytes array.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Defined in `build_toco_convert_protos`.
|
||||||
|
"""
|
||||||
|
model_flags, toco_flags = build_toco_convert_protos(input_tensors,
|
||||||
|
output_tensors,
|
||||||
|
*args, **kwargs)
|
||||||
|
data = toco_convert_protos(model_flags.SerializeToString(),
|
||||||
|
toco_flags.SerializeToString(),
|
||||||
input_data.SerializeToString())
|
input_data.SerializeToString())
|
||||||
return data
|
return data
|
||||||
|
@ -25,6 +25,7 @@ EXPERIMENTAL: APIs here are unstable and likely to change without notice.
|
|||||||
|
|
||||||
@@FLOAT
|
@@FLOAT
|
||||||
@@QUANTIZED_UINT8
|
@@QUANTIZED_UINT8
|
||||||
|
@@STRING
|
||||||
@@TFLITE
|
@@TFLITE
|
||||||
@@GRAPHVIZ_DOT
|
@@GRAPHVIZ_DOT
|
||||||
|
|
||||||
@ -38,6 +39,7 @@ from six import PY3
|
|||||||
from google.protobuf import text_format as _text_format
|
from google.protobuf import text_format as _text_format
|
||||||
from google.protobuf.message import DecodeError
|
from google.protobuf.message import DecodeError
|
||||||
from tensorflow.contrib.lite.python import lite_constants as constants
|
from tensorflow.contrib.lite.python import lite_constants as constants
|
||||||
|
from tensorflow.contrib.lite.python.convert import build_toco_convert_protos # pylint: disable=unused-import
|
||||||
from tensorflow.contrib.lite.python.convert import tensor_name
|
from tensorflow.contrib.lite.python.convert import tensor_name
|
||||||
from tensorflow.contrib.lite.python.convert import toco_convert
|
from tensorflow.contrib.lite.python.convert import toco_convert
|
||||||
from tensorflow.contrib.lite.python.convert import toco_convert_protos # pylint: disable=unused-import
|
from tensorflow.contrib.lite.python.convert import toco_convert_protos # pylint: disable=unused-import
|
||||||
@ -65,10 +67,10 @@ class TocoConverter(object):
|
|||||||
Attributes:
|
Attributes:
|
||||||
|
|
||||||
inference_type: Target data type of arrays in the output file. Currently
|
inference_type: Target data type of arrays in the output file. Currently
|
||||||
must be `{FLOAT, QUANTIZED_UINT8}`. (default FLOAT)
|
must be `{FLOAT, QUANTIZED_UINT8, STRING}`. (default FLOAT)
|
||||||
inference_input_type: Target data type of input arrays. Allows for a
|
inference_input_type: Target data type of input arrays. Allows for a
|
||||||
different type for input arrays in the case of quantization. Currently
|
different type for input arrays in the case of quantization. Currently
|
||||||
must be `{FLOAT, QUANTIZED_UINT8}`. (default `inference_type`)
|
must be `{FLOAT, QUANTIZED_UINT8, STRING}`. (default `inference_type`)
|
||||||
output_format: Output file format. Currently must be `{TFLITE,
|
output_format: Output file format. Currently must be `{TFLITE,
|
||||||
GRAPHVIZ_DOT}`. (default TFLITE)
|
GRAPHVIZ_DOT}`. (default TFLITE)
|
||||||
quantized_input_stats: Dict of strings representing input tensor names
|
quantized_input_stats: Dict of strings representing input tensor names
|
||||||
|
@ -116,7 +116,8 @@ def _convert_model(flags):
|
|||||||
"tensors in order to map between names and "
|
"tensors in order to map between names and "
|
||||||
"values.".format(",".join(input_arrays)))
|
"values.".format(",".join(input_arrays)))
|
||||||
converter.quantized_input_stats = dict(zip(input_arrays, quant_stats))
|
converter.quantized_input_stats = dict(zip(input_arrays, quant_stats))
|
||||||
if flags.default_ranges_min and flags.default_ranges_max:
|
if (flags.default_ranges_min is not None) and (flags.default_ranges_max is
|
||||||
|
not None):
|
||||||
converter.default_ranges_stats = (flags.default_ranges_min,
|
converter.default_ranges_stats = (flags.default_ranges_min,
|
||||||
flags.default_ranges_max)
|
flags.default_ranges_max)
|
||||||
|
|
||||||
@ -195,7 +196,7 @@ def _check_flags(flags, unparsed):
|
|||||||
raise ValueError("--std_dev_values, --mean_values must have the same "
|
raise ValueError("--std_dev_values, --mean_values must have the same "
|
||||||
"number of items")
|
"number of items")
|
||||||
|
|
||||||
if bool(flags.default_ranges_min) != bool(flags.default_ranges_max):
|
if (flags.default_ranges_min is None) != (flags.default_ranges_max is None):
|
||||||
raise ValueError("--default_ranges_min and --default_ranges_max must be "
|
raise ValueError("--default_ranges_min and --default_ranges_max must be "
|
||||||
"used together")
|
"used together")
|
||||||
|
|
||||||
@ -233,12 +234,12 @@ def run_main(_):
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--inference_type",
|
"--inference_type",
|
||||||
type=str.upper,
|
type=str.upper,
|
||||||
choices=["FLOAT", "QUANTIZED_UINT8"],
|
choices=["FLOAT", "QUANTIZED_UINT8", "STRING"],
|
||||||
help="Target data type of arrays in the output file.")
|
help="Target data type of arrays in the output file.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--inference_input_type",
|
"--inference_input_type",
|
||||||
type=str.upper,
|
type=str.upper,
|
||||||
choices=["FLOAT", "QUANTIZED_UINT8"],
|
choices=["FLOAT", "QUANTIZED_UINT8", "STRING"],
|
||||||
help=("Target data type of input arrays. Allows for a different type for "
|
help=("Target data type of input arrays. Allows for a different type for "
|
||||||
"input arrays in the case of quantization."))
|
"input arrays in the case of quantization."))
|
||||||
|
|
||||||
|
@ -39,7 +39,7 @@ limitations under the License.
|
|||||||
#define TENSORFLOW_CONTRIB_LITE_BUILTIN_OPS_H_
|
#define TENSORFLOW_CONTRIB_LITE_BUILTIN_OPS_H_
|
||||||
|
|
||||||
// DO NOT EDIT MANUALLY: This file is automatically generated by
|
// DO NOT EDIT MANUALLY: This file is automatically generated by
|
||||||
// `schema_builtin_ops_header_generator.py`.
|
// `schema/builtin_ops_header/generator.cc`.
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
extern "C" {
|
extern "C" {
|
||||||
|
@ -362,6 +362,8 @@ bool HardcodeMinMax::Run(Model* model, std::size_t op_index) {
|
|||||||
changed = HardcodeMinMaxForAverageOrMaxPool(model, op);
|
changed = HardcodeMinMaxForAverageOrMaxPool(model, op);
|
||||||
break;
|
break;
|
||||||
|
|
||||||
|
case OperatorType::kResizeBilinear:
|
||||||
|
case OperatorType::kSlice:
|
||||||
case OperatorType::kStridedSlice:
|
case OperatorType::kStridedSlice:
|
||||||
case OperatorType::kSqueeze:
|
case OperatorType::kSqueeze:
|
||||||
case OperatorType::kTensorFlowReshape:
|
case OperatorType::kTensorFlowReshape:
|
||||||
|
@ -45,12 +45,14 @@ bool SupportsQuantization(const Operator& op) {
|
|||||||
type == OperatorType::kTensorFlowMinimum ||
|
type == OperatorType::kTensorFlowMinimum ||
|
||||||
type == OperatorType::kTensorFlowMaximum ||
|
type == OperatorType::kTensorFlowMaximum ||
|
||||||
type == OperatorType::kLogistic || type == OperatorType::kSoftmax ||
|
type == OperatorType::kLogistic || type == OperatorType::kSoftmax ||
|
||||||
type == OperatorType::kLogSoftmax ||
|
type == OperatorType::kLogSoftmax || type == OperatorType::kSlice ||
|
||||||
|
type == OperatorType::kResizeBilinear ||
|
||||||
type == OperatorType::kTensorFlowSplit || type == OperatorType::kSub ||
|
type == OperatorType::kTensorFlowSplit || type == OperatorType::kSub ||
|
||||||
type == OperatorType::kSqueeze || type == OperatorType::kPad ||
|
type == OperatorType::kSqueeze || type == OperatorType::kPad ||
|
||||||
type == OperatorType::kPadV2 ||
|
type == OperatorType::kPadV2 ||
|
||||||
type == OperatorType::kTensorFlowReshape ||
|
type == OperatorType::kTensorFlowReshape ||
|
||||||
type == OperatorType::kTanh || type == OperatorType::kMul ||
|
type == OperatorType::kTanh || type == OperatorType::kMul ||
|
||||||
|
type == OperatorType::kSpaceToBatchND ||
|
||||||
type == OperatorType::kSpaceToDepth ||
|
type == OperatorType::kSpaceToDepth ||
|
||||||
type == OperatorType::kStridedSlice ||
|
type == OperatorType::kStridedSlice ||
|
||||||
type == OperatorType::kDepthToSpace ||
|
type == OperatorType::kDepthToSpace ||
|
||||||
|
@ -920,7 +920,7 @@ void CheckEachArray(const Model& model) {
|
|||||||
CHECK(array->buffer->type == array->data_type);
|
CHECK(array->buffer->type == array->data_type);
|
||||||
// The presence of a fixed buffer should imply the presence of a fixed
|
// The presence of a fixed buffer should imply the presence of a fixed
|
||||||
// shape.
|
// shape.
|
||||||
CHECK(array->has_shape());
|
CHECK(array->has_shape()) << "Invalid array: " << array_entry.first;
|
||||||
// Constant buffer should has a valid shape.
|
// Constant buffer should has a valid shape.
|
||||||
for (int d : array->shape().dims()) {
|
for (int d : array->shape().dims()) {
|
||||||
CHECK_GE(d, 1);
|
CHECK_GE(d, 1);
|
||||||
|
@ -8,7 +8,7 @@ load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite"
|
|||||||
load("//tensorflow/contrib/lite:build_def.bzl", "tflite_linkopts")
|
load("//tensorflow/contrib/lite:build_def.bzl", "tflite_linkopts")
|
||||||
load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts")
|
load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts")
|
||||||
|
|
||||||
common_copts = ["-Wall"]
|
common_copts = ["-Wall"] + tflite_copts()
|
||||||
|
|
||||||
cc_binary(
|
cc_binary(
|
||||||
name = "benchmark_model",
|
name = "benchmark_model",
|
||||||
@ -16,14 +16,11 @@ cc_binary(
|
|||||||
"benchmark_main.cc",
|
"benchmark_main.cc",
|
||||||
"logging.h",
|
"logging.h",
|
||||||
],
|
],
|
||||||
copts = tflite_copts() + common_copts,
|
copts = common_copts,
|
||||||
linkopts = select({
|
linkopts = tflite_linkopts() + select({
|
||||||
"//tensorflow:android": [
|
"//tensorflow:android": [
|
||||||
"-pie",
|
"-pie", # Android 5.0 and later supports only PIE
|
||||||
"-landroid",
|
"-lm", # some builtin ops, e.g., tanh, need -lm
|
||||||
"-lm",
|
|
||||||
"-z defs",
|
|
||||||
"-Wl,--exclude-libs,ALL", # Exclude syms in all libs from auto export
|
|
||||||
],
|
],
|
||||||
"//conditions:default": [],
|
"//conditions:default": [],
|
||||||
}),
|
}),
|
||||||
|
@ -53,7 +53,7 @@ tf_cc_binary(
|
|||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
|
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
|
||||||
"//tensorflow/core/platform/cloud:gcs_file_system",
|
"//tensorflow/core/platform/cloud:gcs_file_system",
|
||||||
"@grpc//:grpc++_unsecure",
|
"@grpc//:grpc++",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -58,7 +58,7 @@ cc_library(
|
|||||||
"//tensorflow/core/distributed_runtime/rpc:async_service_interface",
|
"//tensorflow/core/distributed_runtime/rpc:async_service_interface",
|
||||||
"//tensorflow/core/distributed_runtime/rpc:grpc_call",
|
"//tensorflow/core/distributed_runtime/rpc:grpc_call",
|
||||||
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
|
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
|
||||||
"@grpc//:grpc++_unsecure",
|
"@grpc//:grpc++",
|
||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
@ -69,7 +69,7 @@ cc_library(
|
|||||||
hdrs = ["grpc_verbs_service_impl.h"],
|
hdrs = ["grpc_verbs_service_impl.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":verbs_service_proto_cc",
|
":verbs_service_proto_cc",
|
||||||
"@grpc//:grpc++_unsecure",
|
"@grpc//:grpc++",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -879,6 +879,7 @@ cc_library(
|
|||||||
hdrs = [
|
hdrs = [
|
||||||
"util/stats_calculator.h",
|
"util/stats_calculator.h",
|
||||||
],
|
],
|
||||||
|
copts = tf_copts(),
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
|
10
tensorflow/core/api_def/base_api/api_def_BesselI0e.pbtxt
Normal file
10
tensorflow/core/api_def/base_api/api_def_BesselI0e.pbtxt
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "BesselI0e"
|
||||||
|
summary: "Computes the Bessel i0e function of `x` element-wise."
|
||||||
|
description: <<END
|
||||||
|
Exponentially scaled modified Bessel function of order 0 defined as
|
||||||
|
`bessel_i0e(x) = exp(-abs(x)) bessel_i0(x)`.
|
||||||
|
|
||||||
|
This function is faster and numerically stabler than `bessel_i0(x)`.
|
||||||
|
END
|
||||||
|
}
|
10
tensorflow/core/api_def/base_api/api_def_BesselI1e.pbtxt
Normal file
10
tensorflow/core/api_def/base_api/api_def_BesselI1e.pbtxt
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "BesselI1e"
|
||||||
|
summary: "Computes the Bessel i1e function of `x` element-wise."
|
||||||
|
description: <<END
|
||||||
|
Exponentially scaled modified Bessel function of order 0 defined as
|
||||||
|
`bessel_i1e(x) = exp(-abs(x)) bessel_i1(x)`.
|
||||||
|
|
||||||
|
This function is faster and numerically stabler than `bessel_i1(x)`.
|
||||||
|
END
|
||||||
|
}
|
@ -0,0 +1,4 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "BesselI0e"
|
||||||
|
visibility: HIDDEN
|
||||||
|
}
|
@ -0,0 +1,4 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "BesselI1e"
|
||||||
|
visibility: HIDDEN
|
||||||
|
}
|
@ -66,29 +66,27 @@ int StepStatsDeviceIndex(StepStats* step_stats, EagerContext* ctx,
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ValidateInputTypeAndPlacement(EagerContext* ctx, Device* op_device,
|
// This function expects *handle to point to an existing tensor handle. The
|
||||||
EagerOperation* op, const OpKernel* kernel,
|
// function will (maybe) update the *handle to be pointed to the newly copied
|
||||||
RunMetadata* run_metadata) {
|
// tensor handle.
|
||||||
Device* host_device = ctx->HostCPU();
|
//
|
||||||
const MemoryTypeVector& memtypes = kernel->input_memory_types();
|
// The passed in *handle will be Unreffed if it is replaced.
|
||||||
if (memtypes.size() != op->Inputs().size()) {
|
Status MaybeCopyInputToExpectedDevice(EagerOperation* op, int i,
|
||||||
return errors::InvalidArgument("expected ", memtypes.size(),
|
const Device* expected_device,
|
||||||
" inputs, got ", op->Inputs().size());
|
RunMetadata* run_metadata,
|
||||||
}
|
TensorHandle** handle) {
|
||||||
for (int i = 0; i < op->Inputs().size(); ++i) {
|
EagerContext* ctx = op->EagerContext();
|
||||||
const Device* expected_device =
|
|
||||||
memtypes[i] == HOST_MEMORY ? host_device : op_device;
|
|
||||||
TensorHandle* handle = op->Inputs()[i];
|
|
||||||
Device* handle_device = nullptr;
|
Device* handle_device = nullptr;
|
||||||
TF_RETURN_IF_ERROR(handle->Device(&handle_device));
|
TF_RETURN_IF_ERROR((*handle)->Device(&handle_device));
|
||||||
const Device* actual_device =
|
const Device* actual_device =
|
||||||
handle_device == nullptr ? host_device : handle_device;
|
handle_device == nullptr ? ctx->HostCPU() : handle_device;
|
||||||
|
|
||||||
if (expected_device != actual_device) {
|
if (expected_device != actual_device) {
|
||||||
switch (ctx->GetDevicePlacementPolicy()) {
|
switch (ctx->GetDevicePlacementPolicy()) {
|
||||||
case DEVICE_PLACEMENT_SILENT_FOR_INT32:
|
case DEVICE_PLACEMENT_SILENT_FOR_INT32:
|
||||||
// TODO(xpan): See if we could bubble python related error up
|
// TODO(xpan): See if we could bubble python related error up
|
||||||
// to python level.
|
// to python level.
|
||||||
if (handle->dtype == DT_INT32) {
|
if ((*handle)->dtype == DT_INT32) {
|
||||||
// Note: enabling silent copies of int32 tensors to match behavior
|
// Note: enabling silent copies of int32 tensors to match behavior
|
||||||
// of graph mode.
|
// of graph mode.
|
||||||
break;
|
break;
|
||||||
@ -101,7 +99,7 @@ Status ValidateInputTypeAndPlacement(EagerContext* ctx, Device* op_device,
|
|||||||
op->Name(), " as input #", i, " was expected to be on ",
|
op->Name(), " as input #", i, " was expected to be on ",
|
||||||
expected_device->name(), " but is actually on ",
|
expected_device->name(), " but is actually on ",
|
||||||
actual_device->name(), " (operation running on ",
|
actual_device->name(), " (operation running on ",
|
||||||
op_device->name(), ")",
|
op->Device()->name(), ")",
|
||||||
" Tensors can be copied explicitly using .gpu() or .cpu() "
|
" Tensors can be copied explicitly using .gpu() or .cpu() "
|
||||||
"methods,"
|
"methods,"
|
||||||
" or transparently copied by using tf.enable_eager_execution("
|
" or transparently copied by using tf.enable_eager_execution("
|
||||||
@ -112,7 +110,7 @@ Status ValidateInputTypeAndPlacement(EagerContext* ctx, Device* op_device,
|
|||||||
LOG(WARNING) << "before computing " << op->Name() << " input #" << i
|
LOG(WARNING) << "before computing " << op->Name() << " input #" << i
|
||||||
<< " was expected to be on " << expected_device->name()
|
<< " was expected to be on " << expected_device->name()
|
||||||
<< " but is actually on " << actual_device->name()
|
<< " but is actually on " << actual_device->name()
|
||||||
<< " (operation running on " << op_device->name()
|
<< " (operation running on " << op->Device()->name()
|
||||||
<< "). This triggers a copy which can be a performance "
|
<< "). This triggers a copy which can be a performance "
|
||||||
"bottleneck.";
|
"bottleneck.";
|
||||||
break;
|
break;
|
||||||
@ -122,9 +120,9 @@ Status ValidateInputTypeAndPlacement(EagerContext* ctx, Device* op_device,
|
|||||||
// We are only here if the policy is warn or silent copies, so we should
|
// We are only here if the policy is warn or silent copies, so we should
|
||||||
// trigger a copy.
|
// trigger a copy.
|
||||||
auto pre_time = Env::Default()->NowMicros();
|
auto pre_time = Env::Default()->NowMicros();
|
||||||
TensorHandle* copied_tensor = nullptr;
|
TensorHandle* result_handle;
|
||||||
Status status = EagerCopyToDevice(
|
Status status = EagerCopyToDevice(
|
||||||
handle, ctx, expected_device->name().c_str(), &copied_tensor);
|
*handle, ctx, expected_device->name().c_str(), &result_handle);
|
||||||
if (run_metadata != nullptr) {
|
if (run_metadata != nullptr) {
|
||||||
auto* step_stats = run_metadata->mutable_step_stats();
|
auto* step_stats = run_metadata->mutable_step_stats();
|
||||||
MaybeInitializeStepStats(step_stats, ctx);
|
MaybeInitializeStepStats(step_stats, ctx);
|
||||||
@ -134,20 +132,37 @@ Status ValidateInputTypeAndPlacement(EagerContext* ctx, Device* op_device,
|
|||||||
auto* node_stats = dev_stats->add_node_stats();
|
auto* node_stats = dev_stats->add_node_stats();
|
||||||
node_stats->set_node_name("_Send");
|
node_stats->set_node_name("_Send");
|
||||||
node_stats->set_all_start_micros(pre_time);
|
node_stats->set_all_start_micros(pre_time);
|
||||||
node_stats->set_op_end_rel_micros(Env::Default()->NowMicros() -
|
node_stats->set_op_end_rel_micros(Env::Default()->NowMicros() - pre_time);
|
||||||
pre_time);
|
|
||||||
}
|
}
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
if (copied_tensor != nullptr) copied_tensor->Unref();
|
if (result_handle != nullptr) result_handle->Unref();
|
||||||
return errors::Internal("Failed copying input tensor from ",
|
return errors::Internal("Failed copying input tensor from ",
|
||||||
actual_device->name(), " to ",
|
actual_device->name(), " to ",
|
||||||
expected_device->name(), " in order to run ",
|
expected_device->name(), " in order to run ",
|
||||||
op->Name(), ": ", status.error_message());
|
op->Name(), ": ", status.error_message());
|
||||||
}
|
}
|
||||||
handle->Unref();
|
|
||||||
handle = copied_tensor;
|
(*handle)->Unref();
|
||||||
(*op->MutableInputs())[i] = copied_tensor;
|
*handle = result_handle;
|
||||||
}
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status ValidateInputTypeAndPlacement(EagerContext* ctx, Device* op_device,
|
||||||
|
EagerOperation* op, const OpKernel* kernel,
|
||||||
|
RunMetadata* run_metadata) {
|
||||||
|
Device* host_device = ctx->HostCPU();
|
||||||
|
const MemoryTypeVector& memtypes = kernel->input_memory_types();
|
||||||
|
if (memtypes.size() != op->Inputs().size()) {
|
||||||
|
return errors::InvalidArgument("expected ", memtypes.size(),
|
||||||
|
" inputs, got ", op->Inputs().size());
|
||||||
|
}
|
||||||
|
for (int i = 0; i < op->Inputs().size(); ++i) {
|
||||||
|
const Device* expected_device =
|
||||||
|
memtypes[i] == HOST_MEMORY ? host_device : op_device;
|
||||||
|
TF_RETURN_IF_ERROR(MaybeCopyInputToExpectedDevice(
|
||||||
|
op, i, expected_device, run_metadata, &((*op->MutableInputs())[i])));
|
||||||
|
tensorflow::TensorHandle* handle = op->Inputs()[i];
|
||||||
if (handle->dtype != kernel->input_type(i)) {
|
if (handle->dtype != kernel->input_type(i)) {
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"cannot compute ", op->Name(), " as input #", i,
|
"cannot compute ", op->Name(), " as input #", i,
|
||||||
@ -192,8 +207,8 @@ Status SelectDevice(const NodeDef& ndef, EagerContext* ctx, Device** device) {
|
|||||||
// Resource4> as the input params to the synthesized function.
|
// Resource4> as the input params to the synthesized function.
|
||||||
//
|
//
|
||||||
// It populates `const_input_types`, `arg_input_types` and
|
// It populates `const_input_types`, `arg_input_types` and
|
||||||
// `op_input_to_func_input` based on the reordering results, that the caller can
|
// `op_input_to_func_input` based on the reordering results, that the caller
|
||||||
// use them to build an XlaLaunch. On error, it returns NULL, and sets
|
// can use them to build an XlaLaunch. On error, it returns NULL, and sets
|
||||||
// `status` accordingly.
|
// `status` accordingly.
|
||||||
const FunctionDef* OpToFunction(TFE_Op* op,
|
const FunctionDef* OpToFunction(TFE_Op* op,
|
||||||
std::vector<TF_DataType>* const_input_types,
|
std::vector<TF_DataType>* const_input_types,
|
||||||
@ -221,8 +236,8 @@ const FunctionDef* OpToFunction(TFE_Op* op,
|
|||||||
const std::unordered_set<string> const_inputs(
|
const std::unordered_set<string> const_inputs(
|
||||||
*XlaOpRegistry::CompileTimeConstantInputs(op->operation.Name()));
|
*XlaOpRegistry::CompileTimeConstantInputs(op->operation.Name()));
|
||||||
|
|
||||||
// First add place holders for the input args, so that we can refer to them by
|
// First add place holders for the input args, so that we can refer to them
|
||||||
// position in the next loop. Also tally up the resource inputs.
|
// by position in the next loop. Also tally up the resource inputs.
|
||||||
int num_resource_inputs = 0;
|
int num_resource_inputs = 0;
|
||||||
for (int i = 0; i < op_def.input_arg_size(); ++i) {
|
for (int i = 0; i < op_def.input_arg_size(); ++i) {
|
||||||
if (op_def.input_arg(i).type() == DT_RESOURCE) {
|
if (op_def.input_arg(i).type() == DT_RESOURCE) {
|
||||||
@ -336,8 +351,9 @@ std::unique_ptr<TFE_Op> BuildXlaLaunch(TFE_Op* op, TF_Status* status) {
|
|||||||
&op_input_to_func_input, status);
|
&op_input_to_func_input, status);
|
||||||
if (!status.ok()) return nullptr;
|
if (!status.ok()) return nullptr;
|
||||||
} else {
|
} else {
|
||||||
// TODO(hongm): XlaOpRegistry::CompileTimeConstantInputs() does not work for
|
// TODO(hongm): XlaOpRegistry::CompileTimeConstantInputs() does not work
|
||||||
// functions, so we need to find another way to handle constant inputs.
|
// for functions, so we need to find another way to handle constant
|
||||||
|
// inputs.
|
||||||
for (int i = const_input_types.size();
|
for (int i = const_input_types.size();
|
||||||
i < fdef->signature().input_arg_size(); ++i) {
|
i < fdef->signature().input_arg_size(); ++i) {
|
||||||
VLOG(1) << "Adding Targs from input arg " << i;
|
VLOG(1) << "Adding Targs from input arg " << i;
|
||||||
@ -348,8 +364,9 @@ std::unique_ptr<TFE_Op> BuildXlaLaunch(TFE_Op* op, TF_Status* status) {
|
|||||||
DCHECK(fdef != nullptr);
|
DCHECK(fdef != nullptr);
|
||||||
|
|
||||||
// Copy inputs and their devices.
|
// Copy inputs and their devices.
|
||||||
// Since input param reordering may have occurred between `op` and `launch_op`
|
// Since input param reordering may have occurred between `op` and
|
||||||
// via `op_input_to_func_input`, adjust the actual inputs accordingly.
|
// `launch_op` via `op_input_to_func_input`, adjust the actual inputs
|
||||||
|
// accordingly.
|
||||||
*launch_op->operation.MutableInputs() = op->operation.Inputs();
|
*launch_op->operation.MutableInputs() = op->operation.Inputs();
|
||||||
for (TensorHandle* h : launch_op->operation.Inputs()) {
|
for (TensorHandle* h : launch_op->operation.Inputs()) {
|
||||||
h->Ref();
|
h->Ref();
|
||||||
@ -545,24 +562,24 @@ Status EagerLocalExecute(EagerOperation* op,
|
|||||||
Status EagerRemoteExecute(EagerOperation* op, eager::EagerClient* eager_client,
|
Status EagerRemoteExecute(EagerOperation* op, eager::EagerClient* eager_client,
|
||||||
uint64 context_id, TensorHandle** retvals,
|
uint64 context_id, TensorHandle** retvals,
|
||||||
int* num_retvals) {
|
int* num_retvals) {
|
||||||
// All tensors must be on the same device.
|
|
||||||
// TODO(nareshmodi): handle silent copies
|
|
||||||
eager::EnqueueRequest request;
|
eager::EnqueueRequest request;
|
||||||
eager::EnqueueResponse response;
|
eager::EnqueueResponse response;
|
||||||
|
|
||||||
auto* remote_op = request.add_queue()->mutable_operation();
|
auto* remote_op = request.add_queue()->mutable_operation();
|
||||||
|
|
||||||
for (auto* input : op->Inputs()) {
|
for (int i = 0; i < op->Inputs().size(); i++) {
|
||||||
tensorflow::Device* input_device;
|
tensorflow::Device* input_device;
|
||||||
TF_RETURN_IF_ERROR(input->Device(&input_device));
|
TF_RETURN_IF_ERROR(op->Inputs()[i]->Device(&input_device));
|
||||||
if (op->Device() != input_device) {
|
if (op->Device() != input_device) {
|
||||||
return tensorflow::errors::InvalidArgument(
|
// TODO(b/110044833): It's possible the same tensor gets copied to the
|
||||||
"Ops and inputs are not on the same device. Use "
|
// remote device repeatedly.
|
||||||
"TFE_TensorHandleCopyToDevice to get ops on the same "
|
TF_RETURN_IF_ERROR(MaybeCopyInputToExpectedDevice(
|
||||||
"device. Expected device: ",
|
op, i, op->Device(), /* run_metadata= */ nullptr,
|
||||||
op->Device()->name(), ", Actual device: ", input_device->name());
|
&(*op->MutableInputs())[i]));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tensorflow::TensorHandle* input = op->Inputs()[i];
|
||||||
|
|
||||||
tensorflow::uint64 op_id;
|
tensorflow::uint64 op_id;
|
||||||
int32 output_num;
|
int32 output_num;
|
||||||
TF_RETURN_IF_ERROR(input->RemoteAddress(&op_id, &output_num));
|
TF_RETURN_IF_ERROR(input->RemoteAddress(&op_id, &output_num));
|
||||||
|
@ -42,7 +42,7 @@ load(
|
|||||||
# Check that tensorflow/core:tensorflow does not depend on grpc.
|
# Check that tensorflow/core:tensorflow does not depend on grpc.
|
||||||
check_deps(
|
check_deps(
|
||||||
name = "core_tensorflow_check_deps",
|
name = "core_tensorflow_check_deps",
|
||||||
disallowed_deps = ["@grpc//:grpc++_unsecure"],
|
disallowed_deps = ["@grpc//:grpc++"],
|
||||||
deps = ["//tensorflow/core:tensorflow"],
|
deps = ["//tensorflow/core:tensorflow"],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -150,7 +150,7 @@ tf_cuda_library(
|
|||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
"//tensorflow/core:proto_text",
|
"//tensorflow/core:proto_text",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"@grpc//:grpc++_unsecure",
|
"@grpc//:grpc++",
|
||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
@ -170,7 +170,7 @@ tf_cuda_library(
|
|||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"@grpc//:grpc++_unsecure",
|
"@grpc//:grpc++",
|
||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
@ -649,7 +649,7 @@ tf_cuda_cc_test(
|
|||||||
"//tensorflow/core/kernels:dense_update_ops",
|
"//tensorflow/core/kernels:dense_update_ops",
|
||||||
"//tensorflow/core/kernels:identity_op",
|
"//tensorflow/core/kernels:identity_op",
|
||||||
"//tensorflow/core/kernels:variable_ops",
|
"//tensorflow/core/kernels:variable_ops",
|
||||||
"@grpc//:grpc++_unsecure",
|
"@grpc//:grpc++",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -682,7 +682,7 @@ tf_cuda_cc_test(
|
|||||||
"//tensorflow/core/distributed_runtime/rpc:grpc_testlib",
|
"//tensorflow/core/distributed_runtime/rpc:grpc_testlib",
|
||||||
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
|
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
|
||||||
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache",
|
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache",
|
||||||
"@grpc//:grpc++_unsecure",
|
"@grpc//:grpc++",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -65,8 +65,8 @@ cc_library(
|
|||||||
"//tensorflow/core/distributed_runtime:worker_env",
|
"//tensorflow/core/distributed_runtime:worker_env",
|
||||||
"//tensorflow/core/distributed_runtime/eager:remote_tensor_handle",
|
"//tensorflow/core/distributed_runtime/eager:remote_tensor_handle",
|
||||||
"//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr",
|
"//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr",
|
||||||
"@grpc//:grpc++_unsecure",
|
"@grpc",
|
||||||
"@grpc//:grpc_unsecure",
|
"@grpc//:grpc++",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -41,8 +41,8 @@ cc_library(
|
|||||||
srcs = ["grpc_util.cc"],
|
srcs = ["grpc_util.cc"],
|
||||||
hdrs = ["grpc_util.h"],
|
hdrs = ["grpc_util.h"],
|
||||||
deps = [
|
deps = [
|
||||||
"@grpc//:grpc_unsecure",
|
"@grpc",
|
||||||
"@grpc//:grpc++_unsecure",
|
"@grpc//:grpc++",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
# Required to be able to overload TensorResponse parsing.
|
# Required to be able to overload TensorResponse parsing.
|
||||||
"//tensorflow/core/distributed_runtime:tensor_coding",
|
"//tensorflow/core/distributed_runtime:tensor_coding",
|
||||||
@ -56,7 +56,7 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":grpc_util",
|
":grpc_util",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"@grpc//:grpc++_unsecure",
|
"@grpc//:grpc++",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -70,7 +70,7 @@ cc_library(
|
|||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core/distributed_runtime:call_options",
|
"//tensorflow/core/distributed_runtime:call_options",
|
||||||
"//tensorflow/core/distributed_runtime:tensor_coding",
|
"//tensorflow/core/distributed_runtime:tensor_coding",
|
||||||
"@grpc//:grpc++_unsecure",
|
"@grpc//:grpc++",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -90,7 +90,7 @@ cc_library(
|
|||||||
"//tensorflow/core/distributed_runtime:tensor_coding",
|
"//tensorflow/core/distributed_runtime:tensor_coding",
|
||||||
"//tensorflow/core/distributed_runtime:worker_cache_logger",
|
"//tensorflow/core/distributed_runtime:worker_cache_logger",
|
||||||
"//tensorflow/core/distributed_runtime:worker_interface",
|
"//tensorflow/core/distributed_runtime:worker_interface",
|
||||||
"@grpc//:grpc++_unsecure",
|
"@grpc//:grpc++",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -103,7 +103,7 @@ cc_library(
|
|||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
"@grpc//:grpc++_unsecure",
|
"@grpc//:grpc++",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -118,7 +118,7 @@ cc_library(
|
|||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core:worker_proto_cc",
|
"//tensorflow/core:worker_proto_cc",
|
||||||
"@grpc//:grpc++_unsecure",
|
"@grpc//:grpc++",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -129,7 +129,7 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
"@grpc//:grpc++_unsecure",
|
"@grpc//:grpc++",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -180,7 +180,7 @@ tf_cuda_library(
|
|||||||
"//tensorflow/core/distributed_runtime:worker_cache",
|
"//tensorflow/core/distributed_runtime:worker_cache",
|
||||||
"//tensorflow/core/distributed_runtime:worker_env",
|
"//tensorflow/core/distributed_runtime:worker_env",
|
||||||
"//tensorflow/core/distributed_runtime:worker_session",
|
"//tensorflow/core/distributed_runtime:worker_session",
|
||||||
"@grpc//:grpc++_unsecure",
|
"@grpc//:grpc++",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -192,7 +192,7 @@ cc_library(
|
|||||||
":grpc_util",
|
":grpc_util",
|
||||||
"//tensorflow/core:worker_proto_cc",
|
"//tensorflow/core:worker_proto_cc",
|
||||||
"//tensorflow/core/distributed_runtime:tensor_coding",
|
"//tensorflow/core/distributed_runtime:tensor_coding",
|
||||||
"@grpc//:grpc++_unsecure",
|
"@grpc//:grpc++",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -225,7 +225,7 @@ cc_library(
|
|||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
"//tensorflow/core:master_proto_cc",
|
"//tensorflow/core:master_proto_cc",
|
||||||
"//tensorflow/core/distributed_runtime:master",
|
"//tensorflow/core/distributed_runtime:master",
|
||||||
"@grpc//:grpc++_unsecure",
|
"@grpc//:grpc++",
|
||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
@ -236,7 +236,7 @@ cc_library(
|
|||||||
hdrs = ["grpc_master_service_impl.h"],
|
hdrs = ["grpc_master_service_impl.h"],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/core:master_proto_cc",
|
"//tensorflow/core:master_proto_cc",
|
||||||
"@grpc//:grpc++_unsecure",
|
"@grpc//:grpc++",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -285,8 +285,8 @@ cc_library(
|
|||||||
"//tensorflow/core/distributed_runtime:server_lib",
|
"//tensorflow/core/distributed_runtime:server_lib",
|
||||||
"//tensorflow/core/distributed_runtime:session_mgr",
|
"//tensorflow/core/distributed_runtime:session_mgr",
|
||||||
"//tensorflow/core/distributed_runtime:worker_env",
|
"//tensorflow/core/distributed_runtime:worker_env",
|
||||||
"@grpc//:grpc++_unsecure",
|
"@grpc",
|
||||||
"@grpc//:grpc_unsecure",
|
"@grpc//:grpc++",
|
||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
@ -313,7 +313,7 @@ tf_cc_binary(
|
|||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/distributed_runtime:server_lib",
|
"//tensorflow/core/distributed_runtime:server_lib",
|
||||||
"//tensorflow/core/kernels:data_flow",
|
"//tensorflow/core/kernels:data_flow",
|
||||||
"@grpc//:grpc++_unsecure",
|
"@grpc//:grpc++",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -338,7 +338,7 @@ tf_cc_binary(
|
|||||||
"//tensorflow/core/kernels:matmul_op",
|
"//tensorflow/core/kernels:matmul_op",
|
||||||
"//tensorflow/core/kernels:reduction_ops",
|
"//tensorflow/core/kernels:reduction_ops",
|
||||||
"//tensorflow/core/kernels:variable_ops",
|
"//tensorflow/core/kernels:variable_ops",
|
||||||
"@grpc//:grpc++_unsecure",
|
"@grpc//:grpc++",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -432,7 +432,7 @@ tf_cc_test(
|
|||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
"//tensorflow/core:testlib",
|
"//tensorflow/core:testlib",
|
||||||
"//tensorflow/core:worker_proto_cc",
|
"//tensorflow/core:worker_proto_cc",
|
||||||
"@grpc//:grpc++_unsecure",
|
"@grpc//:grpc++",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -445,8 +445,8 @@ tf_cc_test(
|
|||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
"//tensorflow/core:worker_proto_cc",
|
"//tensorflow/core:worker_proto_cc",
|
||||||
"@grpc//:grpc++_unsecure",
|
"@grpc",
|
||||||
"@grpc//:grpc_unsecure",
|
"@grpc//:grpc++",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@ cc_library(
|
|||||||
hdrs = ["grpc_eager_service.h"],
|
hdrs = ["grpc_eager_service.h"],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/core:eager_service_proto_cc",
|
"//tensorflow/core:eager_service_proto_cc",
|
||||||
"@grpc//:grpc++_unsecure",
|
"@grpc//:grpc++",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -29,7 +29,7 @@ cc_library(
|
|||||||
"//tensorflow/core/distributed_runtime/rpc:grpc_state",
|
"//tensorflow/core/distributed_runtime/rpc:grpc_state",
|
||||||
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
|
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
|
||||||
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_service",
|
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_service",
|
||||||
"@grpc//:grpc++_unsecure",
|
"@grpc//:grpc++",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -48,7 +48,7 @@ cc_library(
|
|||||||
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
|
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
|
||||||
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache",
|
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache",
|
||||||
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_service",
|
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_service",
|
||||||
"@grpc//:grpc++_unsecure",
|
"@grpc//:grpc++",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
29
tensorflow/core/kernels/cwise_op_bessel.cc
Normal file
29
tensorflow/core/kernels/cwise_op_bessel.cc
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/kernels/cwise_ops_common.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
REGISTER3(UnaryOp, CPU, "BesselI0e", functor::bessel_i0e, Eigen::half, float,
|
||||||
|
double);
|
||||||
|
REGISTER3(UnaryOp, CPU, "BesselI1e", functor::bessel_i1e, Eigen::half, float,
|
||||||
|
double);
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
REGISTER3(UnaryOp, GPU, "BesselI0e", functor::bessel_i0e, Eigen::half, float,
|
||||||
|
double);
|
||||||
|
REGISTER3(UnaryOp, GPU, "BesselI1e", functor::bessel_i1e, Eigen::half, float,
|
||||||
|
double);
|
||||||
|
#endif
|
||||||
|
} // namespace tensorflow
|
27
tensorflow/core/kernels/cwise_op_bessel.cu.cc
Normal file
27
tensorflow/core/kernels/cwise_op_bessel.cu.cc
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
|
||||||
|
#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace functor {
|
||||||
|
DEFINE_UNARY3(bessel_i0e, Eigen::half, float, double);
|
||||||
|
DEFINE_UNARY3(bessel_i1e, Eigen::half, float, double);
|
||||||
|
} // namespace functor
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // GOOGLE_CUDA
|
@ -616,6 +616,12 @@ struct acos : base<T, Eigen::internal::scalar_acos_op<T>> {};
|
|||||||
template <typename T>
|
template <typename T>
|
||||||
struct atan : base<T, Eigen::internal::scalar_atan_op<T>> {};
|
struct atan : base<T, Eigen::internal::scalar_atan_op<T>> {};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct bessel_i0e : base<T, Eigen::internal::scalar_i0e_op<T>> {};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct bessel_i1e : base<T, Eigen::internal::scalar_i1e_op<T>> {};
|
||||||
|
|
||||||
struct logical_not : base<bool, Eigen::internal::scalar_boolean_not_op<bool>> {
|
struct logical_not : base<bool, Eigen::internal::scalar_boolean_not_op<bool>> {
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -207,12 +207,6 @@ class IteratorResource : public ResourceBase {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
std::shared_ptr<StatsAggregator> stats_aggregator() {
|
|
||||||
tf_shared_lock l(mu_);
|
|
||||||
return stats_aggregator_;
|
|
||||||
}
|
|
||||||
|
|
||||||
string DebugString() override { return "Iterator resource"; }
|
string DebugString() override { return "Iterator resource"; }
|
||||||
|
|
||||||
const DataTypeVector& output_dtypes() const { return output_dtypes_; }
|
const DataTypeVector& output_dtypes() const { return output_dtypes_; }
|
||||||
@ -231,7 +225,6 @@ class IteratorResource : public ResourceBase {
|
|||||||
FunctionLibraryRuntime* lib_ = nullptr; // not owned.
|
FunctionLibraryRuntime* lib_ = nullptr; // not owned.
|
||||||
std::shared_ptr<IteratorBase> iterator_;
|
std::shared_ptr<IteratorBase> iterator_;
|
||||||
mutex mu_;
|
mutex mu_;
|
||||||
std::shared_ptr<StatsAggregator> stats_aggregator_ GUARDED_BY(mu_);
|
|
||||||
std::shared_ptr<const FunctionLibraryDefinition> lib_def_ GUARDED_BY(mu_);
|
std::shared_ptr<const FunctionLibraryDefinition> lib_def_ GUARDED_BY(mu_);
|
||||||
const DataTypeVector output_dtypes_;
|
const DataTypeVector output_dtypes_;
|
||||||
const std::vector<PartialTensorShape> output_shapes_;
|
const std::vector<PartialTensorShape> output_shapes_;
|
||||||
@ -944,9 +937,6 @@ class IteratorGetNextOp : public AsyncOpKernel {
|
|||||||
|
|
||||||
IteratorContext::Params params;
|
IteratorContext::Params params;
|
||||||
params.env = ctx->env();
|
params.env = ctx->env();
|
||||||
params.stats_aggregator_getter = [iterator]() {
|
|
||||||
return iterator->stats_aggregator();
|
|
||||||
};
|
|
||||||
params.runner = *(ctx->runner());
|
params.runner = *(ctx->runner());
|
||||||
params.function_library = iterator->function_library();
|
params.function_library = iterator->function_library();
|
||||||
DeviceBase* device = ctx->function_library()->device();
|
DeviceBase* device = ctx->function_library()->device();
|
||||||
@ -995,9 +985,6 @@ class IteratorGetNextSyncOp : public OpKernel {
|
|||||||
|
|
||||||
IteratorContext::Params params;
|
IteratorContext::Params params;
|
||||||
params.env = ctx->env();
|
params.env = ctx->env();
|
||||||
params.stats_aggregator_getter = [iterator]() {
|
|
||||||
return iterator->stats_aggregator();
|
|
||||||
};
|
|
||||||
params.runner = *(ctx->runner());
|
params.runner = *(ctx->runner());
|
||||||
params.function_library = iterator->function_library();
|
params.function_library = iterator->function_library();
|
||||||
DeviceBase* device = ctx->function_library()->device();
|
DeviceBase* device = ctx->function_library()->device();
|
||||||
|
@ -10085,6 +10085,52 @@ op {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
op {
|
||||||
|
name: "BesselI0e"
|
||||||
|
input_arg {
|
||||||
|
name: "x"
|
||||||
|
type_attr: "T"
|
||||||
|
}
|
||||||
|
output_arg {
|
||||||
|
name: "y"
|
||||||
|
type_attr: "T"
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "T"
|
||||||
|
type: "type"
|
||||||
|
allowed_values {
|
||||||
|
list {
|
||||||
|
type: DT_BFLOAT16
|
||||||
|
type: DT_HALF
|
||||||
|
type: DT_FLOAT
|
||||||
|
type: DT_DOUBLE
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
op {
|
||||||
|
name: "BesselI1e"
|
||||||
|
input_arg {
|
||||||
|
name: "x"
|
||||||
|
type_attr: "T"
|
||||||
|
}
|
||||||
|
output_arg {
|
||||||
|
name: "y"
|
||||||
|
type_attr: "T"
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "T"
|
||||||
|
type: "type"
|
||||||
|
allowed_values {
|
||||||
|
list {
|
||||||
|
type: DT_BFLOAT16
|
||||||
|
type: DT_HALF
|
||||||
|
type: DT_FLOAT
|
||||||
|
type: DT_DOUBLE
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
op {
|
op {
|
||||||
name: "Betainc"
|
name: "Betainc"
|
||||||
input_arg {
|
input_arg {
|
||||||
@ -25468,6 +25514,44 @@ op {
|
|||||||
type: "func"
|
type: "func"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
op {
|
||||||
|
name: "If"
|
||||||
|
input_arg {
|
||||||
|
name: "cond"
|
||||||
|
type_attr: "Tcond"
|
||||||
|
}
|
||||||
|
input_arg {
|
||||||
|
name: "input"
|
||||||
|
type_list_attr: "Tin"
|
||||||
|
}
|
||||||
|
output_arg {
|
||||||
|
name: "output"
|
||||||
|
type_list_attr: "Tout"
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "Tcond"
|
||||||
|
type: "type"
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "Tin"
|
||||||
|
type: "list(type)"
|
||||||
|
has_minimum: true
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "Tout"
|
||||||
|
type: "list(type)"
|
||||||
|
has_minimum: true
|
||||||
|
minimum: 1
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "then_branch"
|
||||||
|
type: "func"
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "else_branch"
|
||||||
|
type: "func"
|
||||||
|
}
|
||||||
|
}
|
||||||
op {
|
op {
|
||||||
name: "Igamma"
|
name: "Igamma"
|
||||||
input_arg {
|
input_arg {
|
||||||
|
@ -239,6 +239,10 @@ REGISTER_OP("Acos").UNARY();
|
|||||||
|
|
||||||
REGISTER_OP("Atan").UNARY();
|
REGISTER_OP("Atan").UNARY();
|
||||||
|
|
||||||
|
REGISTER_OP("BesselI0e").UNARY_REAL();
|
||||||
|
|
||||||
|
REGISTER_OP("BesselI1e").UNARY_REAL();
|
||||||
|
|
||||||
#undef UNARY
|
#undef UNARY
|
||||||
#undef UNARY_REAL
|
#undef UNARY_REAL
|
||||||
#undef UNARY_COMPLEX
|
#undef UNARY_COMPLEX
|
||||||
|
@ -3860,6 +3860,52 @@ op {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
op {
|
||||||
|
name: "BesselI0e"
|
||||||
|
input_arg {
|
||||||
|
name: "x"
|
||||||
|
type_attr: "T"
|
||||||
|
}
|
||||||
|
output_arg {
|
||||||
|
name: "y"
|
||||||
|
type_attr: "T"
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "T"
|
||||||
|
type: "type"
|
||||||
|
allowed_values {
|
||||||
|
list {
|
||||||
|
type: DT_BFLOAT16
|
||||||
|
type: DT_HALF
|
||||||
|
type: DT_FLOAT
|
||||||
|
type: DT_DOUBLE
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
op {
|
||||||
|
name: "BesselI1e"
|
||||||
|
input_arg {
|
||||||
|
name: "x"
|
||||||
|
type_attr: "T"
|
||||||
|
}
|
||||||
|
output_arg {
|
||||||
|
name: "y"
|
||||||
|
type_attr: "T"
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "T"
|
||||||
|
type: "type"
|
||||||
|
allowed_values {
|
||||||
|
list {
|
||||||
|
type: DT_BFLOAT16
|
||||||
|
type: DT_HALF
|
||||||
|
type: DT_FLOAT
|
||||||
|
type: DT_DOUBLE
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
op {
|
op {
|
||||||
name: "Betainc"
|
name: "Betainc"
|
||||||
input_arg {
|
input_arg {
|
||||||
@ -12358,7 +12404,6 @@ op {
|
|||||||
name: "Tin"
|
name: "Tin"
|
||||||
type: "list(type)"
|
type: "list(type)"
|
||||||
has_minimum: true
|
has_minimum: true
|
||||||
minimum: 1
|
|
||||||
}
|
}
|
||||||
attr {
|
attr {
|
||||||
name: "Tout"
|
name: "Tout"
|
||||||
|
@ -4210,69 +4210,6 @@ func Digamma(scope *Scope, x tf.Output) (y tf.Output) {
|
|||||||
return op.Output(0)
|
return op.Output(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Shuffle dimensions of x according to a permutation.
|
|
||||||
//
|
|
||||||
// The output `y` has the same rank as `x`. The shapes of `x` and `y` satisfy:
|
|
||||||
// `y.shape[i] == x.shape[perm[i]] for i in [0, 1, ..., rank(x) - 1]`
|
|
||||||
func Transpose(scope *Scope, x tf.Output, perm tf.Output) (y tf.Output) {
|
|
||||||
if scope.Err() != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
opspec := tf.OpSpec{
|
|
||||||
Type: "Transpose",
|
|
||||||
Input: []tf.Input{
|
|
||||||
x, perm,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
op := scope.AddOperation(opspec)
|
|
||||||
return op.Output(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// MinAttr is an optional argument to Min.
|
|
||||||
type MinAttr func(optionalAttr)
|
|
||||||
|
|
||||||
// MinKeepDims sets the optional keep_dims attribute to value.
|
|
||||||
//
|
|
||||||
// value: If true, retain reduced dimensions with length 1.
|
|
||||||
// If not specified, defaults to false
|
|
||||||
func MinKeepDims(value bool) MinAttr {
|
|
||||||
return func(m optionalAttr) {
|
|
||||||
m["keep_dims"] = value
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Computes the minimum of elements across dimensions of a tensor.
|
|
||||||
//
|
|
||||||
// Reduces `input` along the dimensions given in `axis`. Unless
|
|
||||||
// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in
|
|
||||||
// `axis`. If `keep_dims` is true, the reduced dimensions are
|
|
||||||
// retained with length 1.
|
|
||||||
//
|
|
||||||
// Arguments:
|
|
||||||
// input: The tensor to reduce.
|
|
||||||
// axis: The dimensions to reduce. Must be in the range
|
|
||||||
// `[-rank(input), rank(input))`.
|
|
||||||
//
|
|
||||||
// Returns The reduced tensor.
|
|
||||||
func Min(scope *Scope, input tf.Output, axis tf.Output, optional ...MinAttr) (output tf.Output) {
|
|
||||||
if scope.Err() != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
attrs := map[string]interface{}{}
|
|
||||||
for _, a := range optional {
|
|
||||||
a(attrs)
|
|
||||||
}
|
|
||||||
opspec := tf.OpSpec{
|
|
||||||
Type: "Min",
|
|
||||||
Input: []tf.Input{
|
|
||||||
input, axis,
|
|
||||||
},
|
|
||||||
Attrs: attrs,
|
|
||||||
}
|
|
||||||
op := scope.AddOperation(opspec)
|
|
||||||
return op.Output(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Conv2DBackpropFilterAttr is an optional argument to Conv2DBackpropFilter.
|
// Conv2DBackpropFilterAttr is an optional argument to Conv2DBackpropFilter.
|
||||||
type Conv2DBackpropFilterAttr func(optionalAttr)
|
type Conv2DBackpropFilterAttr func(optionalAttr)
|
||||||
|
|
||||||
@ -6181,6 +6118,77 @@ func Mod(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
|
|||||||
return op.Output(0)
|
return op.Output(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Computes offsets of concat inputs within its output.
|
||||||
|
//
|
||||||
|
// For example:
|
||||||
|
//
|
||||||
|
// ```
|
||||||
|
// # 'x' is [2, 2, 7]
|
||||||
|
// # 'y' is [2, 3, 7]
|
||||||
|
// # 'z' is [2, 5, 7]
|
||||||
|
// concat_offset(2, [x, y, z]) => [0, 0, 0], [0, 2, 0], [0, 5, 0]
|
||||||
|
// ```
|
||||||
|
//
|
||||||
|
// This is typically used by gradient computations for a concat operation.
|
||||||
|
//
|
||||||
|
// Arguments:
|
||||||
|
// concat_dim: The dimension along which to concatenate.
|
||||||
|
// shape: The `N` int32 vectors representing shape of tensors being concatenated.
|
||||||
|
//
|
||||||
|
// Returns The `N` int32 vectors representing the starting offset
|
||||||
|
// of input tensors within the concatenated output.
|
||||||
|
func ConcatOffset(scope *Scope, concat_dim tf.Output, shape []tf.Output) (offset []tf.Output) {
|
||||||
|
if scope.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
opspec := tf.OpSpec{
|
||||||
|
Type: "ConcatOffset",
|
||||||
|
Input: []tf.Input{
|
||||||
|
concat_dim, tf.OutputList(shape),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
op := scope.AddOperation(opspec)
|
||||||
|
if scope.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var idx int
|
||||||
|
var err error
|
||||||
|
if offset, idx, err = makeOutputList(op, idx, "offset"); err != nil {
|
||||||
|
scope.UpdateErr("ConcatOffset", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return offset
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the lower regularized incomplete Gamma function `Q(a, x)`.
|
||||||
|
//
|
||||||
|
// The lower regularized incomplete Gamma function is defined as:
|
||||||
|
//
|
||||||
|
//
|
||||||
|
// \\(P(a, x) = gamma(a, x) / Gamma(a) = 1 - Q(a, x)\\)
|
||||||
|
//
|
||||||
|
// where
|
||||||
|
//
|
||||||
|
// \\(gamma(a, x) = int_{0}^{x} t^{a-1} exp(-t) dt\\)
|
||||||
|
//
|
||||||
|
// is the lower incomplete Gamma function.
|
||||||
|
//
|
||||||
|
// Note, above `Q(a, x)` (`Igammac`) is the upper regularized complete
|
||||||
|
// Gamma function.
|
||||||
|
func Igamma(scope *Scope, a tf.Output, x tf.Output) (z tf.Output) {
|
||||||
|
if scope.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
opspec := tf.OpSpec{
|
||||||
|
Type: "Igamma",
|
||||||
|
Input: []tf.Input{
|
||||||
|
a, x,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
op := scope.AddOperation(opspec)
|
||||||
|
return op.Output(0)
|
||||||
|
}
|
||||||
|
|
||||||
// DepthToSpaceAttr is an optional argument to DepthToSpace.
|
// DepthToSpaceAttr is an optional argument to DepthToSpace.
|
||||||
type DepthToSpaceAttr func(optionalAttr)
|
type DepthToSpaceAttr func(optionalAttr)
|
||||||
|
|
||||||
@ -7000,6 +7008,69 @@ func BiasAddV1(scope *Scope, value tf.Output, bias tf.Output) (output tf.Output)
|
|||||||
return op.Output(0)
|
return op.Output(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Shuffle dimensions of x according to a permutation.
|
||||||
|
//
|
||||||
|
// The output `y` has the same rank as `x`. The shapes of `x` and `y` satisfy:
|
||||||
|
// `y.shape[i] == x.shape[perm[i]] for i in [0, 1, ..., rank(x) - 1]`
|
||||||
|
func Transpose(scope *Scope, x tf.Output, perm tf.Output) (y tf.Output) {
|
||||||
|
if scope.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
opspec := tf.OpSpec{
|
||||||
|
Type: "Transpose",
|
||||||
|
Input: []tf.Input{
|
||||||
|
x, perm,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
op := scope.AddOperation(opspec)
|
||||||
|
return op.Output(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MinAttr is an optional argument to Min.
|
||||||
|
type MinAttr func(optionalAttr)
|
||||||
|
|
||||||
|
// MinKeepDims sets the optional keep_dims attribute to value.
|
||||||
|
//
|
||||||
|
// value: If true, retain reduced dimensions with length 1.
|
||||||
|
// If not specified, defaults to false
|
||||||
|
func MinKeepDims(value bool) MinAttr {
|
||||||
|
return func(m optionalAttr) {
|
||||||
|
m["keep_dims"] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Computes the minimum of elements across dimensions of a tensor.
|
||||||
|
//
|
||||||
|
// Reduces `input` along the dimensions given in `axis`. Unless
|
||||||
|
// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in
|
||||||
|
// `axis`. If `keep_dims` is true, the reduced dimensions are
|
||||||
|
// retained with length 1.
|
||||||
|
//
|
||||||
|
// Arguments:
|
||||||
|
// input: The tensor to reduce.
|
||||||
|
// axis: The dimensions to reduce. Must be in the range
|
||||||
|
// `[-rank(input), rank(input))`.
|
||||||
|
//
|
||||||
|
// Returns The reduced tensor.
|
||||||
|
func Min(scope *Scope, input tf.Output, axis tf.Output, optional ...MinAttr) (output tf.Output) {
|
||||||
|
if scope.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
attrs := map[string]interface{}{}
|
||||||
|
for _, a := range optional {
|
||||||
|
a(attrs)
|
||||||
|
}
|
||||||
|
opspec := tf.OpSpec{
|
||||||
|
Type: "Min",
|
||||||
|
Input: []tf.Input{
|
||||||
|
input, axis,
|
||||||
|
},
|
||||||
|
Attrs: attrs,
|
||||||
|
}
|
||||||
|
op := scope.AddOperation(opspec)
|
||||||
|
return op.Output(0)
|
||||||
|
}
|
||||||
|
|
||||||
// Transforms a Tensor into a serialized TensorProto proto.
|
// Transforms a Tensor into a serialized TensorProto proto.
|
||||||
//
|
//
|
||||||
// Arguments:
|
// Arguments:
|
||||||
@ -11592,60 +11663,6 @@ func SparseDenseCwiseMul(scope *Scope, sp_indices tf.Output, sp_values tf.Output
|
|||||||
return op.Output(0)
|
return op.Output(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ResizeAreaAttr is an optional argument to ResizeArea.
|
|
||||||
type ResizeAreaAttr func(optionalAttr)
|
|
||||||
|
|
||||||
// ResizeAreaAlignCorners sets the optional align_corners attribute to value.
|
|
||||||
//
|
|
||||||
// value: If true, the centers of the 4 corner pixels of the input and output tensors are
|
|
||||||
// aligned, preserving the values at the corner pixels. Defaults to false.
|
|
||||||
// If not specified, defaults to false
|
|
||||||
func ResizeAreaAlignCorners(value bool) ResizeAreaAttr {
|
|
||||||
return func(m optionalAttr) {
|
|
||||||
m["align_corners"] = value
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Resize `images` to `size` using area interpolation.
|
|
||||||
//
|
|
||||||
// Input images can be of different types but output images are always float.
|
|
||||||
//
|
|
||||||
// The range of pixel values for the output image might be slightly different
|
|
||||||
// from the range for the input image because of limited numerical precision.
|
|
||||||
// To guarantee an output range, for example `[0.0, 1.0]`, apply
|
|
||||||
// `tf.clip_by_value` to the output.
|
|
||||||
//
|
|
||||||
// Each output pixel is computed by first transforming the pixel's footprint into
|
|
||||||
// the input tensor and then averaging the pixels that intersect the footprint. An
|
|
||||||
// input pixel's contribution to the average is weighted by the fraction of its
|
|
||||||
// area that intersects the footprint. This is the same as OpenCV's INTER_AREA.
|
|
||||||
//
|
|
||||||
// Arguments:
|
|
||||||
// images: 4-D with shape `[batch, height, width, channels]`.
|
|
||||||
// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The
|
|
||||||
// new size for the images.
|
|
||||||
//
|
|
||||||
// Returns 4-D with shape
|
|
||||||
// `[batch, new_height, new_width, channels]`.
|
|
||||||
func ResizeArea(scope *Scope, images tf.Output, size tf.Output, optional ...ResizeAreaAttr) (resized_images tf.Output) {
|
|
||||||
if scope.Err() != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
attrs := map[string]interface{}{}
|
|
||||||
for _, a := range optional {
|
|
||||||
a(attrs)
|
|
||||||
}
|
|
||||||
opspec := tf.OpSpec{
|
|
||||||
Type: "ResizeArea",
|
|
||||||
Input: []tf.Input{
|
|
||||||
images, size,
|
|
||||||
},
|
|
||||||
Attrs: attrs,
|
|
||||||
}
|
|
||||||
op := scope.AddOperation(opspec)
|
|
||||||
return op.Output(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 2D real-valued fast Fourier transform.
|
// 2D real-valued fast Fourier transform.
|
||||||
//
|
//
|
||||||
// Computes the 2-dimensional discrete Fourier transform of a real-valued signal
|
// Computes the 2-dimensional discrete Fourier transform of a real-valued signal
|
||||||
@ -13635,170 +13652,6 @@ func TopK(scope *Scope, input tf.Output, k int64, optional ...TopKAttr) (values
|
|||||||
return op.Output(0), op.Output(1)
|
return op.Output(0), op.Output(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ComplexAttr is an optional argument to Complex.
|
|
||||||
type ComplexAttr func(optionalAttr)
|
|
||||||
|
|
||||||
// ComplexTout sets the optional Tout attribute to value.
|
|
||||||
// If not specified, defaults to DT_COMPLEX64
|
|
||||||
func ComplexTout(value tf.DataType) ComplexAttr {
|
|
||||||
return func(m optionalAttr) {
|
|
||||||
m["Tout"] = value
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Converts two real numbers to a complex number.
|
|
||||||
//
|
|
||||||
// Given a tensor `real` representing the real part of a complex number, and a
|
|
||||||
// tensor `imag` representing the imaginary part of a complex number, this
|
|
||||||
// operation returns complex numbers elementwise of the form \\(a + bj\\), where
|
|
||||||
// *a* represents the `real` part and *b* represents the `imag` part.
|
|
||||||
//
|
|
||||||
// The input tensors `real` and `imag` must have the same shape.
|
|
||||||
//
|
|
||||||
// For example:
|
|
||||||
//
|
|
||||||
// ```
|
|
||||||
// # tensor 'real' is [2.25, 3.25]
|
|
||||||
// # tensor `imag` is [4.75, 5.75]
|
|
||||||
// tf.complex(real, imag) ==> [[2.25 + 4.75j], [3.25 + 5.75j]]
|
|
||||||
// ```
|
|
||||||
func Complex(scope *Scope, real tf.Output, imag tf.Output, optional ...ComplexAttr) (out tf.Output) {
|
|
||||||
if scope.Err() != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
attrs := map[string]interface{}{}
|
|
||||||
for _, a := range optional {
|
|
||||||
a(attrs)
|
|
||||||
}
|
|
||||||
opspec := tf.OpSpec{
|
|
||||||
Type: "Complex",
|
|
||||||
Input: []tf.Input{
|
|
||||||
real, imag,
|
|
||||||
},
|
|
||||||
Attrs: attrs,
|
|
||||||
}
|
|
||||||
op := scope.AddOperation(opspec)
|
|
||||||
return op.Output(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ImagAttr is an optional argument to Imag.
|
|
||||||
type ImagAttr func(optionalAttr)
|
|
||||||
|
|
||||||
// ImagTout sets the optional Tout attribute to value.
|
|
||||||
// If not specified, defaults to DT_FLOAT
|
|
||||||
func ImagTout(value tf.DataType) ImagAttr {
|
|
||||||
return func(m optionalAttr) {
|
|
||||||
m["Tout"] = value
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns the imaginary part of a complex number.
|
|
||||||
//
|
|
||||||
// Given a tensor `input` of complex numbers, this operation returns a tensor of
|
|
||||||
// type `float` that is the imaginary part of each element in `input`. All
|
|
||||||
// elements in `input` must be complex numbers of the form \\(a + bj\\), where *a*
|
|
||||||
// is the real part and *b* is the imaginary part returned by this operation.
|
|
||||||
//
|
|
||||||
// For example:
|
|
||||||
//
|
|
||||||
// ```
|
|
||||||
// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j]
|
|
||||||
// tf.imag(input) ==> [4.75, 5.75]
|
|
||||||
// ```
|
|
||||||
func Imag(scope *Scope, input tf.Output, optional ...ImagAttr) (output tf.Output) {
|
|
||||||
if scope.Err() != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
attrs := map[string]interface{}{}
|
|
||||||
for _, a := range optional {
|
|
||||||
a(attrs)
|
|
||||||
}
|
|
||||||
opspec := tf.OpSpec{
|
|
||||||
Type: "Imag",
|
|
||||||
Input: []tf.Input{
|
|
||||||
input,
|
|
||||||
},
|
|
||||||
Attrs: attrs,
|
|
||||||
}
|
|
||||||
op := scope.AddOperation(opspec)
|
|
||||||
return op.Output(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Computes the maximum along segments of a tensor.
|
|
||||||
//
|
|
||||||
// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
|
|
||||||
// segments.
|
|
||||||
//
|
|
||||||
// Computes a tensor such that
|
|
||||||
// \\(output_i = \max_j(data_j)\\) where `max` is over `j` such
|
|
||||||
// that `segment_ids[j] == i`.
|
|
||||||
//
|
|
||||||
// If the max is empty for a given segment ID `i`, `output[i] = 0`.
|
|
||||||
//
|
|
||||||
// <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
|
|
||||||
// <img style="width:100%" src="https://www.tensorflow.org/images/SegmentMax.png" alt>
|
|
||||||
// </div>
|
|
||||||
//
|
|
||||||
// Arguments:
|
|
||||||
//
|
|
||||||
// segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
|
|
||||||
// first dimension. Values should be sorted and can be repeated.
|
|
||||||
//
|
|
||||||
// Returns Has same shape as data, except for dimension 0 which
|
|
||||||
// has size `k`, the number of segments.
|
|
||||||
func SegmentMax(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf.Output) {
|
|
||||||
if scope.Err() != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
opspec := tf.OpSpec{
|
|
||||||
Type: "SegmentMax",
|
|
||||||
Input: []tf.Input{
|
|
||||||
data, segment_ids,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
op := scope.AddOperation(opspec)
|
|
||||||
return op.Output(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Computes hyperbolic tangent of `x` element-wise.
|
|
||||||
func Tanh(scope *Scope, x tf.Output) (y tf.Output) {
|
|
||||||
if scope.Err() != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
opspec := tf.OpSpec{
|
|
||||||
Type: "Tanh",
|
|
||||||
Input: []tf.Input{
|
|
||||||
x,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
op := scope.AddOperation(opspec)
|
|
||||||
return op.Output(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Creates a dataset that skips `count` elements from the `input_dataset`.
|
|
||||||
//
|
|
||||||
// Arguments:
|
|
||||||
//
|
|
||||||
// count: A scalar representing the number of elements from the `input_dataset`
|
|
||||||
// that should be skipped. If count is -1, skips everything.
|
|
||||||
//
|
|
||||||
//
|
|
||||||
func SkipDataset(scope *Scope, input_dataset tf.Output, count tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) {
|
|
||||||
if scope.Err() != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes}
|
|
||||||
opspec := tf.OpSpec{
|
|
||||||
Type: "SkipDataset",
|
|
||||||
Input: []tf.Input{
|
|
||||||
input_dataset, count,
|
|
||||||
},
|
|
||||||
Attrs: attrs,
|
|
||||||
}
|
|
||||||
op := scope.AddOperation(opspec)
|
|
||||||
return op.Output(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compute the Hurwitz zeta function \\(\zeta(x, q)\\).
|
// Compute the Hurwitz zeta function \\(\zeta(x, q)\\).
|
||||||
//
|
//
|
||||||
// The Hurwitz zeta function is defined as:
|
// The Hurwitz zeta function is defined as:
|
||||||
@ -14064,49 +13917,6 @@ func Minimum(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
|
|||||||
return op.Output(0)
|
return op.Output(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RealAttr is an optional argument to Real.
|
|
||||||
type RealAttr func(optionalAttr)
|
|
||||||
|
|
||||||
// RealTout sets the optional Tout attribute to value.
|
|
||||||
// If not specified, defaults to DT_FLOAT
|
|
||||||
func RealTout(value tf.DataType) RealAttr {
|
|
||||||
return func(m optionalAttr) {
|
|
||||||
m["Tout"] = value
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns the real part of a complex number.
|
|
||||||
//
|
|
||||||
// Given a tensor `input` of complex numbers, this operation returns a tensor of
|
|
||||||
// type `float` that is the real part of each element in `input`. All elements in
|
|
||||||
// `input` must be complex numbers of the form \\(a + bj\\), where *a* is the real
|
|
||||||
// part returned by this operation and *b* is the imaginary part.
|
|
||||||
//
|
|
||||||
// For example:
|
|
||||||
//
|
|
||||||
// ```
|
|
||||||
// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j]
|
|
||||||
// tf.real(input) ==> [-2.25, 3.25]
|
|
||||||
// ```
|
|
||||||
func Real(scope *Scope, input tf.Output, optional ...RealAttr) (output tf.Output) {
|
|
||||||
if scope.Err() != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
attrs := map[string]interface{}{}
|
|
||||||
for _, a := range optional {
|
|
||||||
a(attrs)
|
|
||||||
}
|
|
||||||
opspec := tf.OpSpec{
|
|
||||||
Type: "Real",
|
|
||||||
Input: []tf.Input{
|
|
||||||
input,
|
|
||||||
},
|
|
||||||
Attrs: attrs,
|
|
||||||
}
|
|
||||||
op := scope.AddOperation(opspec)
|
|
||||||
return op.Output(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AudioSummaryAttr is an optional argument to AudioSummary.
|
// AudioSummaryAttr is an optional argument to AudioSummary.
|
||||||
type AudioSummaryAttr func(optionalAttr)
|
type AudioSummaryAttr func(optionalAttr)
|
||||||
|
|
||||||
@ -19704,6 +19514,267 @@ func OrderedMapIncompleteSize(scope *Scope, dtypes []tf.DataType, optional ...Or
|
|||||||
return op.Output(0)
|
return op.Output(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ComplexAttr is an optional argument to Complex.
|
||||||
|
type ComplexAttr func(optionalAttr)
|
||||||
|
|
||||||
|
// ComplexTout sets the optional Tout attribute to value.
|
||||||
|
// If not specified, defaults to DT_COMPLEX64
|
||||||
|
func ComplexTout(value tf.DataType) ComplexAttr {
|
||||||
|
return func(m optionalAttr) {
|
||||||
|
m["Tout"] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Converts two real numbers to a complex number.
|
||||||
|
//
|
||||||
|
// Given a tensor `real` representing the real part of a complex number, and a
|
||||||
|
// tensor `imag` representing the imaginary part of a complex number, this
|
||||||
|
// operation returns complex numbers elementwise of the form \\(a + bj\\), where
|
||||||
|
// *a* represents the `real` part and *b* represents the `imag` part.
|
||||||
|
//
|
||||||
|
// The input tensors `real` and `imag` must have the same shape.
|
||||||
|
//
|
||||||
|
// For example:
|
||||||
|
//
|
||||||
|
// ```
|
||||||
|
// # tensor 'real' is [2.25, 3.25]
|
||||||
|
// # tensor `imag` is [4.75, 5.75]
|
||||||
|
// tf.complex(real, imag) ==> [[2.25 + 4.75j], [3.25 + 5.75j]]
|
||||||
|
// ```
|
||||||
|
func Complex(scope *Scope, real tf.Output, imag tf.Output, optional ...ComplexAttr) (out tf.Output) {
|
||||||
|
if scope.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
attrs := map[string]interface{}{}
|
||||||
|
for _, a := range optional {
|
||||||
|
a(attrs)
|
||||||
|
}
|
||||||
|
opspec := tf.OpSpec{
|
||||||
|
Type: "Complex",
|
||||||
|
Input: []tf.Input{
|
||||||
|
real, imag,
|
||||||
|
},
|
||||||
|
Attrs: attrs,
|
||||||
|
}
|
||||||
|
op := scope.AddOperation(opspec)
|
||||||
|
return op.Output(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ImagAttr is an optional argument to Imag.
|
||||||
|
type ImagAttr func(optionalAttr)
|
||||||
|
|
||||||
|
// ImagTout sets the optional Tout attribute to value.
|
||||||
|
// If not specified, defaults to DT_FLOAT
|
||||||
|
func ImagTout(value tf.DataType) ImagAttr {
|
||||||
|
return func(m optionalAttr) {
|
||||||
|
m["Tout"] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the imaginary part of a complex number.
|
||||||
|
//
|
||||||
|
// Given a tensor `input` of complex numbers, this operation returns a tensor of
|
||||||
|
// type `float` that is the imaginary part of each element in `input`. All
|
||||||
|
// elements in `input` must be complex numbers of the form \\(a + bj\\), where *a*
|
||||||
|
// is the real part and *b* is the imaginary part returned by this operation.
|
||||||
|
//
|
||||||
|
// For example:
|
||||||
|
//
|
||||||
|
// ```
|
||||||
|
// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j]
|
||||||
|
// tf.imag(input) ==> [4.75, 5.75]
|
||||||
|
// ```
|
||||||
|
func Imag(scope *Scope, input tf.Output, optional ...ImagAttr) (output tf.Output) {
|
||||||
|
if scope.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
attrs := map[string]interface{}{}
|
||||||
|
for _, a := range optional {
|
||||||
|
a(attrs)
|
||||||
|
}
|
||||||
|
opspec := tf.OpSpec{
|
||||||
|
Type: "Imag",
|
||||||
|
Input: []tf.Input{
|
||||||
|
input,
|
||||||
|
},
|
||||||
|
Attrs: attrs,
|
||||||
|
}
|
||||||
|
op := scope.AddOperation(opspec)
|
||||||
|
return op.Output(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Computes the maximum along segments of a tensor.
|
||||||
|
//
|
||||||
|
// Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
|
||||||
|
// segments.
|
||||||
|
//
|
||||||
|
// Computes a tensor such that
|
||||||
|
// \\(output_i = \max_j(data_j)\\) where `max` is over `j` such
|
||||||
|
// that `segment_ids[j] == i`.
|
||||||
|
//
|
||||||
|
// If the max is empty for a given segment ID `i`, `output[i] = 0`.
|
||||||
|
//
|
||||||
|
// <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
|
||||||
|
// <img style="width:100%" src="https://www.tensorflow.org/images/SegmentMax.png" alt>
|
||||||
|
// </div>
|
||||||
|
//
|
||||||
|
// Arguments:
|
||||||
|
//
|
||||||
|
// segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
|
||||||
|
// first dimension. Values should be sorted and can be repeated.
|
||||||
|
//
|
||||||
|
// Returns Has same shape as data, except for dimension 0 which
|
||||||
|
// has size `k`, the number of segments.
|
||||||
|
func SegmentMax(scope *Scope, data tf.Output, segment_ids tf.Output) (output tf.Output) {
|
||||||
|
if scope.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
opspec := tf.OpSpec{
|
||||||
|
Type: "SegmentMax",
|
||||||
|
Input: []tf.Input{
|
||||||
|
data, segment_ids,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
op := scope.AddOperation(opspec)
|
||||||
|
return op.Output(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Computes hyperbolic tangent of `x` element-wise.
|
||||||
|
func Tanh(scope *Scope, x tf.Output) (y tf.Output) {
|
||||||
|
if scope.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
opspec := tf.OpSpec{
|
||||||
|
Type: "Tanh",
|
||||||
|
Input: []tf.Input{
|
||||||
|
x,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
op := scope.AddOperation(opspec)
|
||||||
|
return op.Output(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Creates a dataset that skips `count` elements from the `input_dataset`.
|
||||||
|
//
|
||||||
|
// Arguments:
|
||||||
|
//
|
||||||
|
// count: A scalar representing the number of elements from the `input_dataset`
|
||||||
|
// that should be skipped. If count is -1, skips everything.
|
||||||
|
//
|
||||||
|
//
|
||||||
|
func SkipDataset(scope *Scope, input_dataset tf.Output, count tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) {
|
||||||
|
if scope.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes}
|
||||||
|
opspec := tf.OpSpec{
|
||||||
|
Type: "SkipDataset",
|
||||||
|
Input: []tf.Input{
|
||||||
|
input_dataset, count,
|
||||||
|
},
|
||||||
|
Attrs: attrs,
|
||||||
|
}
|
||||||
|
op := scope.AddOperation(opspec)
|
||||||
|
return op.Output(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RealAttr is an optional argument to Real.
|
||||||
|
type RealAttr func(optionalAttr)
|
||||||
|
|
||||||
|
// RealTout sets the optional Tout attribute to value.
|
||||||
|
// If not specified, defaults to DT_FLOAT
|
||||||
|
func RealTout(value tf.DataType) RealAttr {
|
||||||
|
return func(m optionalAttr) {
|
||||||
|
m["Tout"] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the real part of a complex number.
|
||||||
|
//
|
||||||
|
// Given a tensor `input` of complex numbers, this operation returns a tensor of
|
||||||
|
// type `float` that is the real part of each element in `input`. All elements in
|
||||||
|
// `input` must be complex numbers of the form \\(a + bj\\), where *a* is the real
|
||||||
|
// part returned by this operation and *b* is the imaginary part.
|
||||||
|
//
|
||||||
|
// For example:
|
||||||
|
//
|
||||||
|
// ```
|
||||||
|
// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j]
|
||||||
|
// tf.real(input) ==> [-2.25, 3.25]
|
||||||
|
// ```
|
||||||
|
func Real(scope *Scope, input tf.Output, optional ...RealAttr) (output tf.Output) {
|
||||||
|
if scope.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
attrs := map[string]interface{}{}
|
||||||
|
for _, a := range optional {
|
||||||
|
a(attrs)
|
||||||
|
}
|
||||||
|
opspec := tf.OpSpec{
|
||||||
|
Type: "Real",
|
||||||
|
Input: []tf.Input{
|
||||||
|
input,
|
||||||
|
},
|
||||||
|
Attrs: attrs,
|
||||||
|
}
|
||||||
|
op := scope.AddOperation(opspec)
|
||||||
|
return op.Output(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResizeAreaAttr is an optional argument to ResizeArea.
|
||||||
|
type ResizeAreaAttr func(optionalAttr)
|
||||||
|
|
||||||
|
// ResizeAreaAlignCorners sets the optional align_corners attribute to value.
|
||||||
|
//
|
||||||
|
// value: If true, the centers of the 4 corner pixels of the input and output tensors are
|
||||||
|
// aligned, preserving the values at the corner pixels. Defaults to false.
|
||||||
|
// If not specified, defaults to false
|
||||||
|
func ResizeAreaAlignCorners(value bool) ResizeAreaAttr {
|
||||||
|
return func(m optionalAttr) {
|
||||||
|
m["align_corners"] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resize `images` to `size` using area interpolation.
|
||||||
|
//
|
||||||
|
// Input images can be of different types but output images are always float.
|
||||||
|
//
|
||||||
|
// The range of pixel values for the output image might be slightly different
|
||||||
|
// from the range for the input image because of limited numerical precision.
|
||||||
|
// To guarantee an output range, for example `[0.0, 1.0]`, apply
|
||||||
|
// `tf.clip_by_value` to the output.
|
||||||
|
//
|
||||||
|
// Each output pixel is computed by first transforming the pixel's footprint into
|
||||||
|
// the input tensor and then averaging the pixels that intersect the footprint. An
|
||||||
|
// input pixel's contribution to the average is weighted by the fraction of its
|
||||||
|
// area that intersects the footprint. This is the same as OpenCV's INTER_AREA.
|
||||||
|
//
|
||||||
|
// Arguments:
|
||||||
|
// images: 4-D with shape `[batch, height, width, channels]`.
|
||||||
|
// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The
|
||||||
|
// new size for the images.
|
||||||
|
//
|
||||||
|
// Returns 4-D with shape
|
||||||
|
// `[batch, new_height, new_width, channels]`.
|
||||||
|
func ResizeArea(scope *Scope, images tf.Output, size tf.Output, optional ...ResizeAreaAttr) (resized_images tf.Output) {
|
||||||
|
if scope.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
attrs := map[string]interface{}{}
|
||||||
|
for _, a := range optional {
|
||||||
|
a(attrs)
|
||||||
|
}
|
||||||
|
opspec := tf.OpSpec{
|
||||||
|
Type: "ResizeArea",
|
||||||
|
Input: []tf.Input{
|
||||||
|
images, size,
|
||||||
|
},
|
||||||
|
Attrs: attrs,
|
||||||
|
}
|
||||||
|
op := scope.AddOperation(opspec)
|
||||||
|
return op.Output(0)
|
||||||
|
}
|
||||||
|
|
||||||
// VarHandleOpAttr is an optional argument to VarHandleOp.
|
// VarHandleOpAttr is an optional argument to VarHandleOp.
|
||||||
type VarHandleOpAttr func(optionalAttr)
|
type VarHandleOpAttr func(optionalAttr)
|
||||||
|
|
||||||
@ -30639,74 +30710,3 @@ func UnravelIndex(scope *Scope, indices tf.Output, dims tf.Output) (output tf.Ou
|
|||||||
op := scope.AddOperation(opspec)
|
op := scope.AddOperation(opspec)
|
||||||
return op.Output(0)
|
return op.Output(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compute the lower regularized incomplete Gamma function `Q(a, x)`.
|
|
||||||
//
|
|
||||||
// The lower regularized incomplete Gamma function is defined as:
|
|
||||||
//
|
|
||||||
//
|
|
||||||
// \\(P(a, x) = gamma(a, x) / Gamma(a) = 1 - Q(a, x)\\)
|
|
||||||
//
|
|
||||||
// where
|
|
||||||
//
|
|
||||||
// \\(gamma(a, x) = int_{0}^{x} t^{a-1} exp(-t) dt\\)
|
|
||||||
//
|
|
||||||
// is the lower incomplete Gamma function.
|
|
||||||
//
|
|
||||||
// Note, above `Q(a, x)` (`Igammac`) is the upper regularized complete
|
|
||||||
// Gamma function.
|
|
||||||
func Igamma(scope *Scope, a tf.Output, x tf.Output) (z tf.Output) {
|
|
||||||
if scope.Err() != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
opspec := tf.OpSpec{
|
|
||||||
Type: "Igamma",
|
|
||||||
Input: []tf.Input{
|
|
||||||
a, x,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
op := scope.AddOperation(opspec)
|
|
||||||
return op.Output(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Computes offsets of concat inputs within its output.
|
|
||||||
//
|
|
||||||
// For example:
|
|
||||||
//
|
|
||||||
// ```
|
|
||||||
// # 'x' is [2, 2, 7]
|
|
||||||
// # 'y' is [2, 3, 7]
|
|
||||||
// # 'z' is [2, 5, 7]
|
|
||||||
// concat_offset(2, [x, y, z]) => [0, 0, 0], [0, 2, 0], [0, 5, 0]
|
|
||||||
// ```
|
|
||||||
//
|
|
||||||
// This is typically used by gradient computations for a concat operation.
|
|
||||||
//
|
|
||||||
// Arguments:
|
|
||||||
// concat_dim: The dimension along which to concatenate.
|
|
||||||
// shape: The `N` int32 vectors representing shape of tensors being concatenated.
|
|
||||||
//
|
|
||||||
// Returns The `N` int32 vectors representing the starting offset
|
|
||||||
// of input tensors within the concatenated output.
|
|
||||||
func ConcatOffset(scope *Scope, concat_dim tf.Output, shape []tf.Output) (offset []tf.Output) {
|
|
||||||
if scope.Err() != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
opspec := tf.OpSpec{
|
|
||||||
Type: "ConcatOffset",
|
|
||||||
Input: []tf.Input{
|
|
||||||
concat_dim, tf.OutputList(shape),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
op := scope.AddOperation(opspec)
|
|
||||||
if scope.Err() != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var idx int
|
|
||||||
var err error
|
|
||||||
if offset, idx, err = makeOutputList(op, idx, "offset"); err != nil {
|
|
||||||
scope.UpdateErr("ConcatOffset", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return offset
|
|
||||||
}
|
|
||||||
|
@ -2530,6 +2530,7 @@ py_library(
|
|||||||
":check_ops",
|
":check_ops",
|
||||||
":confusion_matrix",
|
":confusion_matrix",
|
||||||
":control_flow_ops",
|
":control_flow_ops",
|
||||||
|
":distribute",
|
||||||
":framework",
|
":framework",
|
||||||
":framework_for_generated_wrappers",
|
":framework_for_generated_wrappers",
|
||||||
":math_ops",
|
":math_ops",
|
||||||
|
@ -259,9 +259,7 @@ class DatasetConstructorTest(test.TestCase):
|
|||||||
sess.run(init_op)
|
sess.run(init_op)
|
||||||
self.assertAllEqual([1, 2, 3], sess.run(get_next))
|
self.assertAllEqual([1, 2, 3], sess.run(get_next))
|
||||||
self.assertAllEqual([4, 5, 6], sess.run(get_next))
|
self.assertAllEqual([4, 5, 6], sess.run(get_next))
|
||||||
# NOTE(mrry): Type name in message differs between Python 2 (`long`) and
|
with self.assertRaisesOpError("The expected type was int64"):
|
||||||
# 3 (`int`).
|
|
||||||
with self.assertRaisesOpError(r"invalid literal for"):
|
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
self.assertAllEqual([7, 8, 9], sess.run(get_next))
|
self.assertAllEqual([7, 8, 9], sess.run(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
@ -290,6 +288,34 @@ class DatasetConstructorTest(test.TestCase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
|
def testFromGeneratorStructureError(self):
|
||||||
|
def generator():
|
||||||
|
yield 1, 2
|
||||||
|
yield 3, 4
|
||||||
|
yield 5
|
||||||
|
yield 6, 7, 8
|
||||||
|
yield 9, 10
|
||||||
|
|
||||||
|
iterator = (dataset_ops.Dataset.from_generator(
|
||||||
|
generator, output_types=(dtypes.int64, dtypes.int64))
|
||||||
|
.make_initializable_iterator())
|
||||||
|
init_op = iterator.initializer
|
||||||
|
get_next = iterator.get_next()
|
||||||
|
|
||||||
|
with self.test_session() as sess:
|
||||||
|
sess.run(init_op)
|
||||||
|
self.assertEqual((1, 2), sess.run(get_next))
|
||||||
|
self.assertEqual((3, 4), sess.run(get_next))
|
||||||
|
with self.assertRaisesOpError(
|
||||||
|
r"The expected structure was \(tf\.int64, tf\.int64\)"):
|
||||||
|
sess.run(get_next)
|
||||||
|
with self.assertRaisesOpError(
|
||||||
|
r"The expected structure was \(tf\.int64, tf\.int64\)"):
|
||||||
|
sess.run(get_next)
|
||||||
|
self.assertEqual((9, 10), sess.run(get_next))
|
||||||
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
sess.run(get_next)
|
||||||
|
|
||||||
def testFromGeneratorHeterogeneous(self):
|
def testFromGeneratorHeterogeneous(self):
|
||||||
def generator():
|
def generator():
|
||||||
yield 1
|
yield 1
|
||||||
|
@ -223,6 +223,13 @@ class Dataset(object):
|
|||||||
def from_tensors(tensors):
|
def from_tensors(tensors):
|
||||||
"""Creates a `Dataset` with a single element, comprising the given tensors.
|
"""Creates a `Dataset` with a single element, comprising the given tensors.
|
||||||
|
|
||||||
|
Note that if `tensors` contains a NumPy array, and eager execution is not
|
||||||
|
enabled, the values will be embedded in the graph as one or more
|
||||||
|
@{tf.constant} operations. For large datasets (> 1 GB), this can waste
|
||||||
|
memory and run into byte limits of graph serialization. If tensors contains
|
||||||
|
one or more large NumPy arrays, consider the alternative described in
|
||||||
|
@{$programmers_guide/datasets#consuming_numpy_arrays$this guide}.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tensors: A nested structure of tensors.
|
tensors: A nested structure of tensors.
|
||||||
|
|
||||||
@ -235,6 +242,13 @@ class Dataset(object):
|
|||||||
def from_tensor_slices(tensors):
|
def from_tensor_slices(tensors):
|
||||||
"""Creates a `Dataset` whose elements are slices of the given tensors.
|
"""Creates a `Dataset` whose elements are slices of the given tensors.
|
||||||
|
|
||||||
|
Note that if `tensors` contains a NumPy array, and eager execution is not
|
||||||
|
enabled, the values will be embedded in the graph as one or more
|
||||||
|
@{tf.constant} operations. For large datasets (> 1 GB), this can waste
|
||||||
|
memory and run into byte limits of graph serialization. If tensors contains
|
||||||
|
one or more large NumPy arrays, consider the alternative described in
|
||||||
|
@{$programmers_guide/datasets#consuming_numpy_arrays$this guide}.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tensors: A nested structure of tensors, each having the same size in the
|
tensors: A nested structure of tensors, each having the same size in the
|
||||||
0th dimension.
|
0th dimension.
|
||||||
@ -409,13 +423,23 @@ class Dataset(object):
|
|||||||
# Use the same _convert function from the py_func() implementation to
|
# Use the same _convert function from the py_func() implementation to
|
||||||
# convert the returned values to arrays early, so that we can inspect
|
# convert the returned values to arrays early, so that we can inspect
|
||||||
# their values.
|
# their values.
|
||||||
# pylint: disable=protected-access
|
try:
|
||||||
ret_arrays = [
|
flattened_values = nest.flatten_up_to(output_types, values)
|
||||||
script_ops.FuncRegistry._convert(ret, dtype=dtype.as_numpy_dtype)
|
except (TypeError, ValueError):
|
||||||
for ret, dtype in zip(
|
raise TypeError(
|
||||||
nest.flatten_up_to(output_types, values), flattened_types)
|
"`generator` yielded an element that did not match the expected "
|
||||||
]
|
"structure. The expected structure was %s, but the yielded "
|
||||||
# pylint: enable=protected-access
|
"element was %s." % (output_types, values))
|
||||||
|
ret_arrays = []
|
||||||
|
for ret, dtype in zip(flattened_values, flattened_types):
|
||||||
|
try:
|
||||||
|
ret_arrays.append(script_ops.FuncRegistry._convert( # pylint: disable=protected-access
|
||||||
|
ret, dtype=dtype.as_numpy_dtype))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
raise TypeError(
|
||||||
|
"`generator` yielded an element that could not be converted to "
|
||||||
|
"the expected type. The expected type was %s, but the yielded "
|
||||||
|
"element was %s." % (dtype.name, ret))
|
||||||
|
|
||||||
# Additional type and shape checking to ensure that the components
|
# Additional type and shape checking to ensure that the components
|
||||||
# of the generated element match the `output_types` and `output_shapes`
|
# of the generated element match the `output_types` and `output_shapes`
|
||||||
|
@ -451,17 +451,22 @@ def get_error_intro(tf_error):
|
|||||||
sample commands for debugging.
|
sample commands for debugging.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if hasattr(tf_error, "op") and hasattr(tf_error.op, "name"):
|
||||||
op_name = tf_error.op.name
|
op_name = tf_error.op.name
|
||||||
|
else:
|
||||||
|
op_name = None
|
||||||
|
|
||||||
intro_lines = [
|
intro_lines = [
|
||||||
"--------------------------------------",
|
"--------------------------------------",
|
||||||
RL("!!! An error occurred during the run !!!", "blink"),
|
RL("!!! An error occurred during the run !!!", "blink"),
|
||||||
"",
|
"",
|
||||||
"You may use the following commands to debug:",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
out = debugger_cli_common.rich_text_lines_from_rich_line_list(intro_lines)
|
out = debugger_cli_common.rich_text_lines_from_rich_line_list(intro_lines)
|
||||||
|
|
||||||
|
if op_name is not None:
|
||||||
|
out.extend(debugger_cli_common.RichTextLines(
|
||||||
|
["You may use the following commands to debug:"]))
|
||||||
out.extend(
|
out.extend(
|
||||||
_recommend_command("ni -a -d -t %s" % op_name,
|
_recommend_command("ni -a -d -t %s" % op_name,
|
||||||
"Inspect information about the failing op.",
|
"Inspect information about the failing op.",
|
||||||
@ -476,17 +481,18 @@ def get_error_intro(tf_error):
|
|||||||
"lt",
|
"lt",
|
||||||
"List all tensors dumped during the failing run() call.",
|
"List all tensors dumped during the failing run() call.",
|
||||||
create_link=True))
|
create_link=True))
|
||||||
|
else:
|
||||||
|
out.extend(debugger_cli_common.RichTextLines([
|
||||||
|
"WARNING: Cannot determine the name of the op that caused the error."]))
|
||||||
|
|
||||||
more_lines = [
|
more_lines = [
|
||||||
"",
|
"",
|
||||||
"Op name: " + op_name,
|
"Op name: %s" % op_name,
|
||||||
"Error type: " + str(type(tf_error)),
|
"Error type: " + str(type(tf_error)),
|
||||||
"",
|
"",
|
||||||
"Details:",
|
"Details:",
|
||||||
str(tf_error),
|
str(tf_error),
|
||||||
"",
|
"",
|
||||||
"WARNING: Using client GraphDef due to the error, instead of "
|
|
||||||
"executor GraphDefs.",
|
|
||||||
"--------------------------------------",
|
"--------------------------------------",
|
||||||
"",
|
"",
|
||||||
]
|
]
|
||||||
|
@ -372,6 +372,11 @@ class GetErrorIntroTest(test_util.TensorFlowTestCase):
|
|||||||
self.assertEqual("Details:", error_intro.lines[14])
|
self.assertEqual("Details:", error_intro.lines[14])
|
||||||
self.assertStartsWith(error_intro.lines[15], "foo description")
|
self.assertStartsWith(error_intro.lines[15], "foo description")
|
||||||
|
|
||||||
|
def testGetErrorIntroForNoOpName(self):
|
||||||
|
tf_error = errors.OpError(None, None, "Fake OpError", -1)
|
||||||
|
error_intro = cli_shared.get_error_intro(tf_error)
|
||||||
|
self.assertIn("Cannot determine the name of the op", error_intro.lines[3])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
googletest.main()
|
googletest.main()
|
||||||
|
@ -69,6 +69,12 @@ run
|
|||||||
exit
|
exit
|
||||||
EOF
|
EOF
|
||||||
|
|
||||||
|
cat << EOF | ${DEBUG_ERRORS_BIN} --error=uninitialized_variable --debug --ui_type=readline
|
||||||
|
run
|
||||||
|
ni -a -d -t v/read
|
||||||
|
exit
|
||||||
|
EOF
|
||||||
|
|
||||||
cat << EOF | ${DEBUG_MNIST_BIN} --debug --max_steps=1 --fake_data --ui_type=readline
|
cat << EOF | ${DEBUG_MNIST_BIN} --debug --max_steps=1 --fake_data --ui_type=readline
|
||||||
run -t 1
|
run -t 1
|
||||||
run --node_name_filter hidden --op_type_filter MatMul
|
run --node_name_filter hidden --op_type_filter MatMul
|
||||||
|
@ -748,7 +748,7 @@ class DebugDumpDir(object):
|
|||||||
return sum(len(self._dump_tensor_data[device_name])
|
return sum(len(self._dump_tensor_data[device_name])
|
||||||
for device_name in self._dump_tensor_data)
|
for device_name in self._dump_tensor_data)
|
||||||
|
|
||||||
def _load_partition_graphs(self, partition_graphs, validate):
|
def _load_partition_graphs(self, client_partition_graphs, validate):
|
||||||
"""Load and process partition graphs.
|
"""Load and process partition graphs.
|
||||||
|
|
||||||
Load the graphs; parse the input and control input structure; obtain the
|
Load the graphs; parse the input and control input structure; obtain the
|
||||||
@ -757,8 +757,10 @@ class DebugDumpDir(object):
|
|||||||
tensor dumps.
|
tensor dumps.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
partition_graphs: A repeated field of GraphDefs representing the
|
client_partition_graphs: A repeated field of GraphDefs representing the
|
||||||
partition graphs executed by the TensorFlow runtime.
|
partition graphs executed by the TensorFlow runtime, from the Python
|
||||||
|
client. These partition graphs are used only if partition graphs
|
||||||
|
cannot be loaded from the dump directory on the file system.
|
||||||
validate: (`bool`) Whether the dump files are to be validated against the
|
validate: (`bool`) Whether the dump files are to be validated against the
|
||||||
partition graphs.
|
partition graphs.
|
||||||
|
|
||||||
@ -769,10 +771,6 @@ class DebugDumpDir(object):
|
|||||||
self._debug_graphs = {}
|
self._debug_graphs = {}
|
||||||
self._node_devices = {}
|
self._node_devices = {}
|
||||||
|
|
||||||
if partition_graphs:
|
|
||||||
partition_graphs_and_device_names = [
|
|
||||||
(partition_graph, None) for partition_graph in partition_graphs]
|
|
||||||
else:
|
|
||||||
partition_graphs_and_device_names = []
|
partition_graphs_and_device_names = []
|
||||||
for device_name in self._device_names:
|
for device_name in self._device_names:
|
||||||
partition_graph = None
|
partition_graph = None
|
||||||
@ -780,13 +778,16 @@ class DebugDumpDir(object):
|
|||||||
partition_graph = _load_graph_def_from_event_file(
|
partition_graph = _load_graph_def_from_event_file(
|
||||||
self._dump_graph_file_paths[device_name])
|
self._dump_graph_file_paths[device_name])
|
||||||
else:
|
else:
|
||||||
partition_graph = self._find_partition_graph(partition_graphs,
|
logging.warn(
|
||||||
|
"Failed to load partition graphs for device %s from disk. "
|
||||||
|
"As a fallback, the client graphs will be used. This "
|
||||||
|
"may cause mismatches in device names." % device_name)
|
||||||
|
partition_graph = self._find_partition_graph(client_partition_graphs,
|
||||||
device_name)
|
device_name)
|
||||||
|
|
||||||
if partition_graph:
|
if partition_graph:
|
||||||
partition_graphs_and_device_names.append((partition_graph,
|
partition_graphs_and_device_names.append((partition_graph,
|
||||||
device_name))
|
device_name))
|
||||||
else:
|
|
||||||
logging.warn("Failed to load partition graphs from disk.")
|
|
||||||
|
|
||||||
for partition_graph, maybe_device_name in partition_graphs_and_device_names:
|
for partition_graph, maybe_device_name in partition_graphs_and_device_names:
|
||||||
debug_graph = debug_graphs.DebugGraph(partition_graph,
|
debug_graph = debug_graphs.DebugGraph(partition_graph,
|
||||||
|
@ -1873,6 +1873,8 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
|
|||||||
delete backward_function;
|
delete backward_function;
|
||||||
});
|
});
|
||||||
|
|
||||||
|
Py_DECREF(num_inputs);
|
||||||
|
|
||||||
Py_RETURN_NONE;
|
Py_RETURN_NONE;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1931,8 +1933,10 @@ bool ReadVariableOp(const FastPathOpExecInfo& parent_op_exec_info,
|
|||||||
Py_INCREF(output->get()); // stay alive after since tuple steals.
|
Py_INCREF(output->get()); // stay alive after since tuple steals.
|
||||||
PyTuple_SET_ITEM(outputs.get(), 0, output->get());
|
PyTuple_SET_ITEM(outputs.get(), 0, output->get());
|
||||||
|
|
||||||
if (!RecordGradient(GetPythonObjectFromString("ReadVariableOp"),
|
tensorflow::Safe_PyObjectPtr op_string(
|
||||||
inputs.get(), Py_None, outputs.get(), Py_None)) {
|
GetPythonObjectFromString("ReadVariableOp"));
|
||||||
|
if (!RecordGradient(op_string.get(), inputs.get(), Py_None, outputs.get(),
|
||||||
|
Py_None)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1242,11 +1242,11 @@ class TensorFlowTestCase(googletest.TestCase):
|
|||||||
b,
|
b,
|
||||||
rtol=rtol,
|
rtol=rtol,
|
||||||
atol=atol,
|
atol=atol,
|
||||||
msg="Mismatched value: a%s is different from b%s." % (path_str,
|
msg=("Mismatched value: a%s is different from b%s. %s" %
|
||||||
path_str))
|
(path_str, path_str, msg)))
|
||||||
except TypeError as e:
|
except TypeError as e:
|
||||||
msg = "Error: a%s has %s, but b%s has %s" % (path_str, type(a),
|
msg = ("Error: a%s has %s, but b%s has %s. %s" %
|
||||||
path_str, type(b))
|
(path_str, type(a), path_str, type(b), msg))
|
||||||
e.args = ((e.args[0] + " : " + msg,) + e.args[1:])
|
e.args = ((e.args[0] + " : " + msg,) + e.args[1:])
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
@ -118,6 +118,7 @@ class LocallyConnectedLayersTest(test.TestCase):
|
|||||||
},
|
},
|
||||||
input_shape=(num_samples, num_row, num_col, stack_size))
|
input_shape=(num_samples, num_row, num_col, stack_size))
|
||||||
|
|
||||||
|
@tf_test_util.run_in_graph_and_eager_modes()
|
||||||
def test_locallyconnected_2d_channels_first(self):
|
def test_locallyconnected_2d_channels_first(self):
|
||||||
num_samples = 8
|
num_samples = 8
|
||||||
filters = 3
|
filters = 3
|
||||||
@ -125,7 +126,6 @@ class LocallyConnectedLayersTest(test.TestCase):
|
|||||||
num_row = 6
|
num_row = 6
|
||||||
num_col = 10
|
num_col = 10
|
||||||
|
|
||||||
with self.test_session():
|
|
||||||
testing_utils.layer_test(
|
testing_utils.layer_test(
|
||||||
keras.layers.LocallyConnected2D,
|
keras.layers.LocallyConnected2D,
|
||||||
kwargs={
|
kwargs={
|
||||||
|
@ -241,6 +241,12 @@ class UnaryOpTest(test.TestCase):
|
|||||||
math_ops.lgamma)
|
math_ops.lgamma)
|
||||||
self._compareBoth(x, np.vectorize(math.erf), math_ops.erf)
|
self._compareBoth(x, np.vectorize(math.erf), math_ops.erf)
|
||||||
self._compareBoth(x, np.vectorize(math.erfc), math_ops.erfc)
|
self._compareBoth(x, np.vectorize(math.erfc), math_ops.erfc)
|
||||||
|
try:
|
||||||
|
from scipy import special # pylint: disable=g-import-not-at-top
|
||||||
|
self._compareBoth(x, special.i0e, math_ops.bessel_i0e)
|
||||||
|
self._compareBoth(x, special.i1e, math_ops.bessel_i1e)
|
||||||
|
except ImportError as e:
|
||||||
|
tf_logging.warn("Cannot test special functions: %s" % str(e))
|
||||||
|
|
||||||
self._compareBothSparse(x, np.abs, math_ops.abs)
|
self._compareBothSparse(x, np.abs, math_ops.abs)
|
||||||
self._compareBothSparse(x, np.negative, math_ops.negative)
|
self._compareBothSparse(x, np.negative, math_ops.negative)
|
||||||
@ -286,6 +292,12 @@ class UnaryOpTest(test.TestCase):
|
|||||||
self._compareBoth(x, np.arcsin, math_ops.asin)
|
self._compareBoth(x, np.arcsin, math_ops.asin)
|
||||||
self._compareBoth(x, np.arccos, math_ops.acos)
|
self._compareBoth(x, np.arccos, math_ops.acos)
|
||||||
self._compareBoth(x, np.arctan, math_ops.atan)
|
self._compareBoth(x, np.arctan, math_ops.atan)
|
||||||
|
try:
|
||||||
|
from scipy import special # pylint: disable=g-import-not-at-top
|
||||||
|
self._compareBoth(x, special.i0e, math_ops.bessel_i0e)
|
||||||
|
self._compareBoth(x, special.i1e, math_ops.bessel_i1e)
|
||||||
|
except ImportError as e:
|
||||||
|
tf_logging.warn("Cannot test special functions: %s" % str(e))
|
||||||
|
|
||||||
self._compareBothSparse(x, np.abs, math_ops.abs)
|
self._compareBothSparse(x, np.abs, math_ops.abs)
|
||||||
self._compareBothSparse(x, np.negative, math_ops.negative)
|
self._compareBothSparse(x, np.negative, math_ops.negative)
|
||||||
@ -334,6 +346,12 @@ class UnaryOpTest(test.TestCase):
|
|||||||
self._compareBoth(k, np.arcsin, math_ops.asin)
|
self._compareBoth(k, np.arcsin, math_ops.asin)
|
||||||
self._compareBoth(k, np.arccos, math_ops.acos)
|
self._compareBoth(k, np.arccos, math_ops.acos)
|
||||||
self._compareBoth(k, np.tan, math_ops.tan)
|
self._compareBoth(k, np.tan, math_ops.tan)
|
||||||
|
try:
|
||||||
|
from scipy import special # pylint: disable=g-import-not-at-top
|
||||||
|
self._compareBoth(x, special.i0e, math_ops.bessel_i0e)
|
||||||
|
self._compareBoth(x, special.i1e, math_ops.bessel_i1e)
|
||||||
|
except ImportError as e:
|
||||||
|
tf_logging.warn("Cannot test special functions: %s" % str(e))
|
||||||
|
|
||||||
self._compareBothSparse(x, np.abs, math_ops.abs)
|
self._compareBothSparse(x, np.abs, math_ops.abs)
|
||||||
self._compareBothSparse(x, np.negative, math_ops.negative)
|
self._compareBothSparse(x, np.negative, math_ops.negative)
|
||||||
@ -370,6 +388,12 @@ class UnaryOpTest(test.TestCase):
|
|||||||
math_ops.lgamma)
|
math_ops.lgamma)
|
||||||
self._compareBoth(x, np.vectorize(math.erf), math_ops.erf)
|
self._compareBoth(x, np.vectorize(math.erf), math_ops.erf)
|
||||||
self._compareBoth(x, np.vectorize(math.erfc), math_ops.erfc)
|
self._compareBoth(x, np.vectorize(math.erfc), math_ops.erfc)
|
||||||
|
try:
|
||||||
|
from scipy import special # pylint: disable=g-import-not-at-top
|
||||||
|
self._compareBoth(x, special.i0e, math_ops.bessel_i0e)
|
||||||
|
self._compareBoth(x, special.i1e, math_ops.bessel_i1e)
|
||||||
|
except ImportError as e:
|
||||||
|
tf_logging.warn("Cannot test special functions: %s" % str(e))
|
||||||
|
|
||||||
self._compareBothSparse(x, np.abs, math_ops.abs)
|
self._compareBothSparse(x, np.abs, math_ops.abs)
|
||||||
self._compareBothSparse(x, np.negative, math_ops.negative)
|
self._compareBothSparse(x, np.negative, math_ops.negative)
|
||||||
|
@ -939,7 +939,8 @@ class ResizeMethod(object):
|
|||||||
def resize_images(images,
|
def resize_images(images,
|
||||||
size,
|
size,
|
||||||
method=ResizeMethod.BILINEAR,
|
method=ResizeMethod.BILINEAR,
|
||||||
align_corners=False):
|
align_corners=False,
|
||||||
|
preserve_aspect_ratio=False):
|
||||||
"""Resize `images` to `size` using the specified `method`.
|
"""Resize `images` to `size` using the specified `method`.
|
||||||
|
|
||||||
Resized images will be distorted if their original aspect ratio is not
|
Resized images will be distorted if their original aspect ratio is not
|
||||||
@ -971,6 +972,10 @@ def resize_images(images,
|
|||||||
align_corners: bool. If True, the centers of the 4 corner pixels of the
|
align_corners: bool. If True, the centers of the 4 corner pixels of the
|
||||||
input and output tensors are aligned, preserving the values at the
|
input and output tensors are aligned, preserving the values at the
|
||||||
corner pixels. Defaults to `False`.
|
corner pixels. Defaults to `False`.
|
||||||
|
preserve_aspect_ratio: Whether to preserve the aspect ratio. If this is set,
|
||||||
|
then `images` will be resized to a size that fits in `size` while
|
||||||
|
preserving the aspect ratio of the original image. Scales up the image if
|
||||||
|
`size` is bigger than the current size of the `image`. Defaults to False.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: if the shape of `images` is incompatible with the
|
ValueError: if the shape of `images` is incompatible with the
|
||||||
@ -1009,6 +1014,28 @@ def resize_images(images,
|
|||||||
new_height_const = size_const_as_shape[0].value
|
new_height_const = size_const_as_shape[0].value
|
||||||
new_width_const = size_const_as_shape[1].value
|
new_width_const = size_const_as_shape[1].value
|
||||||
|
|
||||||
|
if preserve_aspect_ratio:
|
||||||
|
# Get the current shapes of the image, even if dynamic.
|
||||||
|
_, current_height, current_width, _ = _ImageDimensions(images, rank=4)
|
||||||
|
|
||||||
|
# do the computation to find the right scale and height/width.
|
||||||
|
scale_factor_height = (math_ops.to_float(new_height_const) /
|
||||||
|
math_ops.to_float(current_height))
|
||||||
|
scale_factor_width = (math_ops.to_float(new_width_const) /
|
||||||
|
math_ops.to_float(current_width))
|
||||||
|
scale_factor = math_ops.minimum(scale_factor_height, scale_factor_width)
|
||||||
|
scaled_height_const = math_ops.to_int32(scale_factor *
|
||||||
|
math_ops.to_float(current_height))
|
||||||
|
scaled_width_const = math_ops.to_int32(scale_factor *
|
||||||
|
math_ops.to_float(current_width))
|
||||||
|
|
||||||
|
# NOTE: Reset the size and other constants used later.
|
||||||
|
size = ops.convert_to_tensor([scaled_height_const, scaled_width_const],
|
||||||
|
dtypes.int32, name='size')
|
||||||
|
size_const_as_shape = tensor_util.constant_value_as_shape(size)
|
||||||
|
new_height_const = size_const_as_shape[0].value
|
||||||
|
new_width_const = size_const_as_shape[1].value
|
||||||
|
|
||||||
# If we can determine that the height and width will be unmodified by this
|
# If we can determine that the height and width will be unmodified by this
|
||||||
# transformation, we avoid performing the resize.
|
# transformation, we avoid performing the resize.
|
||||||
if all(x is not None
|
if all(x is not None
|
||||||
@ -1469,6 +1496,75 @@ def adjust_hue(image, delta, name=None):
|
|||||||
return convert_image_dtype(rgb_altered, orig_dtype)
|
return convert_image_dtype(rgb_altered, orig_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
# pylint: disable=invalid-name
|
||||||
|
@tf_export('image.random_jpeg_quality')
|
||||||
|
def random_jpeg_quality(image, min_jpeg_quality, max_jpeg_quality, seed=None):
|
||||||
|
"""Randomly changes jpeg encoding quality for inducing jpeg noise.
|
||||||
|
|
||||||
|
`min_jpeg_quality` must be in the interval `[0, 100]` and less than
|
||||||
|
`max_jpeg_quality`.
|
||||||
|
`max_jpeg_quality` must be in the interval `[0, 100]`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: RGB image or images. Size of the last dimension must be 3.
|
||||||
|
min_jpeg_quality: Minimum jpeg encoding quality to use.
|
||||||
|
max_jpeg_quality: Maximum jpeg encoding quality to use.
|
||||||
|
seed: An operation-specific seed. It will be used in conjunction
|
||||||
|
with the graph-level seed to determine the real seeds that will be
|
||||||
|
used in this operation. Please see the documentation of
|
||||||
|
set_random_seed for its interaction with the graph-level random seed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Adjusted image(s), same shape and DType as `image`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if `min_jpeg_quality` or `max_jpeg_quality` is invalid.
|
||||||
|
"""
|
||||||
|
if (min_jpeg_quality < 0 or max_jpeg_quality < 0 or
|
||||||
|
min_jpeg_quality > 100 or max_jpeg_quality > 100):
|
||||||
|
raise ValueError('jpeg encoding range must be between 0 and 100.')
|
||||||
|
|
||||||
|
if min_jpeg_quality >= max_jpeg_quality:
|
||||||
|
raise ValueError('`min_jpeg_quality` must be less than `max_jpeg_quality`.')
|
||||||
|
|
||||||
|
np.random.seed(seed)
|
||||||
|
jpeg_quality = np.random.randint(min_jpeg_quality, max_jpeg_quality)
|
||||||
|
return adjust_jpeg_quality(image, jpeg_quality)
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export('image.adjust_jpeg_quality')
|
||||||
|
def adjust_jpeg_quality(image, jpeg_quality, name=None):
|
||||||
|
"""Adjust jpeg encoding quality of an RGB image.
|
||||||
|
|
||||||
|
This is a convenience method that adjusts jpeg encoding quality of an
|
||||||
|
RGB image.
|
||||||
|
|
||||||
|
`image` is an RGB image. The image's encoding quality is adjusted
|
||||||
|
to `jpeg_quality`.
|
||||||
|
`jpeg_quality` must be in the interval `[0, 100]`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: RGB image or images. Size of the last dimension must be 3.
|
||||||
|
jpeg_quality: int. jpeg encoding quality.
|
||||||
|
name: A name for this operation (optional).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Adjusted image(s), same shape and DType as `image`.
|
||||||
|
"""
|
||||||
|
with ops.name_scope(name, 'adjust_jpeg_quality', [image]) as name:
|
||||||
|
image = ops.convert_to_tensor(image, name='image')
|
||||||
|
# Remember original dtype to so we can convert back if needed
|
||||||
|
orig_dtype = image.dtype
|
||||||
|
# Convert to uint8
|
||||||
|
image = convert_image_dtype(image, dtypes.uint8)
|
||||||
|
# Encode image to jpeg with given jpeg quality
|
||||||
|
image = gen_image_ops.encode_jpeg(image, quality=jpeg_quality)
|
||||||
|
# Decode jpeg image
|
||||||
|
image = gen_image_ops.decode_jpeg(image)
|
||||||
|
# Convert back to original dtype and return
|
||||||
|
return convert_image_dtype(image, orig_dtype)
|
||||||
|
|
||||||
|
|
||||||
@tf_export('image.random_saturation')
|
@tf_export('image.random_saturation')
|
||||||
def random_saturation(image, lower, upper, seed=None):
|
def random_saturation(image, lower, upper, seed=None):
|
||||||
"""Adjust the saturation of an RGB image by a random factor.
|
"""Adjust the saturation of an RGB image by a random factor.
|
||||||
|
@ -2599,6 +2599,86 @@ class ResizeImagesTest(test_util.TensorFlowTestCase):
|
|||||||
y = image_ops.resize_images(single_image, [55, 66])
|
y = image_ops.resize_images(single_image, [55, 66])
|
||||||
self.assertTrue(y.op.name.startswith("resize_images"))
|
self.assertTrue(y.op.name.startswith("resize_images"))
|
||||||
|
|
||||||
|
def _ResizeImageCall(self, x, max_h, max_w, preserve_aspect_ratio,
|
||||||
|
use_tensor_inputs):
|
||||||
|
if use_tensor_inputs:
|
||||||
|
target_max = ops.convert_to_tensor([max_h, max_w])
|
||||||
|
x_tensor = array_ops.placeholder(x.dtype, shape=[None] * x.ndim)
|
||||||
|
feed_dict = {x_tensor: x}
|
||||||
|
else:
|
||||||
|
target_max = [max_h, max_w]
|
||||||
|
x_tensor = x
|
||||||
|
feed_dict = {}
|
||||||
|
|
||||||
|
y = image_ops.resize_images(x_tensor, target_max,
|
||||||
|
preserve_aspect_ratio=preserve_aspect_ratio)
|
||||||
|
|
||||||
|
with self.test_session(use_gpu=True):
|
||||||
|
return y.eval(feed_dict=feed_dict)
|
||||||
|
|
||||||
|
def _assertResizeEqual(self, x, x_shape, y, y_shape,
|
||||||
|
preserve_aspect_ratio=True,
|
||||||
|
use_tensor_inputs_options=None):
|
||||||
|
use_tensor_inputs_options = use_tensor_inputs_options or [False, True]
|
||||||
|
target_height, target_width, _ = y_shape
|
||||||
|
x = np.array(x).reshape(x_shape)
|
||||||
|
y = np.array(y).reshape(y_shape)
|
||||||
|
|
||||||
|
for use_tensor_inputs in use_tensor_inputs_options:
|
||||||
|
y_tf = self._ResizeImageCall(x, target_height, target_width,
|
||||||
|
preserve_aspect_ratio, use_tensor_inputs)
|
||||||
|
self.assertAllClose(y, y_tf)
|
||||||
|
|
||||||
|
def _assertResizeCheckShape(self, x, x_shape, target_shape,
|
||||||
|
y_shape, preserve_aspect_ratio=True,
|
||||||
|
use_tensor_inputs_options=None):
|
||||||
|
use_tensor_inputs_options = use_tensor_inputs_options or [False, True]
|
||||||
|
target_height, target_width = target_shape
|
||||||
|
x = np.array(x).reshape(x_shape)
|
||||||
|
y = np.zeros(y_shape)
|
||||||
|
|
||||||
|
for use_tensor_inputs in use_tensor_inputs_options:
|
||||||
|
y_tf = self._ResizeImageCall(x, target_height, target_width,
|
||||||
|
preserve_aspect_ratio, use_tensor_inputs)
|
||||||
|
self.assertShapeEqual(y, ops.convert_to_tensor(y_tf))
|
||||||
|
|
||||||
|
def testPreserveAspectRatioMultipleImages(self):
|
||||||
|
x_shape = [10, 100, 100, 10]
|
||||||
|
x = np.random.uniform(size=x_shape)
|
||||||
|
|
||||||
|
self._assertResizeCheckShape(x, x_shape, [250, 250], [10, 250, 250, 10],
|
||||||
|
preserve_aspect_ratio=False)
|
||||||
|
|
||||||
|
def testPreserveAspectRatioNoOp(self):
|
||||||
|
x_shape = [10, 10, 10]
|
||||||
|
x = np.random.uniform(size=x_shape)
|
||||||
|
|
||||||
|
self._assertResizeEqual(x, x_shape, x, x_shape)
|
||||||
|
|
||||||
|
def testPreserveAspectRatioSmaller(self):
|
||||||
|
x_shape = [100, 100, 10]
|
||||||
|
x = np.random.uniform(size=x_shape)
|
||||||
|
|
||||||
|
self._assertResizeCheckShape(x, x_shape, [75, 50], [50, 50, 10])
|
||||||
|
|
||||||
|
def testPreserveAspectRatioSmallerMultipleImages(self):
|
||||||
|
x_shape = [10, 100, 100, 10]
|
||||||
|
x = np.random.uniform(size=x_shape)
|
||||||
|
|
||||||
|
self._assertResizeCheckShape(x, x_shape, [75, 50], [10, 50, 50, 10])
|
||||||
|
|
||||||
|
def testPreserveAspectRatioLarger(self):
|
||||||
|
x_shape = [100, 100, 10]
|
||||||
|
x = np.random.uniform(size=x_shape)
|
||||||
|
|
||||||
|
self._assertResizeCheckShape(x, x_shape, [150, 200], [150, 150, 10])
|
||||||
|
|
||||||
|
def testPreserveAspectRatioSameRatio(self):
|
||||||
|
x_shape = [1920, 1080, 3]
|
||||||
|
x = np.random.uniform(size=x_shape)
|
||||||
|
|
||||||
|
self._assertResizeCheckShape(x, x_shape, [3840, 2160], [3840, 2160, 3])
|
||||||
|
|
||||||
|
|
||||||
class ResizeImageWithCropOrPadTest(test_util.TensorFlowTestCase):
|
class ResizeImageWithCropOrPadTest(test_util.TensorFlowTestCase):
|
||||||
|
|
||||||
|
@ -620,6 +620,35 @@ def _DigammaGrad(op, grad):
|
|||||||
return grad * math_ops.polygamma(array_ops.constant(1, dtype=x.dtype), x)
|
return grad * math_ops.polygamma(array_ops.constant(1, dtype=x.dtype), x)
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterGradient("BesselI0e")
|
||||||
|
def _BesselI0eGrad(op, grad):
|
||||||
|
"""Compute gradient of bessel_i0e(x) with respect to its argument."""
|
||||||
|
x = op.inputs[0]
|
||||||
|
y = op.outputs[0]
|
||||||
|
with ops.control_dependencies([grad]):
|
||||||
|
return grad * (math_ops.bessel_i1e(x) - math_ops.sign(x) * y)
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterGradient("BesselI1e")
|
||||||
|
def _BesselI1eGrad(op, grad):
|
||||||
|
"""Compute gradient of bessel_i1e(x) with respect to its argument."""
|
||||||
|
x = op.inputs[0]
|
||||||
|
y = op.outputs[0]
|
||||||
|
with ops.control_dependencies([grad]):
|
||||||
|
# For x = 0, the correct gradient is 0.5.
|
||||||
|
# However, the main branch gives NaN because of the division by x, so
|
||||||
|
# we impute the gradient manually.
|
||||||
|
# An alternative solution is to express the gradient via bessel_i0e and
|
||||||
|
# bessel_i2e, but the latter is not yet implemented in Eigen.
|
||||||
|
eps = np.finfo(x.dtype.as_numpy_dtype).eps
|
||||||
|
zeros = array_ops.zeros_like(x)
|
||||||
|
x_is_not_tiny = math_ops.abs(x) > eps
|
||||||
|
safe_x = array_ops.where(x_is_not_tiny, x, eps + zeros)
|
||||||
|
dy_dx = math_ops.bessel_i0e(safe_x) - y * (
|
||||||
|
math_ops.sign(safe_x) + math_ops.reciprocal(safe_x))
|
||||||
|
return grad * array_ops.where(x_is_not_tiny, dy_dx, 0.5 + zeros)
|
||||||
|
|
||||||
|
|
||||||
@ops.RegisterGradient("Igamma")
|
@ops.RegisterGradient("Igamma")
|
||||||
def _IgammaGrad(op, grad):
|
def _IgammaGrad(op, grad):
|
||||||
"""Returns gradient of igamma(a, x) with respect to x."""
|
"""Returns gradient of igamma(a, x) with respect to x."""
|
||||||
|
@ -2954,6 +2954,67 @@ def polyval(coeffs, x, name=None):
|
|||||||
p = c + p * x
|
p = c + p * x
|
||||||
return p
|
return p
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export("math.bessel_i0e")
|
||||||
|
def bessel_i0e(x, name=None):
|
||||||
|
"""Computes the Bessel i0e function of `x` element-wise.
|
||||||
|
|
||||||
|
Exponentially scaled modified Bessel function of order 0 defined as
|
||||||
|
`bessel_i0e(x) = exp(-abs(x)) bessel_i0(x)`.
|
||||||
|
|
||||||
|
This function is faster and numerically stabler than `bessel_i0(x)`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
|
||||||
|
`float32`, `float64`.
|
||||||
|
name: A name for the operation (optional).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
|
||||||
|
|
||||||
|
@compatibility(scipy)
|
||||||
|
Equivalent to scipy.special.i0e
|
||||||
|
@end_compatibility
|
||||||
|
"""
|
||||||
|
with ops.name_scope(name, "bessel_i0e", [x]) as name:
|
||||||
|
if isinstance(x, sparse_tensor.SparseTensor):
|
||||||
|
x_i0e = gen_math_ops.bessel_i0e(x.values, name=name)
|
||||||
|
return sparse_tensor.SparseTensor(
|
||||||
|
indices=x.indices, values=x_i0e, dense_shape=x.dense_shape)
|
||||||
|
else:
|
||||||
|
return gen_math_ops.bessel_i0e(x, name=name)
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export("math.bessel_i1e")
|
||||||
|
def bessel_i1e(x, name=None):
|
||||||
|
"""Computes the Bessel i1e function of `x` element-wise.
|
||||||
|
|
||||||
|
Exponentially scaled modified Bessel function of order 1 defined as
|
||||||
|
`bessel_i1e(x) = exp(-abs(x)) bessel_i1(x)`.
|
||||||
|
|
||||||
|
This function is faster and numerically stabler than `bessel_i1(x)`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
|
||||||
|
`float32`, `float64`.
|
||||||
|
name: A name for the operation (optional).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
|
||||||
|
|
||||||
|
@compatibility(scipy)
|
||||||
|
Equivalent to scipy.special.i1e
|
||||||
|
@end_compatibility
|
||||||
|
"""
|
||||||
|
with ops.name_scope(name, "bessel_i1e", [x]) as name:
|
||||||
|
if isinstance(x, sparse_tensor.SparseTensor):
|
||||||
|
x_i1e = gen_math_ops.bessel_i1e(x.values, name=name)
|
||||||
|
return sparse_tensor.SparseTensor(
|
||||||
|
indices=x.indices, values=x_i1e, dense_shape=x.dense_shape)
|
||||||
|
else:
|
||||||
|
return gen_math_ops.bessel_i1e(x, name=name)
|
||||||
|
|
||||||
|
|
||||||
# FFT ops were moved to tf.spectral. tf.fft symbols were part of the TensorFlow
|
# FFT ops were moved to tf.spectral. tf.fft symbols were part of the TensorFlow
|
||||||
# 1.0 API so we leave these here for backwards compatibility.
|
# 1.0 API so we leave these here for backwards compatibility.
|
||||||
fft = gen_spectral_ops.fft
|
fft = gen_spectral_ops.fft
|
||||||
|
@ -34,16 +34,49 @@ from tensorflow.python.ops import state_ops
|
|||||||
from tensorflow.python.ops import variable_scope
|
from tensorflow.python.ops import variable_scope
|
||||||
from tensorflow.python.ops import weights_broadcast_ops
|
from tensorflow.python.ops import weights_broadcast_ops
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
|
from tensorflow.python.training import distribute as distribute_lib
|
||||||
from tensorflow.python.util.deprecation import deprecated
|
from tensorflow.python.util.deprecation import deprecated
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
|
||||||
def metric_variable(shape, dtype, validate_shape=True, name=None):
|
def metric_variable(shape, dtype, validate_shape=True, name=None):
|
||||||
"""Create variable in `GraphKeys.(LOCAL|METRIC_VARIABLES`) collections."""
|
"""Create variable in `GraphKeys.(LOCAL|METRIC_VARIABLES)` collections.
|
||||||
|
|
||||||
|
If running in a `DistributionStrategy` context, the variable will be
|
||||||
|
"tower local". This means:
|
||||||
|
|
||||||
|
* The returned object will be a container with separate variables
|
||||||
|
per replica/tower of the model.
|
||||||
|
|
||||||
|
* When writing to the variable, e.g. using `assign_add` in a metric
|
||||||
|
update, the update will be applied to the variable local to the
|
||||||
|
replica/tower.
|
||||||
|
|
||||||
|
* To get a metric's result value, we need to sum the variable values
|
||||||
|
across the replicas/towers before computing the final answer.
|
||||||
|
Furthermore, the final answer should be computed once instead of
|
||||||
|
in every replica/tower. Both of these are accomplished by
|
||||||
|
running the computation of the final result value inside
|
||||||
|
`tf.contrib.distribute.get_tower_context().merge_call(fn)`.
|
||||||
|
Inside the `merge_call()`, ops are only added to the graph once
|
||||||
|
and access to a tower-local variable in a computation returns
|
||||||
|
the sum across all replicas/towers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
shape: Shape of the created variable.
|
||||||
|
dtype: Type of the created variable.
|
||||||
|
validate_shape: (Optional) Whether shape validation is enabled for
|
||||||
|
the created variable.
|
||||||
|
name: (Optional) String name of the created variable.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A (non-trainable) variable initialized to zero, or if inside a
|
||||||
|
`DistributionStrategy` scope a tower-local variable container.
|
||||||
|
"""
|
||||||
|
with distribute_lib.get_tower_context().tower_local_var_scope('sum'):
|
||||||
|
# Note that "tower local" implies trainable=False.
|
||||||
return variable_scope.variable(
|
return variable_scope.variable(
|
||||||
lambda: array_ops.zeros(shape, dtype),
|
lambda: array_ops.zeros(shape, dtype),
|
||||||
trainable=False,
|
|
||||||
collections=[
|
collections=[
|
||||||
ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.METRIC_VARIABLES
|
ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.METRIC_VARIABLES
|
||||||
],
|
],
|
||||||
@ -333,11 +366,15 @@ def mean(values,
|
|||||||
with ops.control_dependencies([values]):
|
with ops.control_dependencies([values]):
|
||||||
update_count_op = state_ops.assign_add(count, num_values)
|
update_count_op = state_ops.assign_add(count, num_values)
|
||||||
|
|
||||||
mean_t = _safe_div(total, count, 'value')
|
def aggregate_across_towers(_, t, c):
|
||||||
update_op = _safe_div(update_total_op, update_count_op, 'update_op')
|
mean_t = _safe_div(t, c, 'value')
|
||||||
|
|
||||||
if metrics_collections:
|
if metrics_collections:
|
||||||
ops.add_to_collections(metrics_collections, mean_t)
|
ops.add_to_collections(metrics_collections, mean_t)
|
||||||
|
return mean_t
|
||||||
|
|
||||||
|
mean_t = distribute_lib.get_tower_context().merge_call(
|
||||||
|
aggregate_across_towers, total, count)
|
||||||
|
update_op = _safe_div(update_total_op, update_count_op, 'update_op')
|
||||||
|
|
||||||
if updates_collections:
|
if updates_collections:
|
||||||
ops.add_to_collections(updates_collections, update_op)
|
ops.add_to_collections(updates_collections, update_op)
|
||||||
@ -572,6 +609,17 @@ def _confusion_matrix_at_thresholds(labels,
|
|||||||
return values, update_ops
|
return values, update_ops
|
||||||
|
|
||||||
|
|
||||||
|
def _aggregate_variable(v, collections):
|
||||||
|
|
||||||
|
def f(distribution, value):
|
||||||
|
value = distribution.read_var(value)
|
||||||
|
if collections:
|
||||||
|
ops.add_to_collections(collections, value)
|
||||||
|
return value
|
||||||
|
|
||||||
|
return distribute_lib.get_tower_context().merge_call(f, v)
|
||||||
|
|
||||||
|
|
||||||
@tf_export('metrics.auc')
|
@tf_export('metrics.auc')
|
||||||
def auc(labels,
|
def auc(labels,
|
||||||
predictions,
|
predictions,
|
||||||
@ -757,13 +805,17 @@ def auc(labels,
|
|||||||
raise ValueError('Invalid summation_method: %s' % summation_method)
|
raise ValueError('Invalid summation_method: %s' % summation_method)
|
||||||
|
|
||||||
# sum up the areas of all the trapeziums
|
# sum up the areas of all the trapeziums
|
||||||
|
def aggregate_auc(_, values):
|
||||||
auc_value = compute_auc(values['tp'], values['fn'], values['tn'],
|
auc_value = compute_auc(values['tp'], values['fn'], values['tn'],
|
||||||
values['fp'], 'value')
|
values['fp'], 'value')
|
||||||
update_op = compute_auc(update_ops['tp'], update_ops['fn'],
|
|
||||||
update_ops['tn'], update_ops['fp'], 'update_op')
|
|
||||||
|
|
||||||
if metrics_collections:
|
if metrics_collections:
|
||||||
ops.add_to_collections(metrics_collections, auc_value)
|
ops.add_to_collections(metrics_collections, auc_value)
|
||||||
|
return auc_value
|
||||||
|
|
||||||
|
auc_value = distribute_lib.get_tower_context().merge_call(
|
||||||
|
aggregate_auc, values)
|
||||||
|
update_op = compute_auc(update_ops['tp'], update_ops['fn'],
|
||||||
|
update_ops['tn'], update_ops['fp'], 'update_op')
|
||||||
|
|
||||||
if updates_collections:
|
if updates_collections:
|
||||||
ops.add_to_collections(updates_collections, update_op)
|
ops.add_to_collections(updates_collections, update_op)
|
||||||
@ -992,15 +1044,18 @@ def mean_per_class_accuracy(labels,
|
|||||||
update_total_op = state_ops.scatter_add(total, labels, ones)
|
update_total_op = state_ops.scatter_add(total, labels, ones)
|
||||||
update_count_op = state_ops.scatter_add(count, labels, is_correct)
|
update_count_op = state_ops.scatter_add(count, labels, is_correct)
|
||||||
|
|
||||||
|
def aggregate_mean_accuracy(_, count, total):
|
||||||
per_class_accuracy = _safe_div(count, total, None)
|
per_class_accuracy = _safe_div(count, total, None)
|
||||||
|
|
||||||
mean_accuracy_v = math_ops.reduce_mean(
|
mean_accuracy_v = math_ops.reduce_mean(
|
||||||
per_class_accuracy, name='mean_accuracy')
|
per_class_accuracy, name='mean_accuracy')
|
||||||
update_op = _safe_div(update_count_op, update_total_op, name='update_op')
|
|
||||||
|
|
||||||
if metrics_collections:
|
if metrics_collections:
|
||||||
ops.add_to_collections(metrics_collections, mean_accuracy_v)
|
ops.add_to_collections(metrics_collections, mean_accuracy_v)
|
||||||
|
return mean_accuracy_v
|
||||||
|
|
||||||
|
mean_accuracy_v = distribute_lib.get_tower_context().merge_call(
|
||||||
|
aggregate_mean_accuracy, count, total)
|
||||||
|
|
||||||
|
update_op = _safe_div(update_count_op, update_total_op, name='update_op')
|
||||||
if updates_collections:
|
if updates_collections:
|
||||||
ops.add_to_collections(updates_collections, update_op)
|
ops.add_to_collections(updates_collections, update_op)
|
||||||
|
|
||||||
@ -1071,7 +1126,7 @@ def mean_iou(labels,
|
|||||||
total_cm, update_op = _streaming_confusion_matrix(labels, predictions,
|
total_cm, update_op = _streaming_confusion_matrix(labels, predictions,
|
||||||
num_classes, weights)
|
num_classes, weights)
|
||||||
|
|
||||||
def compute_mean_iou(name):
|
def compute_mean_iou(total_cm, name):
|
||||||
"""Compute the mean intersection-over-union via the confusion matrix."""
|
"""Compute the mean intersection-over-union via the confusion matrix."""
|
||||||
sum_over_row = math_ops.to_float(math_ops.reduce_sum(total_cm, 0))
|
sum_over_row = math_ops.to_float(math_ops.reduce_sum(total_cm, 0))
|
||||||
sum_over_col = math_ops.to_float(math_ops.reduce_sum(total_cm, 1))
|
sum_over_col = math_ops.to_float(math_ops.reduce_sum(total_cm, 1))
|
||||||
@ -1098,10 +1153,14 @@ def mean_iou(labels,
|
|||||||
math_ops.reduce_sum(iou, name=name) / num_valid_entries, 0)
|
math_ops.reduce_sum(iou, name=name) / num_valid_entries, 0)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
mean_iou_v = compute_mean_iou('mean_iou')
|
def mean_iou_across_towers(_, v):
|
||||||
|
mean_iou_v = compute_mean_iou(v, 'mean_iou')
|
||||||
if metrics_collections:
|
if metrics_collections:
|
||||||
ops.add_to_collections(metrics_collections, mean_iou_v)
|
ops.add_to_collections(metrics_collections, mean_iou_v)
|
||||||
|
return mean_iou_v
|
||||||
|
|
||||||
|
mean_iou_v = distribute_lib.get_tower_context().merge_call(
|
||||||
|
mean_iou_across_towers, total_cm)
|
||||||
|
|
||||||
if updates_collections:
|
if updates_collections:
|
||||||
ops.add_to_collections(updates_collections, update_op)
|
ops.add_to_collections(updates_collections, update_op)
|
||||||
@ -1310,12 +1369,16 @@ def mean_tensor(values,
|
|||||||
with ops.control_dependencies([values]):
|
with ops.control_dependencies([values]):
|
||||||
update_count_op = state_ops.assign_add(count, num_values)
|
update_count_op = state_ops.assign_add(count, num_values)
|
||||||
|
|
||||||
mean_t = _safe_div(total, count, 'value')
|
def aggregate_across_towers(_, t, c):
|
||||||
update_op = _safe_div(update_total_op, update_count_op, 'update_op')
|
mean_t = _safe_div(t, c, 'value')
|
||||||
|
|
||||||
if metrics_collections:
|
if metrics_collections:
|
||||||
ops.add_to_collections(metrics_collections, mean_t)
|
ops.add_to_collections(metrics_collections, mean_t)
|
||||||
|
return mean_t
|
||||||
|
|
||||||
|
mean_t = distribute_lib.get_tower_context().merge_call(
|
||||||
|
aggregate_across_towers, total, count)
|
||||||
|
|
||||||
|
update_op = _safe_div(update_total_op, update_count_op, 'update_op')
|
||||||
if updates_collections:
|
if updates_collections:
|
||||||
ops.add_to_collections(updates_collections, update_op)
|
ops.add_to_collections(updates_collections, update_op)
|
||||||
|
|
||||||
@ -1413,12 +1476,9 @@ def _count_condition(values,
|
|||||||
weights = math_ops.to_float(weights)
|
weights = math_ops.to_float(weights)
|
||||||
values = math_ops.multiply(values, weights)
|
values = math_ops.multiply(values, weights)
|
||||||
|
|
||||||
value_tensor = array_ops.identity(count)
|
value_tensor = _aggregate_variable(count, metrics_collections)
|
||||||
|
|
||||||
update_op = state_ops.assign_add(count, math_ops.reduce_sum(values))
|
update_op = state_ops.assign_add(count, math_ops.reduce_sum(values))
|
||||||
|
|
||||||
if metrics_collections:
|
|
||||||
ops.add_to_collections(metrics_collections, value_tensor)
|
|
||||||
|
|
||||||
if updates_collections:
|
if updates_collections:
|
||||||
ops.add_to_collections(updates_collections, update_op)
|
ops.add_to_collections(updates_collections, update_op)
|
||||||
|
|
||||||
@ -1525,13 +1585,12 @@ def false_negatives_at_thresholds(labels,
|
|||||||
values, update_ops = _confusion_matrix_at_thresholds(
|
values, update_ops = _confusion_matrix_at_thresholds(
|
||||||
labels, predictions, thresholds, weights=weights, includes=('fn',))
|
labels, predictions, thresholds, weights=weights, includes=('fn',))
|
||||||
|
|
||||||
if metrics_collections:
|
fn_value = _aggregate_variable(values['fn'], metrics_collections)
|
||||||
ops.add_to_collections(metrics_collections, values['fn'])
|
|
||||||
|
|
||||||
if updates_collections:
|
if updates_collections:
|
||||||
ops.add_to_collections(updates_collections, update_ops['fn'])
|
ops.add_to_collections(updates_collections, update_ops['fn'])
|
||||||
|
|
||||||
return values['fn'], update_ops['fn']
|
return fn_value, update_ops['fn']
|
||||||
|
|
||||||
|
|
||||||
@tf_export('metrics.false_positives')
|
@tf_export('metrics.false_positives')
|
||||||
@ -1635,13 +1694,12 @@ def false_positives_at_thresholds(labels,
|
|||||||
values, update_ops = _confusion_matrix_at_thresholds(
|
values, update_ops = _confusion_matrix_at_thresholds(
|
||||||
labels, predictions, thresholds, weights=weights, includes=('fp',))
|
labels, predictions, thresholds, weights=weights, includes=('fp',))
|
||||||
|
|
||||||
if metrics_collections:
|
fp_value = _aggregate_variable(values['fp'], metrics_collections)
|
||||||
ops.add_to_collections(metrics_collections, values['fp'])
|
|
||||||
|
|
||||||
if updates_collections:
|
if updates_collections:
|
||||||
ops.add_to_collections(updates_collections, update_ops['fp'])
|
ops.add_to_collections(updates_collections, update_ops['fp'])
|
||||||
|
|
||||||
return values['fp'], update_ops['fp']
|
return fp_value, update_ops['fp']
|
||||||
|
|
||||||
|
|
||||||
@tf_export('metrics.true_negatives')
|
@tf_export('metrics.true_negatives')
|
||||||
@ -1745,13 +1803,12 @@ def true_negatives_at_thresholds(labels,
|
|||||||
values, update_ops = _confusion_matrix_at_thresholds(
|
values, update_ops = _confusion_matrix_at_thresholds(
|
||||||
labels, predictions, thresholds, weights=weights, includes=('tn',))
|
labels, predictions, thresholds, weights=weights, includes=('tn',))
|
||||||
|
|
||||||
if metrics_collections:
|
tn_value = _aggregate_variable(values['tn'], metrics_collections)
|
||||||
ops.add_to_collections(metrics_collections, values['tn'])
|
|
||||||
|
|
||||||
if updates_collections:
|
if updates_collections:
|
||||||
ops.add_to_collections(updates_collections, update_ops['tn'])
|
ops.add_to_collections(updates_collections, update_ops['tn'])
|
||||||
|
|
||||||
return values['tn'], update_ops['tn']
|
return tn_value, update_ops['tn']
|
||||||
|
|
||||||
|
|
||||||
@tf_export('metrics.true_positives')
|
@tf_export('metrics.true_positives')
|
||||||
@ -1855,13 +1912,12 @@ def true_positives_at_thresholds(labels,
|
|||||||
values, update_ops = _confusion_matrix_at_thresholds(
|
values, update_ops = _confusion_matrix_at_thresholds(
|
||||||
labels, predictions, thresholds, weights=weights, includes=('tp',))
|
labels, predictions, thresholds, weights=weights, includes=('tp',))
|
||||||
|
|
||||||
if metrics_collections:
|
tp_value = _aggregate_variable(values['tp'], metrics_collections)
|
||||||
ops.add_to_collections(metrics_collections, values['tp'])
|
|
||||||
|
|
||||||
if updates_collections:
|
if updates_collections:
|
||||||
ops.add_to_collections(updates_collections, update_ops['tp'])
|
ops.add_to_collections(updates_collections, update_ops['tp'])
|
||||||
|
|
||||||
return values['tp'], update_ops['tp']
|
return tp_value, update_ops['tp']
|
||||||
|
|
||||||
|
|
||||||
@tf_export('metrics.precision')
|
@tf_export('metrics.precision')
|
||||||
@ -1945,13 +2001,17 @@ def precision(labels,
|
|||||||
return array_ops.where(
|
return array_ops.where(
|
||||||
math_ops.greater(tp + fp, 0), math_ops.div(tp, tp + fp), 0, name)
|
math_ops.greater(tp + fp, 0), math_ops.div(tp, tp + fp), 0, name)
|
||||||
|
|
||||||
|
def once_across_towers(_, true_p, false_p):
|
||||||
p = compute_precision(true_p, false_p, 'value')
|
p = compute_precision(true_p, false_p, 'value')
|
||||||
update_op = compute_precision(true_positives_update_op,
|
|
||||||
false_positives_update_op, 'update_op')
|
|
||||||
|
|
||||||
if metrics_collections:
|
if metrics_collections:
|
||||||
ops.add_to_collections(metrics_collections, p)
|
ops.add_to_collections(metrics_collections, p)
|
||||||
|
return p
|
||||||
|
|
||||||
|
p = distribute_lib.get_tower_context().merge_call(
|
||||||
|
once_across_towers, true_p, false_p)
|
||||||
|
|
||||||
|
update_op = compute_precision(true_positives_update_op,
|
||||||
|
false_positives_update_op, 'update_op')
|
||||||
if updates_collections:
|
if updates_collections:
|
||||||
ops.add_to_collections(updates_collections, update_op)
|
ops.add_to_collections(updates_collections, update_op)
|
||||||
|
|
||||||
@ -2025,13 +2085,17 @@ def precision_at_thresholds(labels,
|
|||||||
def compute_precision(tp, fp, name):
|
def compute_precision(tp, fp, name):
|
||||||
return math_ops.div(tp, epsilon + tp + fp, name='precision_' + name)
|
return math_ops.div(tp, epsilon + tp + fp, name='precision_' + name)
|
||||||
|
|
||||||
|
def precision_across_towers(_, values):
|
||||||
prec = compute_precision(values['tp'], values['fp'], 'value')
|
prec = compute_precision(values['tp'], values['fp'], 'value')
|
||||||
update_op = compute_precision(update_ops['tp'], update_ops['fp'],
|
|
||||||
'update_op')
|
|
||||||
|
|
||||||
if metrics_collections:
|
if metrics_collections:
|
||||||
ops.add_to_collections(metrics_collections, prec)
|
ops.add_to_collections(metrics_collections, prec)
|
||||||
|
return prec
|
||||||
|
|
||||||
|
prec = distribute_lib.get_tower_context().merge_call(
|
||||||
|
precision_across_towers, values)
|
||||||
|
|
||||||
|
update_op = compute_precision(update_ops['tp'], update_ops['fp'],
|
||||||
|
'update_op')
|
||||||
if updates_collections:
|
if updates_collections:
|
||||||
ops.add_to_collections(updates_collections, update_op)
|
ops.add_to_collections(updates_collections, update_op)
|
||||||
|
|
||||||
@ -2117,13 +2181,17 @@ def recall(labels,
|
|||||||
math_ops.greater(true_p + false_n, 0),
|
math_ops.greater(true_p + false_n, 0),
|
||||||
math_ops.div(true_p, true_p + false_n), 0, name)
|
math_ops.div(true_p, true_p + false_n), 0, name)
|
||||||
|
|
||||||
|
def once_across_towers(_, true_p, false_n):
|
||||||
rec = compute_recall(true_p, false_n, 'value')
|
rec = compute_recall(true_p, false_n, 'value')
|
||||||
update_op = compute_recall(true_positives_update_op,
|
|
||||||
false_negatives_update_op, 'update_op')
|
|
||||||
|
|
||||||
if metrics_collections:
|
if metrics_collections:
|
||||||
ops.add_to_collections(metrics_collections, rec)
|
ops.add_to_collections(metrics_collections, rec)
|
||||||
|
return rec
|
||||||
|
|
||||||
|
rec = distribute_lib.get_tower_context().merge_call(
|
||||||
|
once_across_towers, true_p, false_n)
|
||||||
|
|
||||||
|
update_op = compute_recall(true_positives_update_op,
|
||||||
|
false_negatives_update_op, 'update_op')
|
||||||
if updates_collections:
|
if updates_collections:
|
||||||
ops.add_to_collections(updates_collections, update_op)
|
ops.add_to_collections(updates_collections, update_op)
|
||||||
|
|
||||||
@ -2552,11 +2620,17 @@ def recall_at_top_k(labels,
|
|||||||
class_id=class_id,
|
class_id=class_id,
|
||||||
weights=weights)
|
weights=weights)
|
||||||
|
|
||||||
|
def aggregate_across_towers(_, tp, fn):
|
||||||
metric = math_ops.div(tp, math_ops.add(tp, fn), name=scope)
|
metric = math_ops.div(tp, math_ops.add(tp, fn), name=scope)
|
||||||
update = math_ops.div(
|
|
||||||
tp_update, math_ops.add(tp_update, fn_update), name='update')
|
|
||||||
if metrics_collections:
|
if metrics_collections:
|
||||||
ops.add_to_collections(metrics_collections, metric)
|
ops.add_to_collections(metrics_collections, metric)
|
||||||
|
return metric
|
||||||
|
|
||||||
|
metric = distribute_lib.get_tower_context().merge_call(
|
||||||
|
aggregate_across_towers, tp, fn)
|
||||||
|
|
||||||
|
update = math_ops.div(
|
||||||
|
tp_update, math_ops.add(tp_update, fn_update), name='update')
|
||||||
if updates_collections:
|
if updates_collections:
|
||||||
ops.add_to_collections(updates_collections, update)
|
ops.add_to_collections(updates_collections, update)
|
||||||
return metric, update
|
return metric, update
|
||||||
@ -2627,12 +2701,16 @@ def recall_at_thresholds(labels,
|
|||||||
def compute_recall(tp, fn, name):
|
def compute_recall(tp, fn, name):
|
||||||
return math_ops.div(tp, epsilon + tp + fn, name='recall_' + name)
|
return math_ops.div(tp, epsilon + tp + fn, name='recall_' + name)
|
||||||
|
|
||||||
|
def recall_across_towers(_, values):
|
||||||
rec = compute_recall(values['tp'], values['fn'], 'value')
|
rec = compute_recall(values['tp'], values['fn'], 'value')
|
||||||
update_op = compute_recall(update_ops['tp'], update_ops['fn'], 'update_op')
|
|
||||||
|
|
||||||
if metrics_collections:
|
if metrics_collections:
|
||||||
ops.add_to_collections(metrics_collections, rec)
|
ops.add_to_collections(metrics_collections, rec)
|
||||||
|
return rec
|
||||||
|
|
||||||
|
rec = distribute_lib.get_tower_context().merge_call(
|
||||||
|
recall_across_towers, values)
|
||||||
|
|
||||||
|
update_op = compute_recall(update_ops['tp'], update_ops['fn'], 'update_op')
|
||||||
if updates_collections:
|
if updates_collections:
|
||||||
ops.add_to_collections(updates_collections, update_op)
|
ops.add_to_collections(updates_collections, update_op)
|
||||||
|
|
||||||
@ -2698,13 +2776,16 @@ def root_mean_squared_error(labels,
|
|||||||
mse, update_mse_op = mean_squared_error(labels, predictions, weights, None,
|
mse, update_mse_op = mean_squared_error(labels, predictions, weights, None,
|
||||||
None, name or
|
None, name or
|
||||||
'root_mean_squared_error')
|
'root_mean_squared_error')
|
||||||
|
def once_across_towers(_, mse):
|
||||||
rmse = math_ops.sqrt(mse)
|
rmse = math_ops.sqrt(mse)
|
||||||
update_rmse_op = math_ops.sqrt(update_mse_op)
|
|
||||||
|
|
||||||
if metrics_collections:
|
if metrics_collections:
|
||||||
ops.add_to_collections(metrics_collections, rmse)
|
ops.add_to_collections(metrics_collections, rmse)
|
||||||
|
return rmse
|
||||||
|
|
||||||
|
rmse = distribute_lib.get_tower_context().merge_call(
|
||||||
|
once_across_towers, mse)
|
||||||
|
|
||||||
|
update_rmse_op = math_ops.sqrt(update_mse_op)
|
||||||
if updates_collections:
|
if updates_collections:
|
||||||
ops.add_to_collections(updates_collections, update_rmse_op)
|
ops.add_to_collections(updates_collections, update_rmse_op)
|
||||||
|
|
||||||
@ -2797,15 +2878,19 @@ def sensitivity_at_specificity(labels,
|
|||||||
return math_ops.div(tp[tf_index], tp[tf_index] + fn[tf_index] + kepsilon,
|
return math_ops.div(tp[tf_index], tp[tf_index] + fn[tf_index] + kepsilon,
|
||||||
name)
|
name)
|
||||||
|
|
||||||
|
def aggregate_across_towers(_, values):
|
||||||
sensitivity = compute_sensitivity_at_specificity(
|
sensitivity = compute_sensitivity_at_specificity(
|
||||||
values['tp'], values['tn'], values['fp'], values['fn'], 'value')
|
values['tp'], values['tn'], values['fp'], values['fn'], 'value')
|
||||||
|
if metrics_collections:
|
||||||
|
ops.add_to_collections(metrics_collections, sensitivity)
|
||||||
|
return sensitivity
|
||||||
|
|
||||||
|
sensitivity = distribute_lib.get_tower_context().merge_call(
|
||||||
|
aggregate_across_towers, values)
|
||||||
|
|
||||||
update_op = compute_sensitivity_at_specificity(
|
update_op = compute_sensitivity_at_specificity(
|
||||||
update_ops['tp'], update_ops['tn'], update_ops['fp'], update_ops['fn'],
|
update_ops['tp'], update_ops['tn'], update_ops['fp'], update_ops['fn'],
|
||||||
'update_op')
|
'update_op')
|
||||||
|
|
||||||
if metrics_collections:
|
|
||||||
ops.add_to_collections(metrics_collections, sensitivity)
|
|
||||||
|
|
||||||
if updates_collections:
|
if updates_collections:
|
||||||
ops.add_to_collections(updates_collections, update_op)
|
ops.add_to_collections(updates_collections, update_op)
|
||||||
|
|
||||||
@ -3070,11 +3155,16 @@ def _streaming_sparse_average_precision_at_top_k(labels,
|
|||||||
total_update = state_ops.assign_add(total_var, batch_total, name='update')
|
total_update = state_ops.assign_add(total_var, batch_total, name='update')
|
||||||
|
|
||||||
# Divide total by max to get mean, for both vars and the update ops.
|
# Divide total by max to get mean, for both vars and the update ops.
|
||||||
|
def aggregate_across_towers(_, total_var, max_var):
|
||||||
mean_average_precision = _safe_scalar_div(total_var, max_var, name='mean')
|
mean_average_precision = _safe_scalar_div(total_var, max_var, name='mean')
|
||||||
update = _safe_scalar_div(total_update, max_update, name=scope)
|
|
||||||
|
|
||||||
if metrics_collections:
|
if metrics_collections:
|
||||||
ops.add_to_collections(metrics_collections, mean_average_precision)
|
ops.add_to_collections(metrics_collections, mean_average_precision)
|
||||||
|
return mean_average_precision
|
||||||
|
|
||||||
|
mean_average_precision = distribute_lib.get_tower_context().merge_call(
|
||||||
|
aggregate_across_towers, total_var, max_var)
|
||||||
|
|
||||||
|
update = _safe_scalar_div(total_update, max_update, name=scope)
|
||||||
if updates_collections:
|
if updates_collections:
|
||||||
ops.add_to_collections(updates_collections, update)
|
ops.add_to_collections(updates_collections, update)
|
||||||
|
|
||||||
@ -3351,11 +3441,17 @@ def precision_at_top_k(labels,
|
|||||||
class_id=class_id,
|
class_id=class_id,
|
||||||
weights=weights)
|
weights=weights)
|
||||||
|
|
||||||
|
def aggregate_across_towers(_, tp, fp):
|
||||||
metric = math_ops.div(tp, math_ops.add(tp, fp), name=scope)
|
metric = math_ops.div(tp, math_ops.add(tp, fp), name=scope)
|
||||||
update = math_ops.div(
|
|
||||||
tp_update, math_ops.add(tp_update, fp_update), name='update')
|
|
||||||
if metrics_collections:
|
if metrics_collections:
|
||||||
ops.add_to_collections(metrics_collections, metric)
|
ops.add_to_collections(metrics_collections, metric)
|
||||||
|
return metric
|
||||||
|
|
||||||
|
metric = distribute_lib.get_tower_context().merge_call(
|
||||||
|
aggregate_across_towers, tp, fp)
|
||||||
|
|
||||||
|
update = math_ops.div(
|
||||||
|
tp_update, math_ops.add(tp_update, fp_update), name='update')
|
||||||
if updates_collections:
|
if updates_collections:
|
||||||
ops.add_to_collections(updates_collections, update)
|
ops.add_to_collections(updates_collections, update)
|
||||||
return metric, update
|
return metric, update
|
||||||
@ -3583,15 +3679,19 @@ def specificity_at_sensitivity(labels,
|
|||||||
return math_ops.div(tn[tf_index], tn[tf_index] + fp[tf_index] + kepsilon,
|
return math_ops.div(tn[tf_index], tn[tf_index] + fp[tf_index] + kepsilon,
|
||||||
name)
|
name)
|
||||||
|
|
||||||
|
def aggregate_across_towers(_, values):
|
||||||
specificity = compute_specificity_at_sensitivity(
|
specificity = compute_specificity_at_sensitivity(
|
||||||
values['tp'], values['tn'], values['fp'], values['fn'], 'value')
|
values['tp'], values['tn'], values['fp'], values['fn'], 'value')
|
||||||
|
if metrics_collections:
|
||||||
|
ops.add_to_collections(metrics_collections, specificity)
|
||||||
|
return specificity
|
||||||
|
|
||||||
|
specificity = distribute_lib.get_tower_context().merge_call(
|
||||||
|
aggregate_across_towers, values)
|
||||||
|
|
||||||
update_op = compute_specificity_at_sensitivity(
|
update_op = compute_specificity_at_sensitivity(
|
||||||
update_ops['tp'], update_ops['tn'], update_ops['fp'], update_ops['fn'],
|
update_ops['tp'], update_ops['tn'], update_ops['fp'], update_ops['fn'],
|
||||||
'update_op')
|
'update_op')
|
||||||
|
|
||||||
if metrics_collections:
|
|
||||||
ops.add_to_collections(metrics_collections, specificity)
|
|
||||||
|
|
||||||
if updates_collections:
|
if updates_collections:
|
||||||
ops.add_to_collections(updates_collections, update_op)
|
ops.add_to_collections(updates_collections, update_op)
|
||||||
|
|
||||||
|
@ -82,6 +82,54 @@ def lbeta(x, name='lbeta'):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export('math.bessel_i0')
|
||||||
|
def bessel_i0(x, name='bessel_i0'):
|
||||||
|
"""Computes the Bessel i0 function of `x` element-wise.
|
||||||
|
|
||||||
|
Modified Bessel function of order 0.
|
||||||
|
|
||||||
|
It is preferable to use the numerically stabler function `i0e(x)` instead.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
|
||||||
|
`float32`, `float64`.
|
||||||
|
name: A name for the operation (optional).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
|
||||||
|
|
||||||
|
@compatibility(scipy)
|
||||||
|
Equivalent to scipy.special.i0
|
||||||
|
@end_compatibility
|
||||||
|
"""
|
||||||
|
with ops.name_scope(name, [x]):
|
||||||
|
return math_ops.exp(math_ops.abs(x)) * math_ops.bessel_i0e(x)
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export('math.bessel_i1')
|
||||||
|
def bessel_i1(x, name='bessel_i1'):
|
||||||
|
"""Computes the Bessel i1 function of `x` element-wise.
|
||||||
|
|
||||||
|
Modified Bessel function of order 1.
|
||||||
|
|
||||||
|
It is preferable to use the numerically stabler function `i1e(x)` instead.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
|
||||||
|
`float32`, `float64`.
|
||||||
|
name: A name for the operation (optional).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
|
||||||
|
|
||||||
|
@compatibility(scipy)
|
||||||
|
Equivalent to scipy.special.i1
|
||||||
|
@end_compatibility
|
||||||
|
"""
|
||||||
|
with ops.name_scope(name, [x]):
|
||||||
|
return math_ops.exp(math_ops.abs(x)) * math_ops.bessel_i1e(x)
|
||||||
|
|
||||||
|
|
||||||
@tf_export('einsum', 'linalg.einsum')
|
@tf_export('einsum', 'linalg.einsum')
|
||||||
def einsum(equation, *inputs, **kwargs):
|
def einsum(equation, *inputs, **kwargs):
|
||||||
"""A generalized contraction between tensors of arbitrary dimension.
|
"""A generalized contraction between tensors of arbitrary dimension.
|
||||||
|
@ -29,6 +29,7 @@ from tensorflow.python.ops import array_ops
|
|||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import special_math_ops
|
from tensorflow.python.ops import special_math_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
from tensorflow.python.platform import tf_logging
|
||||||
|
|
||||||
|
|
||||||
class LBetaTest(test.TestCase):
|
class LBetaTest(test.TestCase):
|
||||||
@ -150,6 +151,33 @@ class LBetaTest(test.TestCase):
|
|||||||
self.assertEqual(expected_result.get_shape(), lbeta_x.get_shape())
|
self.assertEqual(expected_result.get_shape(), lbeta_x.get_shape())
|
||||||
|
|
||||||
|
|
||||||
|
class BesselTest(test.TestCase):
|
||||||
|
|
||||||
|
def test_bessel_i0(self):
|
||||||
|
x_single = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float32)
|
||||||
|
x_double = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float64)
|
||||||
|
try:
|
||||||
|
from scipy import special # pylint: disable=g-import-not-at-top
|
||||||
|
self.assertAllClose(special.i0(x_single),
|
||||||
|
self.evaluate(special_math_ops.bessel_i0(x_single)))
|
||||||
|
self.assertAllClose(special.i0(x_double),
|
||||||
|
self.evaluate(special_math_ops.bessel_i0(x_double)))
|
||||||
|
except ImportError as e:
|
||||||
|
tf_logging.warn('Cannot test special functions: %s' % str(e))
|
||||||
|
|
||||||
|
def test_bessel_i1(self):
|
||||||
|
x_single = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float32)
|
||||||
|
x_double = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float64)
|
||||||
|
try:
|
||||||
|
from scipy import special # pylint: disable=g-import-not-at-top
|
||||||
|
self.assertAllClose(special.i1(x_single),
|
||||||
|
self.evaluate(special_math_ops.bessel_i1(x_single)))
|
||||||
|
self.assertAllClose(special.i1(x_double),
|
||||||
|
self.evaluate(special_math_ops.bessel_i1(x_double)))
|
||||||
|
except ImportError as e:
|
||||||
|
tf_logging.warn('Cannot test special functions: %s' % str(e))
|
||||||
|
|
||||||
|
|
||||||
class EinsumTest(test.TestCase):
|
class EinsumTest(test.TestCase):
|
||||||
|
|
||||||
simple_cases = [
|
simple_cases = [
|
||||||
|
@ -79,12 +79,14 @@ def _parse_saved_model(export_dir):
|
|||||||
constants.SAVED_MODEL_FILENAME_PB))
|
constants.SAVED_MODEL_FILENAME_PB))
|
||||||
|
|
||||||
|
|
||||||
def _get_asset_tensors(export_dir, meta_graph_def_to_load):
|
def _get_asset_tensors(export_dir, meta_graph_def_to_load, import_scope=None):
|
||||||
"""Gets the asset tensors, if defined in the meta graph def to load.
|
"""Gets the asset tensors, if defined in the meta graph def to load.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
export_dir: Directory where the SavedModel is located.
|
export_dir: Directory where the SavedModel is located.
|
||||||
meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded.
|
meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded.
|
||||||
|
import_scope: Optional `string` -- if specified, prepend this followed by
|
||||||
|
'/' to all returned asset tensor names.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A dictionary of asset tensors, keyed by the name of the asset tensor. The
|
A dictionary of asset tensors, keyed by the name of the asset tensor. The
|
||||||
@ -104,7 +106,10 @@ def _get_asset_tensors(export_dir, meta_graph_def_to_load):
|
|||||||
for asset_any_proto in assets_any_proto:
|
for asset_any_proto in assets_any_proto:
|
||||||
asset_proto = meta_graph_pb2.AssetFileDef()
|
asset_proto = meta_graph_pb2.AssetFileDef()
|
||||||
asset_any_proto.Unpack(asset_proto)
|
asset_any_proto.Unpack(asset_proto)
|
||||||
asset_tensor_dict[asset_proto.tensor_info.name] = os.path.join(
|
tensor_name = asset_proto.tensor_info.name
|
||||||
|
if import_scope:
|
||||||
|
tensor_name = "%s/%s" % (import_scope, tensor_name)
|
||||||
|
asset_tensor_dict[tensor_name] = os.path.join(
|
||||||
compat.as_bytes(assets_directory),
|
compat.as_bytes(assets_directory),
|
||||||
compat.as_bytes(asset_proto.filename))
|
compat.as_bytes(asset_proto.filename))
|
||||||
return asset_tensor_dict
|
return asset_tensor_dict
|
||||||
@ -179,7 +184,7 @@ def maybe_saved_model_directory(export_dir):
|
|||||||
|
|
||||||
|
|
||||||
@tf_export("saved_model.loader.load")
|
@tf_export("saved_model.loader.load")
|
||||||
def load(sess, tags, export_dir, **saver_kwargs):
|
def load(sess, tags, export_dir, import_scope=None, **saver_kwargs):
|
||||||
"""Loads the model from a SavedModel as specified by tags.
|
"""Loads the model from a SavedModel as specified by tags.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -189,6 +194,10 @@ def load(sess, tags, export_dir, **saver_kwargs):
|
|||||||
SavedModel `save()` API.
|
SavedModel `save()` API.
|
||||||
export_dir: Directory in which the SavedModel protocol buffer and variables
|
export_dir: Directory in which the SavedModel protocol buffer and variables
|
||||||
to be loaded are located.
|
to be loaded are located.
|
||||||
|
import_scope: Optional `string` -- if specified, prepend this string
|
||||||
|
followed by '/' to all loaded tensor names. This scope is applied to
|
||||||
|
tensor instances loaded into the passed session, but it is *not* written
|
||||||
|
through to the static `MetaGraphDef` protocol buffer that is returned.
|
||||||
**saver_kwargs: Optional keyword arguments passed through to Saver.
|
**saver_kwargs: Optional keyword arguments passed through to Saver.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -216,7 +225,8 @@ def load(sess, tags, export_dir, **saver_kwargs):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Build a saver by importing the meta graph def to load.
|
# Build a saver by importing the meta graph def to load.
|
||||||
saver = tf_saver.import_meta_graph(meta_graph_def_to_load, **saver_kwargs)
|
saver = tf_saver.import_meta_graph(
|
||||||
|
meta_graph_def_to_load, import_scope=import_scope, **saver_kwargs)
|
||||||
|
|
||||||
if saver:
|
if saver:
|
||||||
# Build the checkpoint path where the variables are located.
|
# Build the checkpoint path where the variables are located.
|
||||||
@ -232,8 +242,8 @@ def load(sess, tags, export_dir, **saver_kwargs):
|
|||||||
"checkpoints were restored.")
|
"checkpoints were restored.")
|
||||||
|
|
||||||
# Get asset tensors, if any.
|
# Get asset tensors, if any.
|
||||||
asset_tensors_dictionary = _get_asset_tensors(export_dir,
|
asset_tensors_dictionary = _get_asset_tensors(
|
||||||
meta_graph_def_to_load)
|
export_dir, meta_graph_def_to_load, import_scope=import_scope)
|
||||||
|
|
||||||
main_op_tensor = (
|
main_op_tensor = (
|
||||||
_get_main_op_tensor(meta_graph_def_to_load) or
|
_get_main_op_tensor(meta_graph_def_to_load) or
|
||||||
|
@ -1197,6 +1197,59 @@ class SavedModelTest(test.TestCase):
|
|||||||
_validate_custom_saver("tag_1", "save_1/restore_all")
|
_validate_custom_saver("tag_1", "save_1/restore_all")
|
||||||
_validate_custom_saver("tag_2", "save_2/restore_all")
|
_validate_custom_saver("tag_2", "save_2/restore_all")
|
||||||
|
|
||||||
|
def testImportScope(self):
|
||||||
|
export_dir = self._get_export_dir("test_scoped_assets")
|
||||||
|
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||||
|
|
||||||
|
# Build a SavedModel with a variable, an asset, and a constant tensor.
|
||||||
|
with self.test_session(graph=ops.Graph()) as sess:
|
||||||
|
self._init_and_validate_variable(sess, "v", 42)
|
||||||
|
asset_collection = self._build_asset_collection("foo.txt", "content_foo",
|
||||||
|
"asset_file_tensor")
|
||||||
|
constant_op.constant("constant value", name="constant_tensor_name")
|
||||||
|
builder.add_meta_graph_and_variables(
|
||||||
|
sess, ["tag_name"], assets_collection=asset_collection)
|
||||||
|
|
||||||
|
# Save the asset file path for later comparison.
|
||||||
|
asset_file_path = asset_collection[0].eval()
|
||||||
|
|
||||||
|
# Save the SavedModel to disk.
|
||||||
|
builder.save()
|
||||||
|
|
||||||
|
with self.test_session(graph=ops.Graph()) as sess:
|
||||||
|
# Restore the SavedModel under an import_scope in a new graph/session.
|
||||||
|
graph_proto = loader.load(
|
||||||
|
sess, ["tag_name"], export_dir, import_scope="scope_name")
|
||||||
|
|
||||||
|
# The loaded variable tensor should be scoped, but its contents should be
|
||||||
|
# unchanged.
|
||||||
|
self.assertEqual(
|
||||||
|
"scope_name/v:0",
|
||||||
|
ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].name)
|
||||||
|
self.assertEqual(
|
||||||
|
42,
|
||||||
|
ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
|
||||||
|
|
||||||
|
# The loaded asset tensor should be scoped, but the asset file path and
|
||||||
|
# contents should be unchanged.
|
||||||
|
asset_collection = ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS)
|
||||||
|
self.assertEqual(1, len(asset_collection))
|
||||||
|
self.assertEqual(asset_file_path, asset_collection[0].eval())
|
||||||
|
self.assertEqual("scope_name/asset_file_tensor:0",
|
||||||
|
asset_collection[0].name)
|
||||||
|
# The static asset data inside graph_proto.collection_def should not be
|
||||||
|
# scoped.
|
||||||
|
self._validate_asset_collection(export_dir, graph_proto.collection_def,
|
||||||
|
"foo.txt", "content_foo",
|
||||||
|
"asset_file_tensor:0")
|
||||||
|
|
||||||
|
# The constant tensor should be scoped, but its contents should be
|
||||||
|
# unchanged.
|
||||||
|
self.assertEqual(
|
||||||
|
compat.as_bytes("constant value"),
|
||||||
|
ops.get_default_graph().get_tensor_by_name(
|
||||||
|
"scope_name/constant_tensor_name:0").eval())
|
||||||
|
|
||||||
def testClearDevices(self):
|
def testClearDevices(self):
|
||||||
export_dir = self._get_export_dir("test_clear_devices")
|
export_dir = self._get_export_dir("test_clear_devices")
|
||||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||||
|
@ -528,6 +528,8 @@ class DistributionStrategy(object):
|
|||||||
* `d.update_non_slot(d.non_slot_devices(), fn)`: in cross-tower
|
* `d.update_non_slot(d.non_slot_devices(), fn)`: in cross-tower
|
||||||
context, like `d.update()` except with locality N.
|
context, like `d.update()` except with locality N.
|
||||||
* `d.fetch(t)`: Copy `t` with any locality to the client's CPU device.
|
* `d.fetch(t)`: Copy `t` with any locality to the client's CPU device.
|
||||||
|
TODO(josh11b): Deprecate `fetch`, switch to `read_var` for
|
||||||
|
reading tower-local variables.
|
||||||
|
|
||||||
The standard pattern for updating variables is to:
|
The standard pattern for updating variables is to:
|
||||||
|
|
||||||
@ -614,8 +616,8 @@ class DistributionStrategy(object):
|
|||||||
|
|
||||||
There will still be one component variable per tower, but there is
|
There will still be one component variable per tower, but there is
|
||||||
no requirement that they stay in sync. Instead, when saving them
|
no requirement that they stay in sync. Instead, when saving them
|
||||||
or calling `fetch()`, we use the value that results when calling
|
or calling `fetch()/read_var()`, we use the value that
|
||||||
`reduce()` on all the towers' variables.
|
results when calling `reduce()` on all the towers' variables.
|
||||||
|
|
||||||
Note: tower-local implies not trainable. Instead, it is expected
|
Note: tower-local implies not trainable. Instead, it is expected
|
||||||
that each tower will directly update (using `assign_add()` or
|
that each tower will directly update (using `assign_add()` or
|
||||||
@ -646,6 +648,21 @@ class DistributionStrategy(object):
|
|||||||
_require_distribution_strategy_scope(self)
|
_require_distribution_strategy_scope(self)
|
||||||
return variable_scope.variable_creator_scope(create_tower_local_variable)
|
return variable_scope.variable_creator_scope(create_tower_local_variable)
|
||||||
|
|
||||||
|
def read_var(self, v):
|
||||||
|
"""Reads the value of a variable.
|
||||||
|
|
||||||
|
Returns the aggregate value of a tower-local variable, or the
|
||||||
|
(read-only) value of any other variable.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
v: A variable allocated within the scope of this `DistributionStrategy`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tensor representing the value of `v`, aggregated across towers if
|
||||||
|
necessary.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("must be implemented in descendants")
|
||||||
|
|
||||||
def colocate_vars_with(self, colocate_with_variable):
|
def colocate_vars_with(self, colocate_with_variable):
|
||||||
"""Scope that controls which devices variables will be created on.
|
"""Scope that controls which devices variables will be created on.
|
||||||
|
|
||||||
@ -904,6 +921,8 @@ class DistributionStrategy(object):
|
|||||||
will attempt to avoid a copy by checking if the value is already
|
will attempt to avoid a copy by checking if the value is already
|
||||||
on the destination device.
|
on the destination device.
|
||||||
|
|
||||||
|
TODO(josh11b): Switch to `read_var`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
val: Value (which may be mirrored) to copy.
|
val: Value (which may be mirrored) to copy.
|
||||||
destination: A device string to copy the value to.
|
destination: A device string to copy the value to.
|
||||||
@ -1197,6 +1216,9 @@ class _DefaultDistributionStrategy(DistributionStrategy):
|
|||||||
with ops.colocate_with(colocate_with), UpdateContext(colocate_with):
|
with ops.colocate_with(colocate_with), UpdateContext(colocate_with):
|
||||||
return fn(*args, **kwargs)
|
return fn(*args, **kwargs)
|
||||||
|
|
||||||
|
def read_var(self, tower_local_var):
|
||||||
|
return array_ops.identity(tower_local_var)
|
||||||
|
|
||||||
def _fetch(self, var, destination, fn):
|
def _fetch(self, var, destination, fn):
|
||||||
with ops.colocate_with(var):
|
with ops.colocate_with(var):
|
||||||
var = fn(var)
|
var = fn(var)
|
||||||
|
@ -1970,7 +1970,7 @@ def import_meta_graph(meta_graph_or_file, clear_devices=False,
|
|||||||
|
|
||||||
return Saver(saver_def=meta_graph_def.saver_def, name=scope)
|
return Saver(saver_def=meta_graph_def.saver_def, name=scope)
|
||||||
else:
|
else:
|
||||||
if variables._all_saveable_objects(): # pylint: disable=protected-access
|
if variables._all_saveable_objects(scope=import_scope): # pylint: disable=protected-access
|
||||||
# Return the default saver instance for all graph variables.
|
# Return the default saver instance for all graph variables.
|
||||||
return Saver()
|
return Saver()
|
||||||
else:
|
else:
|
||||||
|
@ -2339,6 +2339,46 @@ class MetaGraphTest(test.TestCase):
|
|||||||
10, size=[1, 10])
|
10, size=[1, 10])
|
||||||
})
|
})
|
||||||
|
|
||||||
|
def testImportIntoNamescopeWithoutVariables(self):
|
||||||
|
# Save a simple graph that contains no variables into a checkpoint.
|
||||||
|
test_dir = self._get_test_dir("no_vars_graph")
|
||||||
|
filename = os.path.join(test_dir, "ckpt")
|
||||||
|
graph_1 = ops_lib.Graph()
|
||||||
|
with session.Session(graph=graph_1) as sess:
|
||||||
|
constant_op.constant([1, 2, 3], name="x")
|
||||||
|
constant_op.constant([1, 2, 3], name="y")
|
||||||
|
saver = saver_module.Saver(allow_empty=True)
|
||||||
|
saver.save(sess, filename)
|
||||||
|
|
||||||
|
# Create a fresh graph.
|
||||||
|
graph_2 = ops_lib.Graph()
|
||||||
|
with session.Session(graph=graph_2) as sess:
|
||||||
|
# Restore the above checkpoint under scope "subgraph_1".
|
||||||
|
new_saver_1 = saver_module.import_meta_graph(
|
||||||
|
filename + ".meta", graph=graph_2, import_scope="subgraph_1")
|
||||||
|
# There are no variables to restore, so import_meta_graph should not
|
||||||
|
# return a Saver.
|
||||||
|
self.assertIsNone(new_saver_1)
|
||||||
|
|
||||||
|
# Create a variable in graph_2 under scope "my_scope".
|
||||||
|
variables.Variable(array_ops.zeros([10]), name="my_scope/my_var")
|
||||||
|
sess.run(variables.global_variables_initializer())
|
||||||
|
# Restore the checkpoint into a different scope "subgraph_2".
|
||||||
|
new_saver_2 = saver_module.import_meta_graph(
|
||||||
|
filename + ".meta", graph=graph_2, import_scope="subgraph_2")
|
||||||
|
# Because the variable does not live in scope "subgraph_2",
|
||||||
|
# import_meta_graph should not attempt to restore the variable. So,
|
||||||
|
# import_meta_graph still won't return a Saver instance.
|
||||||
|
self.assertIsNone(new_saver_2)
|
||||||
|
|
||||||
|
# However, if we restore the checkpoint under scope "my_scope",
|
||||||
|
# import_meta_graph will detect the variable and return a Saver for
|
||||||
|
# restoring it. This should happen even when the variable does not
|
||||||
|
# originate from graph_1.
|
||||||
|
new_saver_3 = saver_module.import_meta_graph(
|
||||||
|
filename + ".meta", graph=graph_2, import_scope="my_scope")
|
||||||
|
self.assertIsInstance(new_saver_3, saver_module.Saver)
|
||||||
|
|
||||||
def testImportIntoImplicitNamescope(self):
|
def testImportIntoImplicitNamescope(self):
|
||||||
# Test that we can import a meta graph into an implicit namescope.
|
# Test that we can import a meta graph into an implicit namescope.
|
||||||
test_dir = self._get_test_dir("import_into_namescope")
|
test_dir = self._get_test_dir("import_into_namescope")
|
||||||
|
@ -24,17 +24,12 @@ limitations under the License.
|
|||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
#include <stdlib.h>
|
#include <stdlib.h>
|
||||||
#include <string.h>
|
#include <string.h>
|
||||||
#ifdef __APPLE__
|
|
||||||
#include <IOKit/kext/KextManager.h>
|
|
||||||
#include <mach-o/dyld.h>
|
|
||||||
#else
|
|
||||||
#if !defined(PLATFORM_WINDOWS)
|
#if !defined(PLATFORM_WINDOWS)
|
||||||
#include <link.h>
|
#include <link.h>
|
||||||
#include <sys/sysmacros.h>
|
#include <sys/sysmacros.h>
|
||||||
#include <unistd.h>
|
#include <unistd.h>
|
||||||
#endif
|
#endif
|
||||||
#include <sys/stat.h>
|
#include <sys/stat.h>
|
||||||
#endif
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
@ -54,9 +49,7 @@ limitations under the License.
|
|||||||
namespace stream_executor {
|
namespace stream_executor {
|
||||||
namespace cuda {
|
namespace cuda {
|
||||||
|
|
||||||
#ifdef __APPLE__
|
#if !defined(PLATFORM_WINDOWS)
|
||||||
static const CFStringRef kDriverKextIdentifier = CFSTR("com.nvidia.CUDA");
|
|
||||||
#elif !defined(PLATFORM_WINDOWS)
|
|
||||||
static const char *kDriverVersionPath = "/proc/driver/nvidia/version";
|
static const char *kDriverVersionPath = "/proc/driver/nvidia/version";
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@ -121,26 +114,7 @@ string Diagnostician::GetDevNodePath(int dev_node_ordinal) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Diagnostician::LogDiagnosticInformation() {
|
void Diagnostician::LogDiagnosticInformation() {
|
||||||
#ifdef __APPLE__
|
#if !defined(PLATFORM_WINDOWS)
|
||||||
CFStringRef kext_ids[1];
|
|
||||||
kext_ids[0] = kDriverKextIdentifier;
|
|
||||||
CFArrayRef kext_id_query = CFArrayCreate(nullptr, (const void**)kext_ids, 1, &kCFTypeArrayCallBacks);
|
|
||||||
CFDictionaryRef kext_infos = KextManagerCopyLoadedKextInfo(kext_id_query, nullptr);
|
|
||||||
CFRelease(kext_id_query);
|
|
||||||
|
|
||||||
CFDictionaryRef cuda_driver_info = nullptr;
|
|
||||||
if (CFDictionaryGetValueIfPresent(kext_infos, kDriverKextIdentifier, (const void**)&cuda_driver_info)) {
|
|
||||||
bool started = CFBooleanGetValue((CFBooleanRef)CFDictionaryGetValue(cuda_driver_info, CFSTR("OSBundleStarted")));
|
|
||||||
if (!started) {
|
|
||||||
LOG(INFO) << "kernel driver is installed, but does not appear to be running on this host "
|
|
||||||
<< "(" << port::Hostname() << ")";
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
LOG(INFO) << "kernel driver does not appear to be installed on this host "
|
|
||||||
<< "(" << port::Hostname() << ")";
|
|
||||||
}
|
|
||||||
CFRelease(kext_infos);
|
|
||||||
#elif !defined(PLATFORM_WINDOWS)
|
|
||||||
if (access(kDriverVersionPath, F_OK) != 0) {
|
if (access(kDriverVersionPath, F_OK) != 0) {
|
||||||
LOG(INFO) << "kernel driver does not appear to be running on this host "
|
LOG(INFO) << "kernel driver does not appear to be running on this host "
|
||||||
<< "(" << port::Hostname() << "): "
|
<< "(" << port::Hostname() << "): "
|
||||||
@ -194,8 +168,7 @@ void Diagnostician::LogDiagnosticInformation() {
|
|||||||
<< DriverVersionStatusToString(kernel_version);
|
<< DriverVersionStatusToString(kernel_version);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// OS X kernel driver does not report version accurately
|
#if !defined(PLATFORM_WINDOWS)
|
||||||
#if !defined(__APPLE__) && !defined(PLATFORM_WINDOWS)
|
|
||||||
if (kernel_version.ok() && dso_version.ok()) {
|
if (kernel_version.ok() && dso_version.ok()) {
|
||||||
WarnOnDsoKernelMismatch(dso_version, kernel_version);
|
WarnOnDsoKernelMismatch(dso_version, kernel_version);
|
||||||
}
|
}
|
||||||
@ -209,29 +182,6 @@ port::StatusOr<DriverVersion> Diagnostician::FindDsoVersion() {
|
|||||||
port::error::NOT_FOUND,
|
port::error::NOT_FOUND,
|
||||||
"was unable to find libcuda.so DSO loaded into this program"));
|
"was unable to find libcuda.so DSO loaded into this program"));
|
||||||
|
|
||||||
#if defined(__APPLE__)
|
|
||||||
// OSX CUDA libraries have names like: libcuda_310.41.15_mercury.dylib
|
|
||||||
const string prefix("libcuda_");
|
|
||||||
const string suffix("_mercury.dylib");
|
|
||||||
for (uint32_t image_index = 0; image_index < _dyld_image_count(); ++image_index) {
|
|
||||||
const string path(_dyld_get_image_name(image_index));
|
|
||||||
const size_t suffix_pos = path.rfind(suffix);
|
|
||||||
const size_t prefix_pos = path.rfind(prefix, suffix_pos);
|
|
||||||
if (prefix_pos == string::npos ||
|
|
||||||
suffix_pos == string::npos) {
|
|
||||||
// no match
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
const size_t start = prefix_pos + prefix.size();
|
|
||||||
if (start >= suffix_pos) {
|
|
||||||
// version not included
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
const size_t length = suffix_pos - start;
|
|
||||||
const string version = path.substr(start, length);
|
|
||||||
result = StringToDriverVersion(version);
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
#if !defined(PLATFORM_WINDOWS) && !defined(ANDROID_TEGRA)
|
#if !defined(PLATFORM_WINDOWS) && !defined(ANDROID_TEGRA)
|
||||||
// Callback used when iterating through DSOs. Looks for the driver-interfacing
|
// Callback used when iterating through DSOs. Looks for the driver-interfacing
|
||||||
// DSO and yields its version number into the callback data, when found.
|
// DSO and yields its version number into the callback data, when found.
|
||||||
@ -264,7 +214,6 @@ port::StatusOr<DriverVersion> Diagnostician::FindDsoVersion() {
|
|||||||
};
|
};
|
||||||
|
|
||||||
dl_iterate_phdr(iterate_phdr, &result);
|
dl_iterate_phdr(iterate_phdr, &result);
|
||||||
#endif
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
@ -310,38 +259,7 @@ void Diagnostician::WarnOnDsoKernelMismatch(
|
|||||||
|
|
||||||
|
|
||||||
port::StatusOr<DriverVersion> Diagnostician::FindKernelDriverVersion() {
|
port::StatusOr<DriverVersion> Diagnostician::FindKernelDriverVersion() {
|
||||||
#if defined(__APPLE__)
|
#if defined(PLATFORM_WINDOWS)
|
||||||
CFStringRef kext_ids[1];
|
|
||||||
kext_ids[0] = kDriverKextIdentifier;
|
|
||||||
CFArrayRef kext_id_query = CFArrayCreate(nullptr, (const void**)kext_ids, 1, &kCFTypeArrayCallBacks);
|
|
||||||
CFDictionaryRef kext_infos = KextManagerCopyLoadedKextInfo(kext_id_query, nullptr);
|
|
||||||
CFRelease(kext_id_query);
|
|
||||||
|
|
||||||
CFDictionaryRef cuda_driver_info = nullptr;
|
|
||||||
if (CFDictionaryGetValueIfPresent(kext_infos, kDriverKextIdentifier, (const void**)&cuda_driver_info)) {
|
|
||||||
// NOTE: OSX CUDA driver does not currently store the same driver version
|
|
||||||
// in kCFBundleVersionKey as is returned by cuDriverGetVersion
|
|
||||||
CFRelease(kext_infos);
|
|
||||||
const CFStringRef str = (CFStringRef)CFDictionaryGetValue(
|
|
||||||
cuda_driver_info, kCFBundleVersionKey);
|
|
||||||
const char *version = CFStringGetCStringPtr(str, kCFStringEncodingUTF8);
|
|
||||||
|
|
||||||
// version can be NULL in which case treat it as empty string
|
|
||||||
// see
|
|
||||||
// https://developer.apple.com/library/mac/documentation/CoreFoundation/Conceptual/CFStrings/Articles/AccessingContents.html#//apple_ref/doc/uid/20001184-100980-TPXREF112
|
|
||||||
if (version == NULL) {
|
|
||||||
return StringToDriverVersion("");
|
|
||||||
}
|
|
||||||
return StringToDriverVersion(version);
|
|
||||||
}
|
|
||||||
CFRelease(kext_infos);
|
|
||||||
auto status = port::Status(
|
|
||||||
port::error::INTERNAL,
|
|
||||||
port::StrCat(
|
|
||||||
"failed to read driver bundle version: ",
|
|
||||||
CFStringGetCStringPtr(kDriverKextIdentifier, kCFStringEncodingUTF8)));
|
|
||||||
return status;
|
|
||||||
#elif defined(PLATFORM_WINDOWS)
|
|
||||||
auto status =
|
auto status =
|
||||||
port::Status(port::error::UNIMPLEMENTED,
|
port::Status(port::error::UNIMPLEMENTED,
|
||||||
"kernel reported driver version not implemented on Windows");
|
"kernel reported driver version not implemented on Windows");
|
||||||
|
@ -495,9 +495,9 @@ PersistentRnnPlan CreatePersistentRnnPlan(cudnnRNNDescriptor_t rnn_desc,
|
|||||||
|
|
||||||
// Turns a BatchDescriptor structure into a cudnn tensor handle within a
|
// Turns a BatchDescriptor structure into a cudnn tensor handle within a
|
||||||
// scope.
|
// scope.
|
||||||
class ScopedTensorDescriptor {
|
class CudnnTensorDescriptor {
|
||||||
public:
|
public:
|
||||||
ScopedTensorDescriptor(const dnn::BatchDescriptor& batch_descriptor,
|
CudnnTensorDescriptor(const dnn::BatchDescriptor& batch_descriptor,
|
||||||
cudnnDataType_t elem_type)
|
cudnnDataType_t elem_type)
|
||||||
: handle_(CreateTensorDescriptor()) {
|
: handle_(CreateTensorDescriptor()) {
|
||||||
switch (batch_descriptor.layout()) {
|
switch (batch_descriptor.layout()) {
|
||||||
@ -540,14 +540,14 @@ class ScopedTensorDescriptor {
|
|||||||
private:
|
private:
|
||||||
TensorDescriptor handle_;
|
TensorDescriptor handle_;
|
||||||
|
|
||||||
SE_DISALLOW_COPY_AND_ASSIGN(ScopedTensorDescriptor);
|
SE_DISALLOW_COPY_AND_ASSIGN(CudnnTensorDescriptor);
|
||||||
};
|
};
|
||||||
|
|
||||||
// Turns a FilterDescriptor structure into a cudnn filter handle within a
|
// Turns a FilterDescriptor structure into a cudnn filter handle within a
|
||||||
// scope.
|
// scope.
|
||||||
class ScopedFilterDescriptor {
|
class CudnnFilterDescriptor {
|
||||||
public:
|
public:
|
||||||
ScopedFilterDescriptor(const dnn::FilterDescriptor& filter_descriptor,
|
CudnnFilterDescriptor(const dnn::FilterDescriptor& filter_descriptor,
|
||||||
cudnnDataType_t elem_type)
|
cudnnDataType_t elem_type)
|
||||||
: handle_(CreateFilterDescriptor()) {
|
: handle_(CreateFilterDescriptor()) {
|
||||||
// TODO(b/23032134): Even if the filter layout is not supported,
|
// TODO(b/23032134): Even if the filter layout is not supported,
|
||||||
@ -586,7 +586,7 @@ class ScopedFilterDescriptor {
|
|||||||
private:
|
private:
|
||||||
FilterDescriptor handle_; // Owned.
|
FilterDescriptor handle_; // Owned.
|
||||||
|
|
||||||
SE_DISALLOW_COPY_AND_ASSIGN(ScopedFilterDescriptor);
|
SE_DISALLOW_COPY_AND_ASSIGN(CudnnFilterDescriptor);
|
||||||
};
|
};
|
||||||
|
|
||||||
// A helper function to decide whether to enable the TENSOR_OP_MATH math type
|
// A helper function to decide whether to enable the TENSOR_OP_MATH math type
|
||||||
@ -636,9 +636,9 @@ bool BatchnormSpatialPersistentEnabled() {
|
|||||||
|
|
||||||
// Turns a ConvolutionDescriptor structure into a cudnn convolution handle
|
// Turns a ConvolutionDescriptor structure into a cudnn convolution handle
|
||||||
// within a scope.
|
// within a scope.
|
||||||
class ScopedConvolutionDescriptor {
|
class CudnnConvolutionDescriptor {
|
||||||
public:
|
public:
|
||||||
ScopedConvolutionDescriptor(
|
CudnnConvolutionDescriptor(
|
||||||
const dnn::ConvolutionDescriptor& convolution_descriptor,
|
const dnn::ConvolutionDescriptor& convolution_descriptor,
|
||||||
cudnnDataType_t data_type)
|
cudnnDataType_t data_type)
|
||||||
: handle_(CreateConvolutionDescriptor()) {
|
: handle_(CreateConvolutionDescriptor()) {
|
||||||
@ -700,14 +700,14 @@ class ScopedConvolutionDescriptor {
|
|||||||
private:
|
private:
|
||||||
ConvolutionDescriptor handle_; // Owned.
|
ConvolutionDescriptor handle_; // Owned.
|
||||||
|
|
||||||
SE_DISALLOW_COPY_AND_ASSIGN(ScopedConvolutionDescriptor);
|
SE_DISALLOW_COPY_AND_ASSIGN(CudnnConvolutionDescriptor);
|
||||||
};
|
};
|
||||||
|
|
||||||
// Turns a PoolingDescriptor structure into a cudnn pooling descriptor handle
|
// Turns a PoolingDescriptor structure into a cudnn pooling descriptor handle
|
||||||
// within a scope.
|
// within a scope.
|
||||||
class ScopedPoolingDescriptor {
|
class CudnnPoolingDescriptor {
|
||||||
public:
|
public:
|
||||||
explicit ScopedPoolingDescriptor(
|
explicit CudnnPoolingDescriptor(
|
||||||
const dnn::PoolingDescriptor& pooling_descriptor)
|
const dnn::PoolingDescriptor& pooling_descriptor)
|
||||||
: handle_(CreatePoolingDescriptor()) {
|
: handle_(CreatePoolingDescriptor()) {
|
||||||
const std::vector<int64> strides64 = pooling_descriptor.strides();
|
const std::vector<int64> strides64 = pooling_descriptor.strides();
|
||||||
@ -739,13 +739,13 @@ class ScopedPoolingDescriptor {
|
|||||||
private:
|
private:
|
||||||
PoolingDescriptor handle_; // Owned.
|
PoolingDescriptor handle_; // Owned.
|
||||||
|
|
||||||
SE_DISALLOW_COPY_AND_ASSIGN(ScopedPoolingDescriptor);
|
SE_DISALLOW_COPY_AND_ASSIGN(CudnnPoolingDescriptor);
|
||||||
};
|
};
|
||||||
|
|
||||||
// Turns a NormalizeDescriptor structure into a cudnn LRN descriptor handle.
|
// Turns a NormalizeDescriptor structure into a cudnn LRN descriptor handle.
|
||||||
class ScopedNormalizeDescriptor {
|
class CudnnNormalizeDescriptor {
|
||||||
public:
|
public:
|
||||||
explicit ScopedNormalizeDescriptor(
|
explicit CudnnNormalizeDescriptor(
|
||||||
const dnn::NormalizeDescriptor& normalize_descriptor)
|
const dnn::NormalizeDescriptor& normalize_descriptor)
|
||||||
: handle_(CreateLrnDescriptor()) {
|
: handle_(CreateLrnDescriptor()) {
|
||||||
// The range specifies that the indices in the closed range
|
// The range specifies that the indices in the closed range
|
||||||
@ -777,14 +777,14 @@ class ScopedNormalizeDescriptor {
|
|||||||
private:
|
private:
|
||||||
LrnDescriptor handle_; // Owned.
|
LrnDescriptor handle_; // Owned.
|
||||||
|
|
||||||
SE_DISALLOW_COPY_AND_ASSIGN(ScopedNormalizeDescriptor);
|
SE_DISALLOW_COPY_AND_ASSIGN(CudnnNormalizeDescriptor);
|
||||||
};
|
};
|
||||||
|
|
||||||
// Turns a ActivationDescriptor structure into a cudnn activation
|
// Turns a ActivationDescriptor structure into a cudnn activation
|
||||||
// descriptor handle within a scope.
|
// descriptor handle within a scope.
|
||||||
class ScopedActivationDescriptor {
|
class CudnnActivationDescriptor {
|
||||||
public:
|
public:
|
||||||
ScopedActivationDescriptor(dnn::ActivationMode activation_mode,
|
CudnnActivationDescriptor(dnn::ActivationMode activation_mode,
|
||||||
cudnnNanPropagation_t nan_propagation,
|
cudnnNanPropagation_t nan_propagation,
|
||||||
double value_max)
|
double value_max)
|
||||||
: handle_(CreateActivationDescriptor()) {
|
: handle_(CreateActivationDescriptor()) {
|
||||||
@ -822,7 +822,7 @@ class ScopedActivationDescriptor {
|
|||||||
private:
|
private:
|
||||||
ActivationDescriptor handle_; // Owned.
|
ActivationDescriptor handle_; // Owned.
|
||||||
|
|
||||||
SE_DISALLOW_COPY_AND_ASSIGN(ScopedActivationDescriptor);
|
SE_DISALLOW_COPY_AND_ASSIGN(CudnnActivationDescriptor);
|
||||||
};
|
};
|
||||||
|
|
||||||
cudnnDataType_t ToCudnnDataType(
|
cudnnDataType_t ToCudnnDataType(
|
||||||
@ -888,21 +888,21 @@ int CudnnDataTypeToByteSize(cudnnDataType_t data_type) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
class ScopedDropoutDescriptor {
|
class CudnnDropoutDescriptor {
|
||||||
explicit ScopedDropoutDescriptor(DropoutDescriptor handle)
|
explicit CudnnDropoutDescriptor(DropoutDescriptor handle)
|
||||||
: handle_(std::move(handle)) {}
|
: handle_(std::move(handle)) {}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
ScopedDropoutDescriptor(ScopedDropoutDescriptor&&) = default;
|
CudnnDropoutDescriptor(CudnnDropoutDescriptor&&) = default;
|
||||||
|
|
||||||
static port::StatusOr<ScopedDropoutDescriptor> Create(
|
static port::StatusOr<CudnnDropoutDescriptor> Create(
|
||||||
const CudnnHandle& cudnn, float dropout, uint64 seed,
|
const CudnnHandle& cudnn, float dropout, uint64 seed,
|
||||||
ScratchAllocator* state_allocator) {
|
ScratchAllocator* state_allocator) {
|
||||||
DropoutDescriptor handle = CreateDropoutDescriptor();
|
DropoutDescriptor handle = CreateDropoutDescriptor();
|
||||||
|
|
||||||
if (dropout == 0.0f) {
|
if (dropout == 0.0f) {
|
||||||
// Return 'empty' dropout descriptor.
|
// Return 'empty' dropout descriptor.
|
||||||
return ScopedDropoutDescriptor(std::move(handle));
|
return CudnnDropoutDescriptor(std::move(handle));
|
||||||
}
|
}
|
||||||
|
|
||||||
DeviceMemory<uint8> state_memory;
|
DeviceMemory<uint8> state_memory;
|
||||||
@ -917,14 +917,14 @@ class ScopedDropoutDescriptor {
|
|||||||
handle.get(), cudnn.handle(), dropout, state_memory.opaque(),
|
handle.get(), cudnn.handle(), dropout, state_memory.opaque(),
|
||||||
state_memory.size(), seed));
|
state_memory.size(), seed));
|
||||||
|
|
||||||
return ScopedDropoutDescriptor(std::move(handle));
|
return CudnnDropoutDescriptor(std::move(handle));
|
||||||
}
|
}
|
||||||
|
|
||||||
cudnnDropoutDescriptor_t handle() const { return handle_.get(); }
|
cudnnDropoutDescriptor_t handle() const { return handle_.get(); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
DropoutDescriptor handle_; // Owned.
|
DropoutDescriptor handle_; // Owned.
|
||||||
SE_DISALLOW_COPY_AND_ASSIGN(ScopedDropoutDescriptor);
|
SE_DISALLOW_COPY_AND_ASSIGN(CudnnDropoutDescriptor);
|
||||||
};
|
};
|
||||||
|
|
||||||
class CudnnRnnParamsDescriptor {
|
class CudnnRnnParamsDescriptor {
|
||||||
@ -973,7 +973,7 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor {
|
|||||||
cudnnRNNMode_t rnn_mode, cudnnDataType_t data_type,
|
cudnnRNNMode_t rnn_mode, cudnnDataType_t data_type,
|
||||||
cudnnDataType_t compute_type,
|
cudnnDataType_t compute_type,
|
||||||
const dnn::AlgorithmConfig& algorithm_config,
|
const dnn::AlgorithmConfig& algorithm_config,
|
||||||
ScopedDropoutDescriptor dropout_desc,
|
CudnnDropoutDescriptor dropout_desc,
|
||||||
CudnnRnnParamsDescriptor params_desc)
|
CudnnRnnParamsDescriptor params_desc)
|
||||||
: rnn_desc_(std::move(rnn_desc)),
|
: rnn_desc_(std::move(rnn_desc)),
|
||||||
rnn_plan_(std::move(rnn_plan)),
|
rnn_plan_(std::move(rnn_plan)),
|
||||||
@ -1002,8 +1002,8 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor {
|
|||||||
const dnn::AlgorithmConfig& algorithm_config, float dropout, uint64 seed,
|
const dnn::AlgorithmConfig& algorithm_config, float dropout, uint64 seed,
|
||||||
ScratchAllocator* state_allocator) {
|
ScratchAllocator* state_allocator) {
|
||||||
SE_ASSIGN_OR_RETURN(
|
SE_ASSIGN_OR_RETURN(
|
||||||
ScopedDropoutDescriptor dropout_desc,
|
CudnnDropoutDescriptor dropout_desc,
|
||||||
ScopedDropoutDescriptor::Create(cudnn, dropout, seed, state_allocator));
|
CudnnDropoutDescriptor::Create(cudnn, dropout, seed, state_allocator));
|
||||||
|
|
||||||
cuda::RnnDescriptor rnn_desc = CreateRnnDescriptor();
|
cuda::RnnDescriptor rnn_desc = CreateRnnDescriptor();
|
||||||
cudnnRNNAlgo_t rnn_algo = ToCudnnRNNAlgo(algorithm_config.algorithm());
|
cudnnRNNAlgo_t rnn_algo = ToCudnnRNNAlgo(algorithm_config.algorithm());
|
||||||
@ -1097,7 +1097,7 @@ class CudnnRnnDescriptor : public dnn::RnnDescriptor {
|
|||||||
cudnnDataType_t data_type_;
|
cudnnDataType_t data_type_;
|
||||||
cudnnDataType_t compute_type_;
|
cudnnDataType_t compute_type_;
|
||||||
dnn::AlgorithmConfig algorithm_config_;
|
dnn::AlgorithmConfig algorithm_config_;
|
||||||
ScopedDropoutDescriptor dropout_desc_;
|
CudnnDropoutDescriptor dropout_desc_;
|
||||||
CudnnRnnParamsDescriptor params_desc_;
|
CudnnRnnParamsDescriptor params_desc_;
|
||||||
SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnDescriptor);
|
SE_DISALLOW_COPY_AND_ASSIGN(CudnnRnnDescriptor);
|
||||||
};
|
};
|
||||||
@ -1926,10 +1926,9 @@ namespace {
|
|||||||
// and backward filter.
|
// and backward filter.
|
||||||
|
|
||||||
port::StatusOr<cudnnConvolutionFwdAlgo_t> GetCudnnConvolutionForwardAlgo(
|
port::StatusOr<cudnnConvolutionFwdAlgo_t> GetCudnnConvolutionForwardAlgo(
|
||||||
const CudnnHandle& cudnn, const ScopedTensorDescriptor& input_nd,
|
const CudnnHandle& cudnn, const CudnnTensorDescriptor& input_nd,
|
||||||
const ScopedFilterDescriptor& filter,
|
const CudnnFilterDescriptor& filter, const CudnnConvolutionDescriptor& conv,
|
||||||
const ScopedConvolutionDescriptor& conv,
|
const CudnnTensorDescriptor& output_nd, bool specify_workspace_limit,
|
||||||
const ScopedTensorDescriptor& output_nd, bool specify_workspace_limit,
|
|
||||||
size_t memory_limit_bytes) {
|
size_t memory_limit_bytes) {
|
||||||
cudnnConvolutionFwdPreference_t preference =
|
cudnnConvolutionFwdPreference_t preference =
|
||||||
specify_workspace_limit ? CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT
|
specify_workspace_limit ? CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT
|
||||||
@ -1943,10 +1942,10 @@ port::StatusOr<cudnnConvolutionFwdAlgo_t> GetCudnnConvolutionForwardAlgo(
|
|||||||
|
|
||||||
port::StatusOr<cudnnConvolutionBwdDataAlgo_t>
|
port::StatusOr<cudnnConvolutionBwdDataAlgo_t>
|
||||||
GetCudnnConvolutionBackwardDataAlgo(const CudnnHandle& cudnn,
|
GetCudnnConvolutionBackwardDataAlgo(const CudnnHandle& cudnn,
|
||||||
const ScopedTensorDescriptor& input_nd,
|
const CudnnTensorDescriptor& input_nd,
|
||||||
const ScopedFilterDescriptor& filter,
|
const CudnnFilterDescriptor& filter,
|
||||||
const ScopedConvolutionDescriptor& conv,
|
const CudnnConvolutionDescriptor& conv,
|
||||||
const ScopedTensorDescriptor& output_nd,
|
const CudnnTensorDescriptor& output_nd,
|
||||||
bool specify_workspace_limit,
|
bool specify_workspace_limit,
|
||||||
size_t memory_limit_bytes) {
|
size_t memory_limit_bytes) {
|
||||||
cudnnConvolutionBwdDataPreference_t preference =
|
cudnnConvolutionBwdDataPreference_t preference =
|
||||||
@ -1962,10 +1961,10 @@ GetCudnnConvolutionBackwardDataAlgo(const CudnnHandle& cudnn,
|
|||||||
|
|
||||||
port::StatusOr<cudnnConvolutionBwdFilterAlgo_t>
|
port::StatusOr<cudnnConvolutionBwdFilterAlgo_t>
|
||||||
GetCudnnConvolutionBackwardFilterAlgo(const CudnnHandle& cudnn,
|
GetCudnnConvolutionBackwardFilterAlgo(const CudnnHandle& cudnn,
|
||||||
const ScopedTensorDescriptor& input_nd,
|
const CudnnTensorDescriptor& input_nd,
|
||||||
const ScopedFilterDescriptor& filter,
|
const CudnnFilterDescriptor& filter,
|
||||||
const ScopedConvolutionDescriptor& conv,
|
const CudnnConvolutionDescriptor& conv,
|
||||||
const ScopedTensorDescriptor& output_nd,
|
const CudnnTensorDescriptor& output_nd,
|
||||||
bool specify_workspace_limit,
|
bool specify_workspace_limit,
|
||||||
size_t memory_limit_bytes) {
|
size_t memory_limit_bytes) {
|
||||||
cudnnConvolutionBwdFilterPreference_t preference =
|
cudnnConvolutionBwdFilterPreference_t preference =
|
||||||
@ -1982,10 +1981,9 @@ GetCudnnConvolutionBackwardFilterAlgo(const CudnnHandle& cudnn,
|
|||||||
port::StatusOr<DeviceMemory<uint8>> AllocateCudnnConvolutionForwardWorkspace(
|
port::StatusOr<DeviceMemory<uint8>> AllocateCudnnConvolutionForwardWorkspace(
|
||||||
Stream* stream, const CudnnHandle& cudnn,
|
Stream* stream, const CudnnHandle& cudnn,
|
||||||
const dnn::AlgorithmDesc& algorithm_desc,
|
const dnn::AlgorithmDesc& algorithm_desc,
|
||||||
const ScopedTensorDescriptor& input_nd,
|
const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
|
||||||
const ScopedFilterDescriptor& filter,
|
const CudnnConvolutionDescriptor& conv,
|
||||||
const ScopedConvolutionDescriptor& conv,
|
const CudnnTensorDescriptor& output_nd,
|
||||||
const ScopedTensorDescriptor& output_nd,
|
|
||||||
ScratchAllocator* scratch_allocator) {
|
ScratchAllocator* scratch_allocator) {
|
||||||
// TODO(csigg): This has side effects on the convolution descriptor. It is
|
// TODO(csigg): This has side effects on the convolution descriptor. It is
|
||||||
// functionally correct because the convolution is run with the algorithm of
|
// functionally correct because the convolution is run with the algorithm of
|
||||||
@ -2025,10 +2023,9 @@ port::StatusOr<DeviceMemory<uint8>>
|
|||||||
AllocateCudnnConvolutionBackwardDataWorkspace(
|
AllocateCudnnConvolutionBackwardDataWorkspace(
|
||||||
Stream* stream, const CudnnHandle& cudnn,
|
Stream* stream, const CudnnHandle& cudnn,
|
||||||
const dnn::AlgorithmDesc& algorithm_desc,
|
const dnn::AlgorithmDesc& algorithm_desc,
|
||||||
const ScopedTensorDescriptor& input_nd,
|
const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
|
||||||
const ScopedFilterDescriptor& filter,
|
const CudnnConvolutionDescriptor& conv,
|
||||||
const ScopedConvolutionDescriptor& conv,
|
const CudnnTensorDescriptor& output_nd,
|
||||||
const ScopedTensorDescriptor& output_nd,
|
|
||||||
ScratchAllocator* scratch_allocator) {
|
ScratchAllocator* scratch_allocator) {
|
||||||
// TODO(csigg): This has side effects on the convolution descriptor. It is
|
// TODO(csigg): This has side effects on the convolution descriptor. It is
|
||||||
// functionally correct because the convolution is run with the algorithm of
|
// functionally correct because the convolution is run with the algorithm of
|
||||||
@ -2070,10 +2067,9 @@ port::StatusOr<DeviceMemory<uint8>>
|
|||||||
AllocateCudnnConvolutionBackwardFilterWorkspace(
|
AllocateCudnnConvolutionBackwardFilterWorkspace(
|
||||||
Stream* stream, const CudnnHandle& cudnn,
|
Stream* stream, const CudnnHandle& cudnn,
|
||||||
const dnn::AlgorithmDesc& algorithm_desc,
|
const dnn::AlgorithmDesc& algorithm_desc,
|
||||||
const ScopedTensorDescriptor& input_nd,
|
const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
|
||||||
const ScopedFilterDescriptor& filter,
|
const CudnnConvolutionDescriptor& conv,
|
||||||
const ScopedConvolutionDescriptor& conv,
|
const CudnnTensorDescriptor& output_nd,
|
||||||
const ScopedTensorDescriptor& output_nd,
|
|
||||||
ScratchAllocator* scratch_allocator) {
|
ScratchAllocator* scratch_allocator) {
|
||||||
// TODO(csigg): This has side effects on the convolution descriptor. It is
|
// TODO(csigg): This has side effects on the convolution descriptor. It is
|
||||||
// functionally correct because the convolution is run with the algorithm of
|
// functionally correct because the convolution is run with the algorithm of
|
||||||
@ -2114,11 +2110,10 @@ AllocateCudnnConvolutionBackwardFilterWorkspace(
|
|||||||
port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionForwardAlgorithm(
|
port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionForwardAlgorithm(
|
||||||
Stream* stream, const CudnnHandle& cudnn,
|
Stream* stream, const CudnnHandle& cudnn,
|
||||||
const dnn::AlgorithmConfig& algorithm_config,
|
const dnn::AlgorithmConfig& algorithm_config,
|
||||||
const ScopedTensorDescriptor& input_nd,
|
const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
|
||||||
const ScopedFilterDescriptor& filter,
|
const CudnnConvolutionDescriptor& conv,
|
||||||
const ScopedConvolutionDescriptor& conv,
|
const CudnnTensorDescriptor& output_nd, ScratchAllocator* scratch_allocator,
|
||||||
const ScopedTensorDescriptor& output_nd,
|
DeviceMemory<uint8>* scratch) {
|
||||||
ScratchAllocator* scratch_allocator, DeviceMemory<uint8>* scratch) {
|
|
||||||
dnn::AlgorithmDesc algo_desc = algorithm_config.algorithm();
|
dnn::AlgorithmDesc algo_desc = algorithm_config.algorithm();
|
||||||
if (algorithm_config.algorithm().is_default()) {
|
if (algorithm_config.algorithm().is_default()) {
|
||||||
// Pick fastest algorithm within memory limit according to cuDNN's
|
// Pick fastest algorithm within memory limit according to cuDNN's
|
||||||
@ -2164,11 +2159,10 @@ port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionForwardAlgorithm(
|
|||||||
port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardDataAlgorithm(
|
port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardDataAlgorithm(
|
||||||
Stream* stream, const CudnnHandle& cudnn,
|
Stream* stream, const CudnnHandle& cudnn,
|
||||||
const dnn::AlgorithmConfig& algorithm_config,
|
const dnn::AlgorithmConfig& algorithm_config,
|
||||||
const ScopedTensorDescriptor& input_nd,
|
const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
|
||||||
const ScopedFilterDescriptor& filter,
|
const CudnnConvolutionDescriptor& conv,
|
||||||
const ScopedConvolutionDescriptor& conv,
|
const CudnnTensorDescriptor& output_nd, ScratchAllocator* scratch_allocator,
|
||||||
const ScopedTensorDescriptor& output_nd,
|
DeviceMemory<uint8>* scratch) {
|
||||||
ScratchAllocator* scratch_allocator, DeviceMemory<uint8>* scratch) {
|
|
||||||
dnn::AlgorithmDesc algo_desc = algorithm_config.algorithm();
|
dnn::AlgorithmDesc algo_desc = algorithm_config.algorithm();
|
||||||
if (algorithm_config.algorithm().is_default()) {
|
if (algorithm_config.algorithm().is_default()) {
|
||||||
// Pick fastest algorithm within memory limit according to cuDNN's
|
// Pick fastest algorithm within memory limit according to cuDNN's
|
||||||
@ -2214,11 +2208,10 @@ port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardDataAlgorithm(
|
|||||||
port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardFilterAlgorithm(
|
port::StatusOr<dnn::AlgorithmDesc> GetCudnnConvolutionBackwardFilterAlgorithm(
|
||||||
Stream* stream, const CudnnHandle& cudnn,
|
Stream* stream, const CudnnHandle& cudnn,
|
||||||
const dnn::AlgorithmConfig& algorithm_config,
|
const dnn::AlgorithmConfig& algorithm_config,
|
||||||
const ScopedTensorDescriptor& input_nd,
|
const CudnnTensorDescriptor& input_nd, const CudnnFilterDescriptor& filter,
|
||||||
const ScopedFilterDescriptor& filter,
|
const CudnnConvolutionDescriptor& conv,
|
||||||
const ScopedConvolutionDescriptor& conv,
|
const CudnnTensorDescriptor& output_nd, ScratchAllocator* scratch_allocator,
|
||||||
const ScopedTensorDescriptor& output_nd,
|
DeviceMemory<uint8>* scratch) {
|
||||||
ScratchAllocator* scratch_allocator, DeviceMemory<uint8>* scratch) {
|
|
||||||
dnn::AlgorithmDesc algo_desc = algorithm_config.algorithm();
|
dnn::AlgorithmDesc algo_desc = algorithm_config.algorithm();
|
||||||
if (algorithm_config.algorithm().is_default()) {
|
if (algorithm_config.algorithm().is_default()) {
|
||||||
// Pick fastest algorithm within memory limit according to cuDNN's
|
// Pick fastest algorithm within memory limit according to cuDNN's
|
||||||
@ -2387,10 +2380,10 @@ port::Status CudnnSupport::DoConvolveImpl(
|
|||||||
const dnn::AlgorithmConfig& algorithm_config,
|
const dnn::AlgorithmConfig& algorithm_config,
|
||||||
dnn::ProfileResult* output_profile_result) {
|
dnn::ProfileResult* output_profile_result) {
|
||||||
cudnnDataType_t cudnn_type = GetCudnnDataType<T>();
|
cudnnDataType_t cudnn_type = GetCudnnDataType<T>();
|
||||||
ScopedTensorDescriptor input_nd(input_descriptor, cudnn_type);
|
CudnnTensorDescriptor input_nd(input_descriptor, cudnn_type);
|
||||||
ScopedTensorDescriptor output_nd(output_descriptor, cudnn_type);
|
CudnnTensorDescriptor output_nd(output_descriptor, cudnn_type);
|
||||||
ScopedFilterDescriptor filter(filter_descriptor, cudnn_type);
|
CudnnFilterDescriptor filter(filter_descriptor, cudnn_type);
|
||||||
ScopedConvolutionDescriptor conv(convolution_descriptor,
|
CudnnConvolutionDescriptor conv(convolution_descriptor,
|
||||||
GetConvComputeType<T>());
|
GetConvComputeType<T>());
|
||||||
|
|
||||||
auto cudnn = cudnn_->GetHandle(parent_, stream);
|
auto cudnn = cudnn_->GetHandle(parent_, stream);
|
||||||
@ -2493,14 +2486,14 @@ port::Status CudnnSupport::DoFusedConvolveImpl(
|
|||||||
"Relu activation.");
|
"Relu activation.");
|
||||||
}
|
}
|
||||||
|
|
||||||
ScopedTensorDescriptor conv_input_nd(
|
CudnnTensorDescriptor conv_input_nd(
|
||||||
conv_input_descriptor, static_cast<cudnnDataType_t>(cudnn_data_type));
|
conv_input_descriptor, static_cast<cudnnDataType_t>(cudnn_data_type));
|
||||||
ScopedTensorDescriptor output_nd(
|
CudnnTensorDescriptor output_nd(
|
||||||
output_descriptor, static_cast<cudnnDataType_t>(cudnn_data_type));
|
output_descriptor, static_cast<cudnnDataType_t>(cudnn_data_type));
|
||||||
ScopedFilterDescriptor filter(filter_descriptor,
|
CudnnFilterDescriptor filter(filter_descriptor,
|
||||||
static_cast<cudnnDataType_t>(cudnn_data_type));
|
static_cast<cudnnDataType_t>(cudnn_data_type));
|
||||||
ScopedTensorDescriptor bias_nd(bias_descriptor, CUDNN_DATA_FLOAT);
|
CudnnTensorDescriptor bias_nd(bias_descriptor, CUDNN_DATA_FLOAT);
|
||||||
ScopedConvolutionDescriptor conv(
|
CudnnConvolutionDescriptor conv(
|
||||||
convolution_descriptor, static_cast<cudnnDataType_t>(cudnn_compute_type));
|
convolution_descriptor, static_cast<cudnnDataType_t>(cudnn_compute_type));
|
||||||
|
|
||||||
auto cudnn = cudnn_->GetHandle(parent_, stream);
|
auto cudnn = cudnn_->GetHandle(parent_, stream);
|
||||||
@ -2528,7 +2521,7 @@ port::Status CudnnSupport::DoFusedConvolveImpl(
|
|||||||
// activation descriptor. Note that this will change the nan propagation
|
// activation descriptor. Note that this will change the nan propagation
|
||||||
// behavior from separate conv, bias, and relu (which by default is
|
// behavior from separate conv, bias, and relu (which by default is
|
||||||
// CUDNN_PROPAGATE_NAN.
|
// CUDNN_PROPAGATE_NAN.
|
||||||
ScopedActivationDescriptor activation_desc(
|
CudnnActivationDescriptor activation_desc(
|
||||||
activation_mode, CUDNN_NOT_PROPAGATE_NAN, output_descriptor.value_max());
|
activation_mode, CUDNN_NOT_PROPAGATE_NAN, output_descriptor.value_max());
|
||||||
auto side_input_data_ptr = (side_input_scale == 0) ? output_data->opaque()
|
auto side_input_data_ptr = (side_input_scale == 0) ? output_data->opaque()
|
||||||
: side_input_data.opaque();
|
: side_input_data.opaque();
|
||||||
@ -2740,8 +2733,8 @@ port::Status CudnnSupport::DoBatchNormalizationForwardImpl(
|
|||||||
DeviceMemory<U>* saved_mean, DeviceMemory<U>* saved_inv_var,
|
DeviceMemory<U>* saved_mean, DeviceMemory<U>* saved_inv_var,
|
||||||
bool is_training, std::function<const DeviceMemory<U>&()> var_to_inv_var,
|
bool is_training, std::function<const DeviceMemory<U>&()> var_to_inv_var,
|
||||||
std::function<void()> inv_var_to_var) {
|
std::function<void()> inv_var_to_var) {
|
||||||
ScopedTensorDescriptor x_descriptor(x_desc, ToCudnnDataType(input_data_type));
|
CudnnTensorDescriptor x_descriptor(x_desc, ToCudnnDataType(input_data_type));
|
||||||
ScopedTensorDescriptor scale_offset_descriptor(
|
CudnnTensorDescriptor scale_offset_descriptor(
|
||||||
scale_offset_desc, ToCudnnDataType(scale_data_type));
|
scale_offset_desc, ToCudnnDataType(scale_data_type));
|
||||||
cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL;
|
cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL;
|
||||||
#if CUDNN_VERSION >= 7000
|
#if CUDNN_VERSION >= 7000
|
||||||
@ -2825,9 +2818,9 @@ port::Status CudnnSupport::DoBatchNormalizationBackwardImpl(
|
|||||||
const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
|
const dnn::BatchDescriptor& scale_offset_desc, const double epsilon,
|
||||||
DeviceMemory<T>* x_backprop, DeviceMemory<U>* scale_backprop,
|
DeviceMemory<T>* x_backprop, DeviceMemory<U>* scale_backprop,
|
||||||
DeviceMemory<U>* offset_backprop) {
|
DeviceMemory<U>* offset_backprop) {
|
||||||
ScopedTensorDescriptor x_descriptor(
|
CudnnTensorDescriptor x_descriptor(
|
||||||
x_desc, static_cast<cudnnDataType_t>(cudnn_input_type));
|
x_desc, static_cast<cudnnDataType_t>(cudnn_input_type));
|
||||||
ScopedTensorDescriptor scale_offset_descriptor(
|
CudnnTensorDescriptor scale_offset_descriptor(
|
||||||
scale_offset_desc, static_cast<cudnnDataType_t>(cudnn_scale_type));
|
scale_offset_desc, static_cast<cudnnDataType_t>(cudnn_scale_type));
|
||||||
cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL;
|
cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL;
|
||||||
#if CUDNN_VERSION >= 7000
|
#if CUDNN_VERSION >= 7000
|
||||||
@ -3017,9 +3010,9 @@ bool CudnnSupport::DoTransformTensor(Stream* stream,
|
|||||||
dnn::DataType output_type, float scale,
|
dnn::DataType output_type, float scale,
|
||||||
DeviceMemoryBase* output_data) {
|
DeviceMemoryBase* output_data) {
|
||||||
float beta = 0.0f;
|
float beta = 0.0f;
|
||||||
ScopedTensorDescriptor input_tensor_desc(
|
CudnnTensorDescriptor input_tensor_desc(
|
||||||
input_desc, ToCudnnDataType(input_type, input_desc.layout()));
|
input_desc, ToCudnnDataType(input_type, input_desc.layout()));
|
||||||
ScopedTensorDescriptor output_tensor_desc(
|
CudnnTensorDescriptor output_tensor_desc(
|
||||||
output_desc, ToCudnnDataType(output_type, output_desc.layout()));
|
output_desc, ToCudnnDataType(output_type, output_desc.layout()));
|
||||||
auto cudnn = cudnn_->GetHandle(parent_, stream);
|
auto cudnn = cudnn_->GetHandle(parent_, stream);
|
||||||
auto status = [&] {
|
auto status = [&] {
|
||||||
@ -3056,10 +3049,10 @@ port::Status CudnnSupport::DoConvolveBackwardDataImpl(
|
|||||||
|
|
||||||
auto cudnn = cudnn_->GetHandle(parent_, stream);
|
auto cudnn = cudnn_->GetHandle(parent_, stream);
|
||||||
|
|
||||||
ScopedTensorDescriptor out_back_nd(output_descriptor, cudnn_type);
|
CudnnTensorDescriptor out_back_nd(output_descriptor, cudnn_type);
|
||||||
ScopedTensorDescriptor in_back_nd(input_descriptor, cudnn_type);
|
CudnnTensorDescriptor in_back_nd(input_descriptor, cudnn_type);
|
||||||
ScopedFilterDescriptor filter(filter_descriptor, cudnn_type);
|
CudnnFilterDescriptor filter(filter_descriptor, cudnn_type);
|
||||||
ScopedConvolutionDescriptor conv(convolution_descriptor,
|
CudnnConvolutionDescriptor conv(convolution_descriptor,
|
||||||
GetConvComputeType<T>());
|
GetConvComputeType<T>());
|
||||||
|
|
||||||
const bool is_profiling = output_profile_result != nullptr;
|
const bool is_profiling = output_profile_result != nullptr;
|
||||||
@ -3192,10 +3185,10 @@ port::Status CudnnSupport::DoConvolveBackwardFilterImpl(
|
|||||||
|
|
||||||
auto cudnn = cudnn_->GetHandle(parent_, stream);
|
auto cudnn = cudnn_->GetHandle(parent_, stream);
|
||||||
|
|
||||||
ScopedTensorDescriptor out_back_nd(output_descriptor, cudnn_type);
|
CudnnTensorDescriptor out_back_nd(output_descriptor, cudnn_type);
|
||||||
ScopedTensorDescriptor input_nd(input_descriptor, cudnn_type);
|
CudnnTensorDescriptor input_nd(input_descriptor, cudnn_type);
|
||||||
ScopedFilterDescriptor filter(filter_descriptor, cudnn_type);
|
CudnnFilterDescriptor filter(filter_descriptor, cudnn_type);
|
||||||
ScopedConvolutionDescriptor conv(convolution_descriptor,
|
CudnnConvolutionDescriptor conv(convolution_descriptor,
|
||||||
GetConvComputeType<T>());
|
GetConvComputeType<T>());
|
||||||
|
|
||||||
const bool is_profiling = output_profile_result != nullptr;
|
const bool is_profiling = output_profile_result != nullptr;
|
||||||
@ -3338,8 +3331,8 @@ port::Status CudnnSupport::DoConvolveBackwardBiasImpl(
|
|||||||
const dnn::BatchDescriptor& bias_descriptor,
|
const dnn::BatchDescriptor& bias_descriptor,
|
||||||
DeviceMemory<T>* backward_bias_data) {
|
DeviceMemory<T>* backward_bias_data) {
|
||||||
cudnnDataType_t cudnn_type = GetCudnnDataType<T>();
|
cudnnDataType_t cudnn_type = GetCudnnDataType<T>();
|
||||||
ScopedTensorDescriptor input_nd(input_descriptor, cudnn_type);
|
CudnnTensorDescriptor input_nd(input_descriptor, cudnn_type);
|
||||||
ScopedTensorDescriptor bias_nd(bias_descriptor, cudnn_type);
|
CudnnTensorDescriptor bias_nd(bias_descriptor, cudnn_type);
|
||||||
|
|
||||||
// Alpha is the scaling factor for input.
|
// Alpha is the scaling factor for input.
|
||||||
float alpha = 1.0;
|
float alpha = 1.0;
|
||||||
@ -3526,7 +3519,7 @@ bool CudnnSupport::DoBiasAdd(Stream* stream,
|
|||||||
const DeviceMemory<float>& biases,
|
const DeviceMemory<float>& biases,
|
||||||
const dnn::BatchDescriptor& dimensions,
|
const dnn::BatchDescriptor& dimensions,
|
||||||
DeviceMemory<float>* output_data) {
|
DeviceMemory<float>* output_data) {
|
||||||
ScopedTensorDescriptor input_descriptor(dimensions, CUDNN_DATA_FLOAT);
|
CudnnTensorDescriptor input_descriptor(dimensions, CUDNN_DATA_FLOAT);
|
||||||
|
|
||||||
dnn::BatchDescriptor bias_dimensions;
|
dnn::BatchDescriptor bias_dimensions;
|
||||||
bias_dimensions.set_count(1)
|
bias_dimensions.set_count(1)
|
||||||
@ -3534,7 +3527,7 @@ bool CudnnSupport::DoBiasAdd(Stream* stream,
|
|||||||
.set_height(1)
|
.set_height(1)
|
||||||
.set_width(1)
|
.set_width(1)
|
||||||
.set_layout(dnn::DataLayout::kBatchYXDepth);
|
.set_layout(dnn::DataLayout::kBatchYXDepth);
|
||||||
ScopedTensorDescriptor bias_descriptor(bias_dimensions, CUDNN_DATA_FLOAT);
|
CudnnTensorDescriptor bias_descriptor(bias_dimensions, CUDNN_DATA_FLOAT);
|
||||||
|
|
||||||
// cudnnAddTensor after R3 is in-place, so we need to copy input_data to
|
// cudnnAddTensor after R3 is in-place, so we need to copy input_data to
|
||||||
// output_data before doing the addition, unless the input and
|
// output_data before doing the addition, unless the input and
|
||||||
@ -3570,10 +3563,10 @@ bool CudnnSupport::DoActivate(Stream* stream,
|
|||||||
const DeviceMemory<float>& input_data,
|
const DeviceMemory<float>& input_data,
|
||||||
DeviceMemory<float>* output_data,
|
DeviceMemory<float>* output_data,
|
||||||
uint64 options) {
|
uint64 options) {
|
||||||
ScopedActivationDescriptor activation_desc(
|
CudnnActivationDescriptor activation_desc(
|
||||||
activation_mode, CUDNN_PROPAGATE_NAN, dimensions.value_max());
|
activation_mode, CUDNN_PROPAGATE_NAN, dimensions.value_max());
|
||||||
|
|
||||||
ScopedTensorDescriptor input_nd(dimensions, CUDNN_DATA_FLOAT);
|
CudnnTensorDescriptor input_nd(dimensions, CUDNN_DATA_FLOAT);
|
||||||
// Alpha is the input scaling factor.
|
// Alpha is the input scaling factor.
|
||||||
float alpha = 1.0;
|
float alpha = 1.0;
|
||||||
// Beta is the output scaling factor.
|
// Beta is the output scaling factor.
|
||||||
@ -3600,9 +3593,9 @@ bool CudnnSupport::DoPoolForward(
|
|||||||
// Beta is the scaling factor for output.
|
// Beta is the scaling factor for output.
|
||||||
double beta = 0.0;
|
double beta = 0.0;
|
||||||
|
|
||||||
ScopedTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_DOUBLE);
|
CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_DOUBLE);
|
||||||
ScopedTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_DOUBLE);
|
CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_DOUBLE);
|
||||||
ScopedPoolingDescriptor pooling_desc(pooling_dimensions);
|
CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
|
||||||
|
|
||||||
auto cudnn = cudnn_->GetHandle(parent_, stream);
|
auto cudnn = cudnn_->GetHandle(parent_, stream);
|
||||||
auto status = [&] {
|
auto status = [&] {
|
||||||
@ -3625,9 +3618,9 @@ bool CudnnSupport::DoPoolForward(
|
|||||||
// Beta is the scaling factor for output.
|
// Beta is the scaling factor for output.
|
||||||
float beta = 0.0;
|
float beta = 0.0;
|
||||||
|
|
||||||
ScopedTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_FLOAT);
|
CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_FLOAT);
|
||||||
ScopedTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_FLOAT);
|
CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_FLOAT);
|
||||||
ScopedPoolingDescriptor pooling_desc(pooling_dimensions);
|
CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
|
||||||
|
|
||||||
auto cudnn = cudnn_->GetHandle(parent_, stream);
|
auto cudnn = cudnn_->GetHandle(parent_, stream);
|
||||||
auto status = [&] {
|
auto status = [&] {
|
||||||
@ -3650,9 +3643,9 @@ bool CudnnSupport::DoPoolForward(
|
|||||||
// Beta is the scaling factor for output.
|
// Beta is the scaling factor for output.
|
||||||
float beta = 0.0;
|
float beta = 0.0;
|
||||||
|
|
||||||
ScopedTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_HALF);
|
CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_HALF);
|
||||||
ScopedTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_HALF);
|
CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_HALF);
|
||||||
ScopedPoolingDescriptor pooling_desc(pooling_dimensions);
|
CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
|
||||||
auto cudnn = cudnn_->GetHandle(parent_, stream);
|
auto cudnn = cudnn_->GetHandle(parent_, stream);
|
||||||
auto status = [&] {
|
auto status = [&] {
|
||||||
RETURN_IF_CUDNN_ERROR(cudnnPoolingForward(
|
RETURN_IF_CUDNN_ERROR(cudnnPoolingForward(
|
||||||
@ -3676,9 +3669,9 @@ bool CudnnSupport::DoPoolBackward(
|
|||||||
// Beta is the scaling factor for output.
|
// Beta is the scaling factor for output.
|
||||||
double beta = 0.0;
|
double beta = 0.0;
|
||||||
|
|
||||||
ScopedTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_DOUBLE);
|
CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_DOUBLE);
|
||||||
ScopedTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_DOUBLE);
|
CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_DOUBLE);
|
||||||
ScopedPoolingDescriptor pooling_desc(pooling_dimensions);
|
CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
|
||||||
|
|
||||||
auto cudnn = cudnn_->GetHandle(parent_, stream);
|
auto cudnn = cudnn_->GetHandle(parent_, stream);
|
||||||
auto status = [&] {
|
auto status = [&] {
|
||||||
@ -3705,9 +3698,9 @@ bool CudnnSupport::DoPoolBackward(
|
|||||||
// Beta is the scaling factor for output.
|
// Beta is the scaling factor for output.
|
||||||
float beta = 0.0;
|
float beta = 0.0;
|
||||||
|
|
||||||
ScopedTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_FLOAT);
|
CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_FLOAT);
|
||||||
ScopedTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_FLOAT);
|
CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_FLOAT);
|
||||||
ScopedPoolingDescriptor pooling_desc(pooling_dimensions);
|
CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
|
||||||
|
|
||||||
auto cudnn = cudnn_->GetHandle(parent_, stream);
|
auto cudnn = cudnn_->GetHandle(parent_, stream);
|
||||||
auto status = [&] {
|
auto status = [&] {
|
||||||
@ -3734,9 +3727,9 @@ bool CudnnSupport::DoPoolBackward(
|
|||||||
// Beta is the scaling factor for output.
|
// Beta is the scaling factor for output.
|
||||||
float beta = 0.0;
|
float beta = 0.0;
|
||||||
|
|
||||||
ScopedTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_HALF);
|
CudnnTensorDescriptor src_desc(input_dimensions, CUDNN_DATA_HALF);
|
||||||
ScopedTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_HALF);
|
CudnnTensorDescriptor dest_desc(output_dimensions, CUDNN_DATA_HALF);
|
||||||
ScopedPoolingDescriptor pooling_desc(pooling_dimensions);
|
CudnnPoolingDescriptor pooling_desc(pooling_dimensions);
|
||||||
|
|
||||||
auto cudnn = cudnn_->GetHandle(parent_, stream);
|
auto cudnn = cudnn_->GetHandle(parent_, stream);
|
||||||
auto status = [&] {
|
auto status = [&] {
|
||||||
@ -3771,8 +3764,8 @@ bool CudnnSupport::DoNormalizeWithDimensions(
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
ScopedTensorDescriptor dims(dimensions, CUDNN_DATA_FLOAT);
|
CudnnTensorDescriptor dims(dimensions, CUDNN_DATA_FLOAT);
|
||||||
ScopedNormalizeDescriptor normalize(normalize_descriptor);
|
CudnnNormalizeDescriptor normalize(normalize_descriptor);
|
||||||
|
|
||||||
// Alpha is the scaling factor for input.
|
// Alpha is the scaling factor for input.
|
||||||
float alpha = 1.0f;
|
float alpha = 1.0f;
|
||||||
@ -3808,8 +3801,8 @@ bool CudnnSupport::DoNormalizeBackwardWithDimensions(
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
ScopedTensorDescriptor dims(dimensions, CUDNN_DATA_FLOAT);
|
CudnnTensorDescriptor dims(dimensions, CUDNN_DATA_FLOAT);
|
||||||
ScopedNormalizeDescriptor normalize(normalize_descriptor);
|
CudnnNormalizeDescriptor normalize(normalize_descriptor);
|
||||||
|
|
||||||
float alpha = 1.0f;
|
float alpha = 1.0f;
|
||||||
float beta = 0.0f;
|
float beta = 0.0f;
|
||||||
@ -3932,9 +3925,9 @@ bool CudnnSupport::DeriveOutputBatchDescriptor(
|
|||||||
const dnn::FilterDescriptor& filter_descriptor,
|
const dnn::FilterDescriptor& filter_descriptor,
|
||||||
const dnn::ConvolutionDescriptor& convolution_descriptor,
|
const dnn::ConvolutionDescriptor& convolution_descriptor,
|
||||||
dnn::BatchDescriptor* output_batch_descriptor) {
|
dnn::BatchDescriptor* output_batch_descriptor) {
|
||||||
ScopedTensorDescriptor input_nd(batch_descriptor, CUDNN_DATA_FLOAT);
|
CudnnTensorDescriptor input_nd(batch_descriptor, CUDNN_DATA_FLOAT);
|
||||||
ScopedFilterDescriptor filter(filter_descriptor, CUDNN_DATA_FLOAT);
|
CudnnFilterDescriptor filter(filter_descriptor, CUDNN_DATA_FLOAT);
|
||||||
ScopedConvolutionDescriptor conv(convolution_descriptor, CUDNN_DATA_FLOAT);
|
CudnnConvolutionDescriptor conv(convolution_descriptor, CUDNN_DATA_FLOAT);
|
||||||
|
|
||||||
int dn = batch_descriptor.ndims() + 2;
|
int dn = batch_descriptor.ndims() + 2;
|
||||||
std::vector<int> dims(dn); // in BDYX
|
std::vector<int> dims(dn); // in BDYX
|
||||||
|
@ -15,9 +15,6 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h"
|
#include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h"
|
||||||
|
|
||||||
#if defined(__APPLE__)
|
|
||||||
#include <mach-o/dyld.h>
|
|
||||||
#endif
|
|
||||||
#if defined(PLATFORM_WINDOWS)
|
#if defined(PLATFORM_WINDOWS)
|
||||||
#include <windows.h>
|
#include <windows.h>
|
||||||
#define PATH_MAX MAX_PATH
|
#define PATH_MAX MAX_PATH
|
||||||
@ -179,19 +176,11 @@ bool CUDAExecutor::FindOnDiskForComputeCapability(
|
|||||||
// would return /usr/bin.
|
// would return /usr/bin.
|
||||||
static string GetBinaryDir(bool strip_exe) {
|
static string GetBinaryDir(bool strip_exe) {
|
||||||
char exe_path[PATH_MAX] = {0};
|
char exe_path[PATH_MAX] = {0};
|
||||||
#if defined(__APPLE__)
|
|
||||||
uint32_t buffer_size = 0U;
|
|
||||||
_NSGetExecutablePath(nullptr, &buffer_size);
|
|
||||||
char unresolved_path[buffer_size];
|
|
||||||
_NSGetExecutablePath(unresolved_path, &buffer_size);
|
|
||||||
CHECK_ERR(realpath(unresolved_path, exe_path) ? 1 : -1);
|
|
||||||
#else
|
|
||||||
#if defined(PLATFORM_WINDOWS)
|
#if defined(PLATFORM_WINDOWS)
|
||||||
HMODULE hModule = GetModuleHandle(NULL);
|
HMODULE hModule = GetModuleHandle(NULL);
|
||||||
GetModuleFileName(hModule, exe_path, MAX_PATH);
|
GetModuleFileName(hModule, exe_path, MAX_PATH);
|
||||||
#else
|
#else
|
||||||
CHECK_ERR(readlink("/proc/self/exe", exe_path, sizeof(exe_path) - 1));
|
CHECK_ERR(readlink("/proc/self/exe", exe_path, sizeof(exe_path) - 1));
|
||||||
#endif
|
|
||||||
#endif
|
#endif
|
||||||
// Make sure it's null-terminated:
|
// Make sure it's null-terminated:
|
||||||
exe_path[sizeof(exe_path) - 1] = 0;
|
exe_path[sizeof(exe_path) - 1] = 0;
|
||||||
@ -854,10 +843,7 @@ CudaContext* CUDAExecutor::cuda_context() { return context_; }
|
|||||||
// For anything more complicated/prod-focused than this, you'll likely want to
|
// For anything more complicated/prod-focused than this, you'll likely want to
|
||||||
// turn to gsys' topology modeling.
|
// turn to gsys' topology modeling.
|
||||||
static int TryToReadNumaNode(const string &pci_bus_id, int device_ordinal) {
|
static int TryToReadNumaNode(const string &pci_bus_id, int device_ordinal) {
|
||||||
#if defined(__APPLE__)
|
#if defined(PLATFORM_WINDOWS)
|
||||||
LOG(INFO) << "OS X does not support NUMA - returning NUMA node zero";
|
|
||||||
return 0;
|
|
||||||
#elif defined(PLATFORM_WINDOWS)
|
|
||||||
// Windows support for NUMA is not currently implemented. Return node 0.
|
// Windows support for NUMA is not currently implemented. Return node 0.
|
||||||
return 0;
|
return 0;
|
||||||
#elif defined(__aarch64__)
|
#elif defined(__aarch64__)
|
||||||
|
@ -5,12 +5,21 @@ licenses(["notice"]) # Apache 2.0
|
|||||||
|
|
||||||
exports_files(["LICENSE"])
|
exports_files(["LICENSE"])
|
||||||
|
|
||||||
|
load("//tensorflow/tools/api/generator:api_gen.bzl", "TENSORFLOW_API_INIT_FILES")
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "doc_srcs",
|
||||||
|
srcs = ["doc_srcs.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
)
|
||||||
|
|
||||||
py_binary(
|
py_binary(
|
||||||
name = "create_python_api",
|
name = "create_python_api",
|
||||||
srcs = ["create_python_api.py"],
|
srcs = ["create_python_api.py"],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":doc_srcs",
|
||||||
"//tensorflow/python:no_contrib",
|
"//tensorflow/python:no_contrib",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -24,3 +33,18 @@ py_test(
|
|||||||
"//tensorflow/python:client_testlib",
|
"//tensorflow/python:client_testlib",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "tensorflow_doc_srcs_test",
|
||||||
|
srcs = ["doc_srcs_test.py"],
|
||||||
|
args = [
|
||||||
|
"--package=tensorflow.python",
|
||||||
|
] + TENSORFLOW_API_INIT_FILES,
|
||||||
|
main = "doc_srcs_test.py",
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":doc_srcs",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//tensorflow/python:no_contrib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@ -26,6 +26,7 @@ import sys
|
|||||||
|
|
||||||
from tensorflow.python.util import tf_decorator
|
from tensorflow.python.util import tf_decorator
|
||||||
from tensorflow.python.util import tf_export
|
from tensorflow.python.util import tf_export
|
||||||
|
from tensorflow.tools.api.generator import doc_srcs
|
||||||
|
|
||||||
API_ATTRS = tf_export.API_ATTRS
|
API_ATTRS = tf_export.API_ATTRS
|
||||||
|
|
||||||
@ -36,10 +37,9 @@ _SYMBOLS_TO_SKIP_EXPLICITLY = {
|
|||||||
# would have side effects.
|
# would have side effects.
|
||||||
'tensorflow.python.platform.flags.FLAGS'
|
'tensorflow.python.platform.flags.FLAGS'
|
||||||
}
|
}
|
||||||
_GENERATED_FILE_HEADER = """\"\"\"Imports for Python API.
|
_GENERATED_FILE_HEADER = """# This file is MACHINE GENERATED! Do not edit.
|
||||||
|
# Generated by: tensorflow/tools/api/generator/create_python_api.py script.
|
||||||
This file is MACHINE GENERATED! Do not edit.
|
\"\"\"%s
|
||||||
Generated by: tensorflow/tools/api/generator/create_python_api.py script.
|
|
||||||
\"\"\"
|
\"\"\"
|
||||||
|
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
@ -252,6 +252,44 @@ def get_module(dir_path, relative_to_dir):
|
|||||||
return dir_path.replace('/', '.').strip('.')
|
return dir_path.replace('/', '.').strip('.')
|
||||||
|
|
||||||
|
|
||||||
|
def get_module_docstring(module_name, package):
|
||||||
|
"""Get docstring for the given module.
|
||||||
|
|
||||||
|
This method looks for docstring in the following order:
|
||||||
|
1. Checks if module has a docstring specified in doc_srcs.
|
||||||
|
2. Checks if module has a docstring source module specified
|
||||||
|
in doc_srcs. If it does, gets docstring from that module.
|
||||||
|
3. Checks if module with module_name exists under base package.
|
||||||
|
If it does, gets docstring from that module.
|
||||||
|
4. Returns a default docstring.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module_name: module name relative to tensorflow
|
||||||
|
(excluding 'tensorflow.' prefix) to get a docstring for.
|
||||||
|
package: Base python package containing python with target tf_export
|
||||||
|
decorators.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
One-line docstring to describe the module.
|
||||||
|
"""
|
||||||
|
# Module under base package to get a docstring from.
|
||||||
|
docstring_module_name = module_name
|
||||||
|
|
||||||
|
if module_name in doc_srcs.TENSORFLOW_DOC_SOURCES:
|
||||||
|
docsrc = doc_srcs.TENSORFLOW_DOC_SOURCES[module_name]
|
||||||
|
if docsrc.docstring:
|
||||||
|
return docsrc.docstring
|
||||||
|
if docsrc.docstring_module_name:
|
||||||
|
docstring_module_name = docsrc.docstring_module_name
|
||||||
|
|
||||||
|
docstring_module_name = package + '.' + docstring_module_name
|
||||||
|
if (docstring_module_name in sys.modules and
|
||||||
|
sys.modules[docstring_module_name].__doc__):
|
||||||
|
return sys.modules[docstring_module_name].__doc__
|
||||||
|
|
||||||
|
return 'Public API for tf.%s namespace.' % module_name
|
||||||
|
|
||||||
|
|
||||||
def create_api_files(
|
def create_api_files(
|
||||||
output_files, package, root_init_template, output_dir, api_name):
|
output_files, package, root_init_template, output_dir, api_name):
|
||||||
"""Creates __init__.py files for the Python API.
|
"""Creates __init__.py files for the Python API.
|
||||||
@ -295,7 +333,9 @@ def create_api_files(
|
|||||||
continue
|
continue
|
||||||
contents = ''
|
contents = ''
|
||||||
if module or not root_init_template:
|
if module or not root_init_template:
|
||||||
contents = _GENERATED_FILE_HEADER + text + _GENERATED_FILE_FOOTER
|
contents = (
|
||||||
|
_GENERATED_FILE_HEADER %
|
||||||
|
get_module_docstring(module, package) + text + _GENERATED_FILE_FOOTER)
|
||||||
else:
|
else:
|
||||||
# Read base init file
|
# Read base init file
|
||||||
with open(root_init_template, 'r') as root_init_template_file:
|
with open(root_init_template, 'r') as root_init_template_file:
|
||||||
@ -308,7 +348,7 @@ def create_api_files(
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Missing outputs for python_api_gen genrule:\n%s.'
|
'Missing outputs for python_api_gen genrule:\n%s.'
|
||||||
'Make sure all required outputs are in the '
|
'Make sure all required outputs are in the '
|
||||||
'tensorflow/tools/api/generator/BUILD file.' %
|
'tensorflow/tools/api/generator/api_gen.bzl file.' %
|
||||||
',\n'.join(sorted(missing_output_files)))
|
',\n'.join(sorted(missing_output_files)))
|
||||||
|
|
||||||
|
|
||||||
|
65
tensorflow/tools/api/generator/doc_srcs.py
Normal file
65
tensorflow/tools/api/generator/doc_srcs.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Specifies sources of doc strings for API modules."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
|
|
||||||
|
|
||||||
|
# Specifies docstring source for a module.
|
||||||
|
# Only one of docstring or docstring_module_name should be set.
|
||||||
|
# * If docstring is set, then we will use this docstring when
|
||||||
|
# for the module.
|
||||||
|
# * If docstring_module_name is set, then we will copy the docstring
|
||||||
|
# from docstring source module.
|
||||||
|
DocSource = collections.namedtuple(
|
||||||
|
'DocSource', ['docstring', 'docstring_module_name'])
|
||||||
|
# Each attribute of DocSource is optional.
|
||||||
|
DocSource.__new__.__defaults__ = (None,) * len(DocSource._fields)
|
||||||
|
|
||||||
|
TENSORFLOW_DOC_SOURCES = {
|
||||||
|
'app': DocSource(docstring_module_name='platform.app'),
|
||||||
|
'compat': DocSource(docstring_module_name='util.compat'),
|
||||||
|
'distributions': DocSource(
|
||||||
|
docstring_module_name='ops.distributions.distributions'),
|
||||||
|
'bitwise': DocSource(docstring_module_name='ops.bitwise_ops'),
|
||||||
|
'errors': DocSource(docstring_module_name='framework.errors'),
|
||||||
|
'gfile': DocSource(docstring_module_name='platform.gfile'),
|
||||||
|
'graph_util': DocSource(docstring_module_name='framework.graph_util'),
|
||||||
|
'image': DocSource(docstring_module_name='ops.image_ops'),
|
||||||
|
'keras.estimator': DocSource(docstring_module_name='estimator.keras'),
|
||||||
|
'linalg': DocSource(docstring_module_name='ops.linalg_ops'),
|
||||||
|
'logging': DocSource(docstring_module_name='ops.logging_ops'),
|
||||||
|
'losses': DocSource(docstring_module_name='ops.losses.losses'),
|
||||||
|
'manip': DocSource(docstring_module_name='ops.manip_ops'),
|
||||||
|
'math': DocSource(docstring_module_name='ops.math_ops'),
|
||||||
|
'metrics': DocSource(docstring_module_name='ops.metrics'),
|
||||||
|
'nn': DocSource(docstring_module_name='ops.nn_ops'),
|
||||||
|
'nn.rnn_cell': DocSource(docstring_module_name='ops.rnn_cell'),
|
||||||
|
'python_io': DocSource(docstring_module_name='lib.io.python_io'),
|
||||||
|
'resource_loader': DocSource(
|
||||||
|
docstring_module_name='platform.resource_loader'),
|
||||||
|
'sets': DocSource(docstring_module_name='ops.sets'),
|
||||||
|
'sparse': DocSource(docstring_module_name='ops.sparse_ops'),
|
||||||
|
'spectral': DocSource(docstring_module_name='ops.spectral_ops'),
|
||||||
|
'strings': DocSource(docstring_module_name='ops.string_ops'),
|
||||||
|
'sysconfig': DocSource(docstring_module_name='platform.sysconfig'),
|
||||||
|
'test': DocSource(docstring_module_name='platform.test'),
|
||||||
|
'train': DocSource(docstring_module_name='training.training'),
|
||||||
|
'train.queue_runner': DocSource(
|
||||||
|
docstring_module_name='training.queue_runner'),
|
||||||
|
}
|
80
tensorflow/tools/api/generator/doc_srcs_test.py
Normal file
80
tensorflow/tools/api/generator/doc_srcs_test.py
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# =============================================================================
|
||||||
|
"""Tests for tensorflow.tools.api.generator.doc_srcs."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import importlib
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
from tensorflow.tools.api.generator import doc_srcs
|
||||||
|
|
||||||
|
|
||||||
|
FLAGS = None
|
||||||
|
|
||||||
|
|
||||||
|
class DocSrcsTest(test.TestCase):
|
||||||
|
|
||||||
|
def testModulesAreValidAPIModules(self):
|
||||||
|
for module_name in doc_srcs.TENSORFLOW_DOC_SOURCES:
|
||||||
|
# Convert module_name to corresponding __init__.py file path.
|
||||||
|
file_path = module_name.replace('.', '/')
|
||||||
|
if file_path:
|
||||||
|
file_path += '/'
|
||||||
|
file_path += '__init__.py'
|
||||||
|
|
||||||
|
if file_path not in FLAGS.outputs:
|
||||||
|
self.assertFalse('%s is not a valid API module' % module_name)
|
||||||
|
|
||||||
|
def testHaveDocstringOrDocstringModule(self):
|
||||||
|
for module_name, docsrc in doc_srcs.TENSORFLOW_DOC_SOURCES.items():
|
||||||
|
if docsrc.docstring and docsrc.docstring_module_name:
|
||||||
|
self.assertFalse(
|
||||||
|
'%s contains DocSource has both a docstring and a '
|
||||||
|
'docstring_module_name. '
|
||||||
|
'Only one of "docstring" or "docstring_module_name" should be set.'
|
||||||
|
% (module_name))
|
||||||
|
|
||||||
|
def testDocstringModulesAreValidModules(self):
|
||||||
|
for _, docsrc in doc_srcs.TENSORFLOW_DOC_SOURCES.items():
|
||||||
|
if docsrc.docstring_module_name:
|
||||||
|
doc_module_name = '.'.join([
|
||||||
|
FLAGS.package, docsrc.docstring_module_name])
|
||||||
|
if doc_module_name not in sys.modules:
|
||||||
|
sys.assertFalse(
|
||||||
|
'docsources_module %s is not a valid module under %s.' %
|
||||||
|
(docsrc.docstring_module_name, FLAGS.package))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
'outputs', metavar='O', type=str, nargs='+',
|
||||||
|
help='create_python_api output files.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--package', type=str,
|
||||||
|
help='Base package that imports modules containing the target tf_export '
|
||||||
|
'decorators.')
|
||||||
|
FLAGS, unparsed = parser.parse_known_args()
|
||||||
|
|
||||||
|
importlib.import_module(FLAGS.package)
|
||||||
|
|
||||||
|
# Now update argv, so that unittest library does not get confused.
|
||||||
|
sys.argv = [sys.argv[0]] + unparsed
|
||||||
|
test.main()
|
@ -20,6 +20,10 @@ tf_module {
|
|||||||
name: "adjust_hue"
|
name: "adjust_hue"
|
||||||
argspec: "args=[\'image\', \'delta\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'image\', \'delta\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "adjust_jpeg_quality"
|
||||||
|
argspec: "args=[\'image\', \'jpeg_quality\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "adjust_saturation"
|
name: "adjust_saturation"
|
||||||
argspec: "args=[\'image\', \'saturation_factor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'image\', \'saturation_factor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
@ -144,6 +148,10 @@ tf_module {
|
|||||||
name: "random_hue"
|
name: "random_hue"
|
||||||
argspec: "args=[\'image\', \'max_delta\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'image\', \'max_delta\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "random_jpeg_quality"
|
||||||
|
argspec: "args=[\'image\', \'min_jpeg_quality\', \'max_jpeg_quality\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "random_saturation"
|
name: "random_saturation"
|
||||||
argspec: "args=[\'image\', \'lower\', \'upper\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'image\', \'lower\', \'upper\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
@ -166,7 +174,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "resize_images"
|
name: "resize_images"
|
||||||
argspec: "args=[\'images\', \'size\', \'method\', \'align_corners\'], varargs=None, keywords=None, defaults=[\'0\', \'False\'], "
|
argspec: "args=[\'images\', \'size\', \'method\', \'align_corners\', \'preserve_aspect_ratio\'], varargs=None, keywords=None, defaults=[\'0\', \'False\', \'False\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "resize_nearest_neighbor"
|
name: "resize_nearest_neighbor"
|
||||||
|
@ -1,5 +1,21 @@
|
|||||||
path: "tensorflow.math"
|
path: "tensorflow.math"
|
||||||
tf_module {
|
tf_module {
|
||||||
|
member_method {
|
||||||
|
name: "bessel_i0"
|
||||||
|
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'bessel_i0\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "bessel_i0e"
|
||||||
|
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "bessel_i1"
|
||||||
|
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'bessel_i1\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "bessel_i1e"
|
||||||
|
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "polyval"
|
name: "polyval"
|
||||||
argspec: "args=[\'coeffs\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'coeffs\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
@ -2,7 +2,7 @@ path: "tensorflow.saved_model.loader"
|
|||||||
tf_module {
|
tf_module {
|
||||||
member_method {
|
member_method {
|
||||||
name: "load"
|
name: "load"
|
||||||
argspec: "args=[\'sess\', \'tags\', \'export_dir\'], varargs=None, keywords=saver_kwargs, defaults=None"
|
argspec: "args=[\'sess\', \'tags\', \'export_dir\', \'import_scope\'], varargs=None, keywords=saver_kwargs, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "maybe_saved_model_directory"
|
name: "maybe_saved_model_directory"
|
||||||
|
@ -778,11 +778,9 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
|
|||||||
actual = "@grpc//:grpc_python_plugin",
|
actual = "@grpc//:grpc_python_plugin",
|
||||||
)
|
)
|
||||||
|
|
||||||
# gRPC has three empty C++ functions which it wants the user to define
|
|
||||||
# at build time. https://github.com/grpc/grpc/issues/13590
|
|
||||||
native.bind(
|
native.bind(
|
||||||
name = "grpc_lib",
|
name = "grpc_lib",
|
||||||
actual = "@grpc//:grpc++_unsecure",
|
actual = "@grpc//:grpc++",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Needed by gRPC
|
# Needed by gRPC
|
||||||
|
Loading…
x
Reference in New Issue
Block a user