Merge branch 'r0.10' of github.com:tensorflow/tensorflow into r0.10

This commit is contained in:
Gunhan Gulsoy 2016-08-22 10:48:14 -07:00
commit 67734a1df6
37 changed files with 2741 additions and 855 deletions

View File

@ -441,7 +441,7 @@ static void TF_Run_Helper(
const std::vector<tensorflow::string>& output_tensor_names, const std::vector<tensorflow::string>& output_tensor_names,
TF_Tensor** c_outputs, TF_Tensor** c_outputs,
// Target nodes // Target nodes
const std::vector<tensorflow::string>& target_node_names, const std::vector<tensorflow::string>& target_oper_names,
TF_Buffer* run_metadata, TF_Status* status) { TF_Buffer* run_metadata, TF_Status* status) {
const int noutputs = output_tensor_names.size(); const int noutputs = output_tensor_names.size();
std::vector<Tensor> outputs(noutputs); std::vector<Tensor> outputs(noutputs);
@ -464,7 +464,7 @@ static void TF_Run_Helper(
RunMetadata run_metadata_proto; RunMetadata run_metadata_proto;
result = session->Run(run_options_proto, input_pairs, output_tensor_names, result = session->Run(run_options_proto, input_pairs, output_tensor_names,
target_node_names, &outputs, &run_metadata_proto); target_oper_names, &outputs, &run_metadata_proto);
// Serialize back to upstream client, who now owns the new buffer // Serialize back to upstream client, who now owns the new buffer
if (run_metadata != nullptr) { if (run_metadata != nullptr) {
@ -512,10 +512,9 @@ void TF_Run(TF_Session* s, const TF_Buffer* run_options,
// Input tensors // Input tensors
const char** c_input_names, TF_Tensor** c_inputs, int ninputs, const char** c_input_names, TF_Tensor** c_inputs, int ninputs,
// Output tensors // Output tensors
const char** c_output_tensor_names, TF_Tensor** c_outputs, const char** c_output_names, TF_Tensor** c_outputs, int noutputs,
int noutputs,
// Target nodes // Target nodes
const char** c_target_node_names, int ntargets, const char** c_target_oper_names, int ntargets,
TF_Buffer* run_metadata, TF_Status* status) { TF_Buffer* run_metadata, TF_Status* status) {
TF_Run_Setup(noutputs, c_outputs, status); TF_Run_Setup(noutputs, c_outputs, status);
std::vector<std::pair<tensorflow::string, Tensor>> input_pairs(ninputs); std::vector<std::pair<tensorflow::string, Tensor>> input_pairs(ninputs);
@ -523,45 +522,44 @@ void TF_Run(TF_Session* s, const TF_Buffer* run_options,
for (int i = 0; i < ninputs; ++i) { for (int i = 0; i < ninputs; ++i) {
input_pairs[i].first = c_input_names[i]; input_pairs[i].first = c_input_names[i];
} }
std::vector<tensorflow::string> output_tensor_names(noutputs); std::vector<tensorflow::string> output_names(noutputs);
for (int i = 0; i < noutputs; ++i) { for (int i = 0; i < noutputs; ++i) {
output_tensor_names[i] = c_output_tensor_names[i]; output_names[i] = c_output_names[i];
} }
std::vector<tensorflow::string> target_node_names(ntargets); std::vector<tensorflow::string> target_oper_names(ntargets);
for (int i = 0; i < ntargets; ++i) { for (int i = 0; i < ntargets; ++i) {
target_node_names[i] = c_target_node_names[i]; target_oper_names[i] = c_target_oper_names[i];
} }
TF_Run_Helper(s->session, nullptr, run_options, input_pairs, TF_Run_Helper(s->session, nullptr, run_options, input_pairs, output_names,
output_tensor_names, c_outputs, target_node_names, run_metadata, c_outputs, target_oper_names, run_metadata, status);
status);
} }
void TF_PRunSetup(TF_Session* s, void TF_PRunSetup(TF_Session* s,
// Input names // Input names
const char** c_input_names, int ninputs, const char** c_input_names, int ninputs,
// Output names // Output names
const char** c_output_tensor_names, int noutputs, const char** c_output_names, int noutputs,
// Target nodes // Target nodes
const char** c_target_node_names, int ntargets, const char** c_target_oper_names, int ntargets,
const char** handle, TF_Status* status) { const char** handle, TF_Status* status) {
status->status = Status::OK(); status->status = Status::OK();
std::vector<tensorflow::string> input_names(ninputs); std::vector<tensorflow::string> input_names(ninputs);
std::vector<tensorflow::string> output_tensor_names(noutputs); std::vector<tensorflow::string> output_names(noutputs);
std::vector<tensorflow::string> target_node_names(ntargets); std::vector<tensorflow::string> target_oper_names(ntargets);
for (int i = 0; i < ninputs; ++i) { for (int i = 0; i < ninputs; ++i) {
input_names[i] = c_input_names[i]; input_names[i] = c_input_names[i];
} }
for (int i = 0; i < noutputs; ++i) { for (int i = 0; i < noutputs; ++i) {
output_tensor_names[i] = c_output_tensor_names[i]; output_names[i] = c_output_names[i];
} }
for (int i = 0; i < ntargets; ++i) { for (int i = 0; i < ntargets; ++i) {
target_node_names[i] = c_target_node_names[i]; target_oper_names[i] = c_target_oper_names[i];
} }
tensorflow::string new_handle; tensorflow::string new_handle;
Status result; Status result;
result = s->session->PRunSetup(input_names, output_tensor_names, result = s->session->PRunSetup(input_names, output_names, target_oper_names,
target_node_names, &new_handle); &new_handle);
if (result.ok()) { if (result.ok()) {
char* buf = new char[new_handle.size() + 1]; char* buf = new char[new_handle.size() + 1];
memcpy(buf, new_handle.c_str(), new_handle.size() + 1); memcpy(buf, new_handle.c_str(), new_handle.size() + 1);
@ -575,10 +573,9 @@ void TF_PRun(TF_Session* s, const char* handle,
// Input tensors // Input tensors
const char** c_input_names, TF_Tensor** c_inputs, int ninputs, const char** c_input_names, TF_Tensor** c_inputs, int ninputs,
// Output tensors // Output tensors
const char** c_output_tensor_names, TF_Tensor** c_outputs, const char** c_output_names, TF_Tensor** c_outputs, int noutputs,
int noutputs,
// Target nodes // Target nodes
const char** c_target_node_names, int ntargets, const char** c_target_oper_names, int ntargets,
TF_Status* status) { TF_Status* status) {
TF_Run_Setup(noutputs, c_outputs, status); TF_Run_Setup(noutputs, c_outputs, status);
std::vector<std::pair<tensorflow::string, Tensor>> input_pairs(ninputs); std::vector<std::pair<tensorflow::string, Tensor>> input_pairs(ninputs);
@ -587,16 +584,16 @@ void TF_PRun(TF_Session* s, const char* handle,
input_pairs[i].first = c_input_names[i]; input_pairs[i].first = c_input_names[i];
} }
std::vector<tensorflow::string> output_tensor_names(noutputs); std::vector<tensorflow::string> output_names(noutputs);
for (int i = 0; i < noutputs; ++i) { for (int i = 0; i < noutputs; ++i) {
output_tensor_names[i] = c_output_tensor_names[i]; output_names[i] = c_output_names[i];
} }
std::vector<tensorflow::string> target_node_names(ntargets); std::vector<tensorflow::string> target_oper_names(ntargets);
for (int i = 0; i < ntargets; ++i) { for (int i = 0; i < ntargets; ++i) {
target_node_names[i] = c_target_node_names[i]; target_oper_names[i] = c_target_oper_names[i];
} }
TF_Run_Helper(s->session, handle, nullptr, input_pairs, output_tensor_names, TF_Run_Helper(s->session, handle, nullptr, input_pairs, output_names,
c_outputs, target_node_names, nullptr, status); c_outputs, target_oper_names, nullptr, status);
} }
struct TF_Library { struct TF_Library {
@ -643,15 +640,16 @@ struct TF_Graph {
bool delete_requested; // set true by TF_DeleteGraph bool delete_requested; // set true by TF_DeleteGraph
}; };
struct TF_NodeDescription { struct TF_OperationDescription {
TF_NodeDescription(TF_Graph* g, const char* op_type, const char* node_name) TF_OperationDescription(TF_Graph* g, const char* op_type,
const char* node_name)
: node_builder(node_name, op_type, g->graph.op_registry()), graph(g) {} : node_builder(node_name, op_type, g->graph.op_registry()), graph(g) {}
NodeBuilder node_builder; NodeBuilder node_builder;
TF_Graph* graph; TF_Graph* graph;
}; };
struct TF_Node { struct TF_Operation {
Node node; Node node;
}; };
@ -670,55 +668,56 @@ struct TF_SessionWithGraph {
namespace { namespace {
TF_Node* ToNode(Node* node) { TF_Operation* ToOperation(Node* node) {
return static_cast<TF_Node*>(static_cast<void*>(node)); return static_cast<TF_Operation*>(static_cast<void*>(node));
} }
tensorflow::string PortName(const TF_Port& port) { tensorflow::string PortName(const TF_Port& port) {
return tensorflow::strings::StrCat(port.node->node.name(), ":", port.index); return tensorflow::strings::StrCat(port.oper->node.name(), ":", port.index);
} }
} // namespace } // namespace
// TF_NodeDescription functions ----------------------------------------------- // TF_OperationDescription functions
// -----------------------------------------------
extern "C" { extern "C" {
TF_NodeDescription* TF_NewNode(TF_Graph* graph, const char* op_type, TF_OperationDescription* TF_NewOperation(TF_Graph* graph, const char* op_type,
const char* node_name) { const char* oper_name) {
mutex_lock l(graph->mu); mutex_lock l(graph->mu);
return new TF_NodeDescription(graph, op_type, node_name); return new TF_OperationDescription(graph, op_type, oper_name);
} }
void TF_SetDevice(TF_NodeDescription* desc, const char* device) { void TF_SetDevice(TF_OperationDescription* desc, const char* device) {
desc->node_builder.Device(device); desc->node_builder.Device(device);
} }
void TF_AddInput(TF_NodeDescription* desc, TF_Port input) { void TF_AddInput(TF_OperationDescription* desc, TF_Port input) {
desc->node_builder.Input(&input.node->node, input.index); desc->node_builder.Input(&input.oper->node, input.index);
} }
void TF_AddInputList(TF_NodeDescription* desc, const TF_Port* inputs, void TF_AddInputList(TF_OperationDescription* desc, const TF_Port* inputs,
int num_inputs) { int num_inputs) {
std::vector<NodeBuilder::NodeOut> input_list; std::vector<NodeBuilder::NodeOut> input_list;
input_list.reserve(num_inputs); input_list.reserve(num_inputs);
for (int i = 0; i < num_inputs; ++i) { for (int i = 0; i < num_inputs; ++i) {
input_list.emplace_back(&inputs[i].node->node, inputs[i].index); input_list.emplace_back(&inputs[i].oper->node, inputs[i].index);
} }
desc->node_builder.Input(input_list); desc->node_builder.Input(input_list);
} }
void TF_AddControlInput(TF_NodeDescription* desc, TF_Node* input) { void TF_AddControlInput(TF_OperationDescription* desc, TF_Operation* input) {
desc->node_builder.ControlInput(&input->node); desc->node_builder.ControlInput(&input->node);
} }
void TF_SetAttrString(TF_NodeDescription* desc, const char* attr_name, void TF_SetAttrString(TF_OperationDescription* desc, const char* attr_name,
const void* value, int length) { const void* value, int length) {
tensorflow::StringPiece s(static_cast<const char*>(value), length); tensorflow::StringPiece s(static_cast<const char*>(value), length);
desc->node_builder.Attr(attr_name, s); desc->node_builder.Attr(attr_name, s);
} }
void TF_SetAttrStringList(TF_NodeDescription* desc, const char* attr_name, void TF_SetAttrStringList(TF_OperationDescription* desc, const char* attr_name,
const void* const* values, const int* lengths, const void* const* values, const int* lengths,
int num_values) { int num_values) {
std::vector<tensorflow::StringPiece> v; std::vector<tensorflow::StringPiece> v;
@ -729,14 +728,14 @@ void TF_SetAttrStringList(TF_NodeDescription* desc, const char* attr_name,
desc->node_builder.Attr(attr_name, v); desc->node_builder.Attr(attr_name, v);
} }
void TF_SetAttrInt(TF_NodeDescription* desc, const char* attr_name, void TF_SetAttrInt(TF_OperationDescription* desc, const char* attr_name,
int64_t value) { int64_t value) {
static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
"64-bit int types should match in size"); "64-bit int types should match in size");
desc->node_builder.Attr(attr_name, static_cast<tensorflow::int64>(value)); desc->node_builder.Attr(attr_name, static_cast<tensorflow::int64>(value));
} }
void TF_SetAttrIntList(TF_NodeDescription* desc, const char* attr_name, void TF_SetAttrIntList(TF_OperationDescription* desc, const char* attr_name,
const int64_t* values, int num_values) { const int64_t* values, int num_values) {
static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
"64-bit int types should match in size"); "64-bit int types should match in size");
@ -746,23 +745,23 @@ void TF_SetAttrIntList(TF_NodeDescription* desc, const char* attr_name,
reinterpret_cast<const tensorflow::int64*>(values), num_values)); reinterpret_cast<const tensorflow::int64*>(values), num_values));
} }
void TF_SetAttrFloat(TF_NodeDescription* desc, const char* attr_name, void TF_SetAttrFloat(TF_OperationDescription* desc, const char* attr_name,
float value) { float value) {
desc->node_builder.Attr(attr_name, value); desc->node_builder.Attr(attr_name, value);
} }
void TF_SetAttrFloatList(TF_NodeDescription* desc, const char* attr_name, void TF_SetAttrFloatList(TF_OperationDescription* desc, const char* attr_name,
const float* values, int num_values) { const float* values, int num_values) {
desc->node_builder.Attr(attr_name, desc->node_builder.Attr(attr_name,
ArraySlice<const float>(values, num_values)); ArraySlice<const float>(values, num_values));
} }
void TF_SetAttrBool(TF_NodeDescription* desc, const char* attr_name, void TF_SetAttrBool(TF_OperationDescription* desc, const char* attr_name,
unsigned char value) { unsigned char value) {
desc->node_builder.Attr(attr_name, static_cast<bool>(value)); desc->node_builder.Attr(attr_name, static_cast<bool>(value));
} }
void TF_SetAttrBoolList(TF_NodeDescription* desc, const char* attr_name, void TF_SetAttrBoolList(TF_OperationDescription* desc, const char* attr_name,
const unsigned char* values, int num_values) { const unsigned char* values, int num_values) {
bool* b = new bool[num_values]; bool* b = new bool[num_values];
for (int i = 0; i < num_values; ++i) { for (int i = 0; i < num_values; ++i) {
@ -771,19 +770,19 @@ void TF_SetAttrBoolList(TF_NodeDescription* desc, const char* attr_name,
desc->node_builder.Attr(attr_name, ArraySlice<const bool>(b, num_values)); desc->node_builder.Attr(attr_name, ArraySlice<const bool>(b, num_values));
} }
void TF_SetAttrType(TF_NodeDescription* desc, const char* attr_name, void TF_SetAttrType(TF_OperationDescription* desc, const char* attr_name,
TF_DataType value) { TF_DataType value) {
desc->node_builder.Attr(attr_name, static_cast<DataType>(value)); desc->node_builder.Attr(attr_name, static_cast<DataType>(value));
} }
void TF_SetAttrTypeList(TF_NodeDescription* desc, const char* attr_name, void TF_SetAttrTypeList(TF_OperationDescription* desc, const char* attr_name,
const TF_DataType* values, int num_values) { const TF_DataType* values, int num_values) {
desc->node_builder.Attr( desc->node_builder.Attr(
attr_name, ArraySlice<const DataType>( attr_name, ArraySlice<const DataType>(
reinterpret_cast<const DataType*>(values), num_values)); reinterpret_cast<const DataType*>(values), num_values));
} }
void TF_SetAttrShape(TF_NodeDescription* desc, const char* attr_name, void TF_SetAttrShape(TF_OperationDescription* desc, const char* attr_name,
const int64_t* dims, int num_dims) { const int64_t* dims, int num_dims) {
PartialTensorShape shape; PartialTensorShape shape;
if (num_dims >= 0) { if (num_dims >= 0) {
@ -795,7 +794,7 @@ void TF_SetAttrShape(TF_NodeDescription* desc, const char* attr_name,
desc->node_builder.Attr(attr_name, shape); desc->node_builder.Attr(attr_name, shape);
} }
void TF_SetAttrShapeList(TF_NodeDescription* desc, const char* attr_name, void TF_SetAttrShapeList(TF_OperationDescription* desc, const char* attr_name,
const int64_t* const* dims, const int* num_dims, const int64_t* const* dims, const int* num_dims,
int num_shapes) { int num_shapes) {
std::vector<PartialTensorShape> shapes; std::vector<PartialTensorShape> shapes;
@ -813,8 +812,9 @@ void TF_SetAttrShapeList(TF_NodeDescription* desc, const char* attr_name,
desc->node_builder.Attr(attr_name, shapes); desc->node_builder.Attr(attr_name, shapes);
} }
void TF_SetAttrTensorShapeProto(TF_NodeDescription* desc, const char* attr_name, void TF_SetAttrTensorShapeProto(TF_OperationDescription* desc,
void* proto, int proto_len, TF_Status* status) { const char* attr_name, void* proto,
int proto_len, TF_Status* status) {
TensorShapeProto shape; TensorShapeProto shape;
if (shape.ParseFromArray(proto, proto_len)) { if (shape.ParseFromArray(proto, proto_len)) {
desc->node_builder.Attr(attr_name, shape); desc->node_builder.Attr(attr_name, shape);
@ -825,7 +825,7 @@ void TF_SetAttrTensorShapeProto(TF_NodeDescription* desc, const char* attr_name,
} }
} }
void TF_SetAttrTensorShapeProtoList(TF_NodeDescription* desc, void TF_SetAttrTensorShapeProtoList(TF_OperationDescription* desc,
const char* attr_name, const char* attr_name,
const void* const* protos, const void* const* protos,
const int* proto_lens, int num_shapes, const int* proto_lens, int num_shapes,
@ -843,7 +843,7 @@ void TF_SetAttrTensorShapeProtoList(TF_NodeDescription* desc,
status->status = Status::OK(); status->status = Status::OK();
} }
void TF_SetAttrTensor(TF_NodeDescription* desc, const char* attr_name, void TF_SetAttrTensor(TF_OperationDescription* desc, const char* attr_name,
TF_Tensor* value, TF_Status* status) { TF_Tensor* value, TF_Status* status) {
status->status = Status::OK(); status->status = Status::OK();
Tensor t; Tensor t;
@ -862,7 +862,7 @@ void TF_SetAttrTensor(TF_NodeDescription* desc, const char* attr_name,
if (ok) desc->node_builder.Attr(attr_name, t); if (ok) desc->node_builder.Attr(attr_name, t);
} }
void TF_SetAttrTensorList(TF_NodeDescription* desc, const char* attr_name, void TF_SetAttrTensorList(TF_OperationDescription* desc, const char* attr_name,
TF_Tensor* const* values, int num_values, TF_Tensor* const* values, int num_values,
TF_Status* status) { TF_Status* status) {
status->status = Status::OK(); status->status = Status::OK();
@ -890,9 +890,9 @@ void TF_SetAttrTensorList(TF_NodeDescription* desc, const char* attr_name,
if (ok) desc->node_builder.Attr(attr_name, t); if (ok) desc->node_builder.Attr(attr_name, t);
} }
void TF_SetAttrToAttrValueProto(TF_NodeDescription* desc, const char* attr_name, void TF_SetAttrToAttrValueProto(TF_OperationDescription* desc,
const void* proto, size_t proto_len, const char* attr_name, const void* proto,
TF_Status* status) { size_t proto_len, TF_Status* status) {
tensorflow::AttrValue attr_value; tensorflow::AttrValue attr_value;
if (attr_value.ParseFromArray(proto, proto_len)) { if (attr_value.ParseFromArray(proto, proto_len)) {
desc->node_builder.Attr(attr_name, attr_value); desc->node_builder.Attr(attr_name, attr_value);
@ -903,7 +903,8 @@ void TF_SetAttrToAttrValueProto(TF_NodeDescription* desc, const char* attr_name,
} }
} }
TF_Node* TF_FinishNode(TF_NodeDescription* desc, TF_Status* status) { TF_Operation* TF_FinishOperation(TF_OperationDescription* desc,
TF_Status* status) {
Node* ret = nullptr; Node* ret = nullptr;
mutex_lock l(desc->graph->mu); mutex_lock l(desc->graph->mu);
@ -919,32 +920,37 @@ TF_Node* TF_FinishNode(TF_NodeDescription* desc, TF_Status* status) {
delete desc; delete desc;
return ToNode(ret); return ToOperation(ret);
} }
// TF_Node functions ---------------------------------------------------------- // TF_Operation functions
// ----------------------------------------------------------
const char* TF_NodeName(TF_Node* node) { return node->node.name().c_str(); } const char* TF_OperationName(TF_Operation* oper) {
return oper->node.name().c_str();
const char* TF_NodeOpType(TF_Node* node) {
return node->node.type_string().c_str();
} }
const char* TF_NodeDevice(TF_Node* node) { const char* TF_OperationOpType(TF_Operation* oper) {
return node->node.def().device().c_str(); return oper->node.type_string().c_str();
} }
int TF_NodeNumOutputs(TF_Node* node) { return node->node.num_outputs(); } const char* TF_OperationDevice(TF_Operation* oper) {
return oper->node.def().device().c_str();
}
TF_DataType TF_NodeOutputType(TF_Port node_out) { int TF_OperationNumOutputs(TF_Operation* oper) {
return oper->node.num_outputs();
}
TF_DataType TF_OperationOutputType(TF_Port oper_out) {
return static_cast<TF_DataType>( return static_cast<TF_DataType>(
node_out.node->node.output_type(node_out.index)); oper_out.oper->node.output_type(oper_out.index));
} }
int TF_NodeOutputListLength(TF_Node* node, const char* arg_name, int TF_OperationOutputListLength(TF_Operation* oper, const char* arg_name,
TF_Status* status) { TF_Status* status) {
NameRangeMap name_ranges; NameRangeMap name_ranges;
status->status = NameRangesForNode(node->node.def(), node->node.op_def(), status->status = NameRangesForNode(oper->node.def(), oper->node.op_def(),
nullptr, &name_ranges); nullptr, &name_ranges);
if (!status->status.ok()) return -1; if (!status->status.ok()) return -1;
auto iter = name_ranges.find(arg_name); auto iter = name_ranges.find(arg_name);
@ -956,16 +962,18 @@ int TF_NodeOutputListLength(TF_Node* node, const char* arg_name,
return iter->second.second - iter->second.first; return iter->second.second - iter->second.first;
} }
int TF_NodeNumInputs(TF_Node* node) { return node->node.num_inputs(); } int TF_OperationNumInputs(TF_Operation* oper) {
return oper->node.num_inputs();
TF_DataType TF_NodeInputType(TF_Port node_in) {
return static_cast<TF_DataType>(node_in.node->node.input_type(node_in.index));
} }
int TF_NodeInputListLength(TF_Node* node, const char* arg_name, TF_DataType TF_OperationInputType(TF_Port oper_in) {
TF_Status* status) { return static_cast<TF_DataType>(oper_in.oper->node.input_type(oper_in.index));
}
int TF_OperationInputListLength(TF_Operation* oper, const char* arg_name,
TF_Status* status) {
NameRangeMap name_ranges; NameRangeMap name_ranges;
status->status = NameRangesForNode(node->node.def(), node->node.op_def(), status->status = NameRangesForNode(oper->node.def(), oper->node.op_def(),
&name_ranges, nullptr); &name_ranges, nullptr);
if (!status->status.ok()) return -1; if (!status->status.ok()) return -1;
auto iter = name_ranges.find(arg_name); auto iter = name_ranges.find(arg_name);
@ -977,32 +985,32 @@ int TF_NodeInputListLength(TF_Node* node, const char* arg_name,
return iter->second.second - iter->second.first; return iter->second.second - iter->second.first;
} }
TF_Port TF_NodeInput(TF_Port node_in) { TF_Port TF_OperationInput(TF_Port oper_in) {
for (const auto* edge : node_in.node->node.in_edges()) { for (const auto* edge : oper_in.oper->node.in_edges()) {
if (edge->dst_input() == node_in.index) { if (edge->dst_input() == oper_in.index) {
return {ToNode(edge->src()), edge->src_output()}; return {ToOperation(edge->src()), edge->src_output()};
} }
} }
return {nullptr, -1}; return {nullptr, -1};
} }
int TF_NodeOutputNumConsumers(TF_Port node_out) { int TF_OperationOutputNumConsumers(TF_Port oper_out) {
int count = 0; int count = 0;
for (const auto* edge : node_out.node->node.out_edges()) { for (const auto* edge : oper_out.oper->node.out_edges()) {
if (edge->src_output() == node_out.index) { if (edge->src_output() == oper_out.index) {
++count; ++count;
} }
} }
return count; return count;
} }
int TF_NodeOutputConsumers(TF_Port node_out, TF_Port* consumers, int TF_OperationOutputConsumers(TF_Port oper_out, TF_Port* consumers,
int max_consumers) { int max_consumers) {
int count = 0; int count = 0;
for (const auto* edge : node_out.node->node.out_edges()) { for (const auto* edge : oper_out.oper->node.out_edges()) {
if (edge->src_output() == node_out.index) { if (edge->src_output() == oper_out.index) {
if (count < max_consumers) { if (count < max_consumers) {
consumers[count] = {ToNode(edge->dst()), edge->dst_input()}; consumers[count] = {ToOperation(edge->dst()), edge->dst_input()};
} }
++count; ++count;
} }
@ -1010,9 +1018,9 @@ int TF_NodeOutputConsumers(TF_Port node_out, TF_Port* consumers,
return count; return count;
} }
int TF_NodeNumControlInputs(TF_Node* node) { int TF_OperationNumControlInputs(TF_Operation* oper) {
int count = 0; int count = 0;
for (const auto* edge : node->node.in_edges()) { for (const auto* edge : oper->node.in_edges()) {
if (edge->IsControlEdge()) { if (edge->IsControlEdge()) {
++count; ++count;
} }
@ -1020,13 +1028,14 @@ int TF_NodeNumControlInputs(TF_Node* node) {
return count; return count;
} }
int TF_NodeGetControlInputs(TF_Node* node, TF_Node** control_inputs, int TF_OperationGetControlInputs(TF_Operation* oper,
int max_control_inputs) { TF_Operation** control_inputs,
int max_control_inputs) {
int count = 0; int count = 0;
for (const auto* edge : node->node.in_edges()) { for (const auto* edge : oper->node.in_edges()) {
if (edge->IsControlEdge()) { if (edge->IsControlEdge()) {
if (count < max_control_inputs) { if (count < max_control_inputs) {
control_inputs[count] = ToNode(edge->src()); control_inputs[count] = ToOperation(edge->src());
} }
++count; ++count;
} }
@ -1034,9 +1043,9 @@ int TF_NodeGetControlInputs(TF_Node* node, TF_Node** control_inputs,
return count; return count;
} }
int TF_NodeNumControlOutputs(TF_Node* node) { int TF_OperationNumControlOutputs(TF_Operation* oper) {
int count = 0; int count = 0;
for (const auto* edge : node->node.out_edges()) { for (const auto* edge : oper->node.out_edges()) {
if (edge->IsControlEdge()) { if (edge->IsControlEdge()) {
++count; ++count;
} }
@ -1044,13 +1053,14 @@ int TF_NodeNumControlOutputs(TF_Node* node) {
return count; return count;
} }
int TF_NodeGetControlOutputs(TF_Node* node, TF_Node** control_outputs, int TF_OperationGetControlOutputs(TF_Operation* oper,
int max_control_outputs) { TF_Operation** control_outputs,
int max_control_outputs) {
int count = 0; int count = 0;
for (const auto* edge : node->node.out_edges()) { for (const auto* edge : oper->node.out_edges()) {
if (edge->IsControlEdge()) { if (edge->IsControlEdge()) {
if (count < max_control_outputs) { if (count < max_control_outputs) {
control_outputs[count] = ToNode(edge->dst()); control_outputs[count] = ToOperation(edge->dst());
} }
++count; ++count;
} }
@ -1058,19 +1068,20 @@ int TF_NodeGetControlOutputs(TF_Node* node, TF_Node** control_outputs,
return count; return count;
} }
void TF_NodeGetAttrValueProto(TF_Node* node, const char* attr_name, void TF_OperationGetAttrValueProto(TF_Operation* oper, const char* attr_name,
TF_Buffer* output_attr_value, TF_Status* status) { TF_Buffer* output_attr_value,
TF_Status* status) {
if (output_attr_value->data != nullptr) { if (output_attr_value->data != nullptr) {
status->status = tensorflow::errors::InvalidArgument( status->status = tensorflow::errors::InvalidArgument(
"Passing non-empty output_attr_value is invalid."); "Passing non-empty output_attr_value is invalid.");
return; return;
} }
const auto& attr_map = node->node.def().attr(); const auto& attr_map = oper->node.def().attr();
auto iter = attr_map.find(attr_name); auto iter = attr_map.find(attr_name);
if (iter == attr_map.end()) { if (iter == attr_map.end()) {
status->status = tensorflow::errors::InvalidArgument( status->status = tensorflow::errors::InvalidArgument(
"Node has no attr named '", attr_name, "'."); "Operation has no attr named '", attr_name, "'.");
return; return;
} }
@ -1086,15 +1097,15 @@ void TF_NodeGetAttrValueProto(TF_Node* node, const char* attr_name,
status->status = Status::OK(); status->status = Status::OK();
} }
void TF_NodeToNodeDef(TF_Node* node, TF_Buffer* output_node_def, void TF_OperationToNodeDef(TF_Operation* oper, TF_Buffer* output_node_def,
TF_Status* status) { TF_Status* status) {
if (output_node_def->data != nullptr) { if (output_node_def->data != nullptr) {
status->status = tensorflow::errors::InvalidArgument( status->status = tensorflow::errors::InvalidArgument(
"Passing non-empty output_node_def is invalid."); "Passing non-empty output_node_def is invalid.");
return; return;
} }
const NodeDef& def = node->node.def(); const NodeDef& def = oper->node.def();
const auto proto_size = def.ByteSize(); const auto proto_size = def.ByteSize();
void* str_buf = malloc(proto_size); void* str_buf = malloc(proto_size);
def.SerializeToArray(str_buf, proto_size); def.SerializeToArray(str_buf, proto_size);
@ -1118,17 +1129,17 @@ void TF_DeleteGraph(TF_Graph* g) {
if (del) delete g; if (del) delete g;
} }
TF_Node* TF_GraphNodeByName(TF_Graph* graph, const char* node_name) { TF_Operation* TF_GraphOperationByName(TF_Graph* graph, const char* oper_name) {
mutex_lock l(graph->mu); mutex_lock l(graph->mu);
auto iter = graph->name_map.find(node_name); auto iter = graph->name_map.find(oper_name);
if (iter == graph->name_map.end()) { if (iter == graph->name_map.end()) {
return nullptr; return nullptr;
} else { } else {
return ToNode(iter->second); return ToOperation(iter->second);
} }
} }
TF_Node* TF_GraphNextNode(TF_Graph* graph, size_t* pos) { TF_Operation* TF_GraphNextOperation(TF_Graph* graph, size_t* pos) {
if (*pos == 0) { if (*pos == 0) {
// Advance past the first sentinal nodes in every graph (the source & sink). // Advance past the first sentinal nodes in every graph (the source & sink).
*pos += 2; *pos += 2;
@ -1143,7 +1154,7 @@ TF_Node* TF_GraphNextNode(TF_Graph* graph, size_t* pos) {
// FindNodeId() returns nullptr for nodes that have been deleted. // FindNodeId() returns nullptr for nodes that have been deleted.
// We aren't currently allowing nodes to be deleted, but it is safer // We aren't currently allowing nodes to be deleted, but it is safer
// to still check. // to still check.
if (node != nullptr) return reinterpret_cast<TF_Node*>(node); if (node != nullptr) return ToOperation(node);
*pos += 1; *pos += 1;
} }
@ -1257,7 +1268,7 @@ void TF_SessionRun(TF_SessionWithGraph* session, const TF_Buffer* run_options,
const TF_Port* inputs, TF_Tensor* const* input_values, const TF_Port* inputs, TF_Tensor* const* input_values,
int ninputs, const TF_Port* outputs, int ninputs, const TF_Port* outputs,
TF_Tensor** output_values, int noutputs, TF_Tensor** output_values, int noutputs,
const TF_Node* const* target_nodes, int ntargets, const TF_Operation* const* target_opers, int ntargets,
TF_Buffer* run_metadata, TF_Status* status) { TF_Buffer* run_metadata, TF_Status* status) {
// TODO(josh11b,mrry): Change Session to be able to use a Graph* // TODO(josh11b,mrry): Change Session to be able to use a Graph*
// directly, instead of requiring us to serialize to a GraphDef and // directly, instead of requiring us to serialize to a GraphDef and
@ -1284,10 +1295,10 @@ void TF_SessionRun(TF_SessionWithGraph* session, const TF_Buffer* run_options,
output_names[i] = PortName(outputs[i]); output_names[i] = PortName(outputs[i]);
} }
// Convert from TF_Node* to string names. // Convert from TF_Operation* to string names.
std::vector<tensorflow::string> target_names(ntargets); std::vector<tensorflow::string> target_names(ntargets);
for (int i = 0; i < ntargets; ++i) { for (int i = 0; i < ntargets; ++i) {
target_names[i] = target_nodes[i]->node.name(); target_names[i] = target_opers[i]->node.name();
} }
// Actually run. // Actually run.
@ -1298,7 +1309,7 @@ void TF_SessionRun(TF_SessionWithGraph* session, const TF_Buffer* run_options,
void TF_SessionPRunSetup(TF_SessionWithGraph* session, const TF_Port* inputs, void TF_SessionPRunSetup(TF_SessionWithGraph* session, const TF_Port* inputs,
int ninputs, const TF_Port* outputs, int noutputs, int ninputs, const TF_Port* outputs, int noutputs,
const TF_Node* const* target_nodes, int ntargets, const TF_Operation* const* target_opers, int ntargets,
const char** handle, TF_Status* status) { const char** handle, TF_Status* status) {
if (!ExtendSessionGraphHelper(session, status)) { if (!ExtendSessionGraphHelper(session, status)) {
return; return;
@ -1316,7 +1327,7 @@ void TF_SessionPRunSetup(TF_SessionWithGraph* session, const TF_Port* inputs,
std::vector<tensorflow::string> target_names(ntargets); std::vector<tensorflow::string> target_names(ntargets);
for (int i = 0; i < ntargets; ++i) { for (int i = 0; i < ntargets; ++i) {
target_names[i] = target_nodes[i]->node.name(); target_names[i] = target_opers[i]->node.name();
} }
tensorflow::string new_handle; tensorflow::string new_handle;
@ -1333,7 +1344,7 @@ void TF_SessionPRun(TF_SessionWithGraph* session, const char* handle,
const TF_Port* inputs, TF_Tensor* const* input_values, const TF_Port* inputs, TF_Tensor* const* input_values,
int ninputs, const TF_Port* outputs, int ninputs, const TF_Port* outputs,
TF_Tensor** output_values, int noutputs, TF_Tensor** output_values, int noutputs,
const TF_Node* const* target_nodes, int ntargets, const TF_Operation* const* target_opers, int ntargets,
TF_Status* status) { TF_Status* status) {
// TODO(josh11b,mrry): Change Session to be able to use a Graph* // TODO(josh11b,mrry): Change Session to be able to use a Graph*
// directly, instead of requiring us to serialize to a GraphDef and // directly, instead of requiring us to serialize to a GraphDef and
@ -1360,10 +1371,10 @@ void TF_SessionPRun(TF_SessionWithGraph* session, const char* handle,
output_names[i] = PortName(outputs[i]); output_names[i] = PortName(outputs[i]);
} }
// Convert from TF_Node* to string names. // Convert from TF_Operation* to string names.
std::vector<tensorflow::string> target_names(ntargets); std::vector<tensorflow::string> target_names(ntargets);
for (int i = 0; i < ntargets; ++i) { for (int i = 0; i < ntargets; ++i) {
target_names[i] = target_nodes[i]->node.name(); target_names[i] = target_opers[i]->node.name();
} }
TF_Run_Helper(session->session, handle, nullptr, input_pairs, output_names, TF_Run_Helper(session->session, handle, nullptr, input_pairs, output_names,

View File

@ -247,29 +247,31 @@ extern TF_Graph* TF_NewGraph();
// TFSessionWithGraph's are referencing it. // TFSessionWithGraph's are referencing it.
extern void TF_DeleteGraph(TF_Graph*); extern void TF_DeleteGraph(TF_Graph*);
// Node being built. The underlying graph must outlive this. // Operation being built. The underlying graph must outlive this.
typedef struct TF_NodeDescription TF_NodeDescription; typedef struct TF_OperationDescription TF_OperationDescription;
// Node that has been added to the graph. Valid until the graph is // Operation that has been added to the graph. Valid until the graph is
// deleted -- in particular adding a new node to the graph does not // deleted -- in particular adding a new operation to the graph does not
// invalidate old TF_Node* pointers. // invalidate old TF_Operation* pointers.
typedef struct TF_Node TF_Node; typedef struct TF_Operation TF_Operation;
// Represents a specific input or output of a node, e.g. to specify the // Represents a specific input or output of an operation, e.g. to
// specific output to pass as an input to an op. // specify the specific output to pass as an input to a new op.
typedef struct TF_Port { typedef struct TF_Port {
TF_Node* node; TF_Operation* oper;
int index; // Specifies the index of the input or output within node. int index; // Specifies the index of the input or output within oper.
} TF_Port; } TF_Port;
// Node will only be added to *graph when TF_FinishNode() is called // Operation will only be added to *graph when TF_FinishOperation() is
// (assuming TF_FinishNode() does not return an error). *graph must // called (assuming TF_FinishOperation() does not return an error).
// not be deleted until after TF_FinishNode() is called. // *graph must not be deleted until after TF_FinishOperation() is
extern TF_NodeDescription* TF_NewNode(TF_Graph* graph, const char* op_type, // called.
const char* node_name); extern TF_OperationDescription* TF_NewOperation(TF_Graph* graph,
const char* op_type,
const char* oper_name);
// Specify the device for `desc`. Defaults to empty, meaning unconstrained. // Specify the device for `desc`. Defaults to empty, meaning unconstrained.
extern void TF_SetDevice(TF_NodeDescription* desc, const char* device); extern void TF_SetDevice(TF_OperationDescription* desc, const char* device);
// The calls to TF_AddInput and TF_AddInputList must match (in number, // The calls to TF_AddInput and TF_AddInputList must match (in number,
// order, and type) the op declaration. For example, the "Concat" op // order, and type) the op declaration. For example, the "Concat" op
@ -285,74 +287,82 @@ extern void TF_SetDevice(TF_NodeDescription* desc, const char* device);
// single tensor), and TF_AddInputList() for the second input (since // single tensor), and TF_AddInputList() for the second input (since
// it takes a list, even if you were to pass a list with a single // it takes a list, even if you were to pass a list with a single
// tensor), as in: // tensor), as in:
// TF_NodeDescription* desc = TF_NewNode(graph, "Concat", "c"); // TF_OperationDescription* desc = TF_NewOperation(graph, "Concat", "c");
// TF_Port concat_dim_input = {...}; // TF_Port concat_dim_input = {...};
// TF_AddInput(desc, concat_dim_input); // TF_AddInput(desc, concat_dim_input);
// TF_Port values_inputs[5] = {{...}, ..., {...}}; // TF_Port values_inputs[5] = {{...}, ..., {...}};
// TF_AddInputList(desc, 5, values_inputs); // TF_AddInputList(desc, 5, values_inputs);
// For inputs that take a single tensor. // For inputs that take a single tensor.
extern void TF_AddInput(TF_NodeDescription* desc, TF_Port input); extern void TF_AddInput(TF_OperationDescription* desc, TF_Port input);
// For inputs that take a list of tensors. // For inputs that take a list of tensors.
// inputs must point to TF_Port[num_inputs]. // inputs must point to TF_Port[num_inputs].
extern void TF_AddInputList(TF_NodeDescription* desc, const TF_Port* inputs, extern void TF_AddInputList(TF_OperationDescription* desc,
int num_inputs); const TF_Port* inputs, int num_inputs);
// Call once per control input to `desc`. // Call once per control input to `desc`.
extern void TF_AddControlInput(TF_NodeDescription* desc, TF_Node* input); extern void TF_AddControlInput(TF_OperationDescription* desc,
TF_Operation* input);
// Call some TF_SetAttr*() function for every attr that is not // Call some TF_SetAttr*() function for every attr that is not
// inferred from an input and doesn't have a default value you wish to // inferred from an input and doesn't have a default value you wish to
// keep. // keep.
// `value` must point to a string of length `length` bytes. // `value` must point to a string of length `length` bytes.
extern void TF_SetAttrString(TF_NodeDescription* desc, const char* attr_name, extern void TF_SetAttrString(TF_OperationDescription* desc,
const void* value, int length); const char* attr_name, const void* value,
int length);
// `values` and `lengths` both must have lengths `num_values`. // `values` and `lengths` both must have lengths `num_values`.
// `values[i]` must point to a string of length `lengths[i]` bytes. // `values[i]` must point to a string of length `lengths[i]` bytes.
extern void TF_SetAttrStringList(TF_NodeDescription* desc, extern void TF_SetAttrStringList(TF_OperationDescription* desc,
const char* attr_name, const char* attr_name,
const void* const* values, const int* lengths, const void* const* values, const int* lengths,
int num_values); int num_values);
extern void TF_SetAttrInt(TF_NodeDescription* desc, const char* attr_name, extern void TF_SetAttrInt(TF_OperationDescription* desc, const char* attr_name,
int64_t value); int64_t value);
extern void TF_SetAttrIntList(TF_NodeDescription* desc, const char* attr_name, extern void TF_SetAttrIntList(TF_OperationDescription* desc,
const int64_t* values, int num_values); const char* attr_name, const int64_t* values,
extern void TF_SetAttrFloat(TF_NodeDescription* desc, const char* attr_name, int num_values);
float value); extern void TF_SetAttrFloat(TF_OperationDescription* desc,
extern void TF_SetAttrFloatList(TF_NodeDescription* desc, const char* attr_name, const char* attr_name, float value);
const float* values, int num_values); extern void TF_SetAttrFloatList(TF_OperationDescription* desc,
extern void TF_SetAttrBool(TF_NodeDescription* desc, const char* attr_name, const char* attr_name, const float* values,
int num_values);
extern void TF_SetAttrBool(TF_OperationDescription* desc, const char* attr_name,
unsigned char value); unsigned char value);
extern void TF_SetAttrBoolList(TF_NodeDescription* desc, const char* attr_name, extern void TF_SetAttrBoolList(TF_OperationDescription* desc,
const char* attr_name,
const unsigned char* values, int num_values); const unsigned char* values, int num_values);
extern void TF_SetAttrType(TF_NodeDescription* desc, const char* attr_name, extern void TF_SetAttrType(TF_OperationDescription* desc, const char* attr_name,
TF_DataType value); TF_DataType value);
extern void TF_SetAttrTypeList(TF_NodeDescription* desc, const char* attr_name, extern void TF_SetAttrTypeList(TF_OperationDescription* desc,
const TF_DataType* values, int num_values); const char* attr_name, const TF_DataType* values,
int num_values);
// Set `num_dims` to -1 to represent "unknown rank". Otherwise, // Set `num_dims` to -1 to represent "unknown rank". Otherwise,
// `dims` points to an array of length `num_dims`. `dims[i]` must be // `dims` points to an array of length `num_dims`. `dims[i]` must be
// >= -1, with -1 meaning "unknown dimension". // >= -1, with -1 meaning "unknown dimension".
extern void TF_SetAttrShape(TF_NodeDescription* desc, const char* attr_name, extern void TF_SetAttrShape(TF_OperationDescription* desc,
const int64_t* dims, int num_dims); const char* attr_name, const int64_t* dims,
int num_dims);
// `dims` and `num_dims` must point to arrays of length `num_shapes`. // `dims` and `num_dims` must point to arrays of length `num_shapes`.
// Set `num_dims[i]` to -1 to represent "unknown rank". Otherwise, // Set `num_dims[i]` to -1 to represent "unknown rank". Otherwise,
// `dims[i]` points to an array of length `num_dims[i]`. `dims[i][j]` // `dims[i]` points to an array of length `num_dims[i]`. `dims[i][j]`
// must be >= -1, with -1 meaning "unknown dimension". // must be >= -1, with -1 meaning "unknown dimension".
extern void TF_SetAttrShapeList(TF_NodeDescription* desc, const char* attr_name, extern void TF_SetAttrShapeList(TF_OperationDescription* desc,
const char* attr_name,
const int64_t* const* dims, const int* num_dims, const int64_t* const* dims, const int* num_dims,
int num_shapes); int num_shapes);
// `proto` must point to an array of `proto_len` bytes representing a // `proto` must point to an array of `proto_len` bytes representing a
// binary-serialized TensorShapeProto. // binary-serialized TensorShapeProto.
extern void TF_SetAttrTensorShapeProto(TF_NodeDescription* desc, extern void TF_SetAttrTensorShapeProto(TF_OperationDescription* desc,
const char* attr_name, void* proto, const char* attr_name, void* proto,
int proto_len, TF_Status* status); int proto_len, TF_Status* status);
// `protos` and `proto_lens` must point to arrays of length `num_shapes`. // `protos` and `proto_lens` must point to arrays of length `num_shapes`.
// `protos[i]` must point to an array of `proto_lens[i]` bytes // `protos[i]` must point to an array of `proto_lens[i]` bytes
// representing a binary-serialized TensorShapeProto. // representing a binary-serialized TensorShapeProto.
extern void TF_SetAttrTensorShapeProtoList(TF_NodeDescription* desc, extern void TF_SetAttrTensorShapeProtoList(TF_OperationDescription* desc,
const char* attr_name, const char* attr_name,
const void* const* protos, const void* const* protos,
const int* proto_lens, const int* proto_lens,
@ -360,11 +370,12 @@ extern void TF_SetAttrTensorShapeProtoList(TF_NodeDescription* desc,
// This functions takes ownership of *value (the // This functions takes ownership of *value (the
// implementation will eventually call TF_DeleteTensor). // implementation will eventually call TF_DeleteTensor).
extern void TF_SetAttrTensor(TF_NodeDescription* desc, const char* attr_name, extern void TF_SetAttrTensor(TF_OperationDescription* desc,
TF_Tensor* value, TF_Status* status); const char* attr_name, TF_Tensor* value,
TF_Status* status);
// This functions takes ownership of values[0]..values[num_values-1] (the // This functions takes ownership of values[0]..values[num_values-1] (the
// implementation will eventually call TF_DeleteTensor on each). // implementation will eventually call TF_DeleteTensor on each).
extern void TF_SetAttrTensorList(TF_NodeDescription* desc, extern void TF_SetAttrTensorList(TF_OperationDescription* desc,
const char* attr_name, const char* attr_name,
TF_Tensor* const* values, int num_values, TF_Tensor* const* values, int num_values,
TF_Status* status); TF_Status* status);
@ -372,100 +383,108 @@ extern void TF_SetAttrTensorList(TF_NodeDescription* desc,
// `proto` should point to a sequence of bytes of length `proto_len` // `proto` should point to a sequence of bytes of length `proto_len`
// representing a binary serialization of an AttrValue protocol // representing a binary serialization of an AttrValue protocol
// buffer. // buffer.
extern void TF_SetAttrToAttrValueProto(TF_NodeDescription* desc, extern void TF_SetAttrToAttrValueProto(TF_OperationDescription* desc,
const char* attr_name, const void* proto, const char* attr_name, const void* proto,
size_t proto_len, TF_Status* status); size_t proto_len, TF_Status* status);
// If this function succeeds: // If this function succeeds:
// * *status is set to an OK value, // * *status is set to an OK value,
// * a TF_Node is added to the graph, // * a TF_Operation is added to the graph,
// * a non-null value pointing to the added node is returned -- // * a non-null value pointing to the added operation is returned --
// this value is valid until the underlying graph is deleted. // this value is valid until the underlying graph is deleted.
// Otherwise: // Otherwise:
// * *status is set to a non-OK value, // * *status is set to a non-OK value,
// * the graph is not modified, // * the graph is not modified,
// * a null value is returned. // * a null value is returned.
// In either case, it deletes `desc`. // In either case, it deletes `desc`.
extern TF_Node* TF_FinishNode(TF_NodeDescription* desc, TF_Status* status); extern TF_Operation* TF_FinishOperation(TF_OperationDescription* desc,
TF_Status* status);
// TF_Node functions. Nodes are immutable once created, so these are all // TF_Operation functions. Operations are immutable once created, so
// query functions. // these are all query functions.
extern const char* TF_NodeName(TF_Node* node); extern const char* TF_OperationName(TF_Operation* oper);
extern const char* TF_NodeOpType(TF_Node* node); extern const char* TF_OperationOpType(TF_Operation* oper);
extern const char* TF_NodeDevice(TF_Node* node); extern const char* TF_OperationDevice(TF_Operation* oper);
extern int TF_NodeNumOutputs(TF_Node* node); extern int TF_OperationNumOutputs(TF_Operation* oper);
extern TF_DataType TF_NodeOutputType(TF_Port node_out); extern TF_DataType TF_OperationOutputType(TF_Port oper_out);
extern int TF_NodeOutputListLength(TF_Node* node, const char* arg_name, extern int TF_OperationOutputListLength(TF_Operation* oper,
TF_Status* status); const char* arg_name,
TF_Status* status);
extern int TF_NodeNumInputs(TF_Node* node); extern int TF_OperationNumInputs(TF_Operation* oper);
extern TF_DataType TF_NodeInputType(TF_Port node_in); extern TF_DataType TF_OperationInputType(TF_Port oper_in);
extern int TF_NodeInputListLength(TF_Node* node, const char* arg_name, extern int TF_OperationInputListLength(TF_Operation* oper, const char* arg_name,
TF_Status* status); TF_Status* status);
// In this code: // In this code:
// TF_Port producer = TF_NodeInput(consumer); // TF_Port producer = TF_OperationInput(consumer);
// There is an edge from producer.node's output (given by // There is an edge from producer.oper's output (given by
// producer.index) to consumer.node's input (given by consumer.index). // producer.index) to consumer.oper's input (given by consumer.index).
extern TF_Port TF_NodeInput(TF_Port node_in); extern TF_Port TF_OperationInput(TF_Port oper_in);
// Get the number of current consumers of a node's output. Note that // Get the number of current consumers of a specific output of an
// this number can change when new nodes are added to the graph. // operation. Note that this number can change when new operations
extern int TF_NodeOutputNumConsumers(TF_Port node_out); // are added to the graph.
extern int TF_OperationOutputNumConsumers(TF_Port oper_out);
// Get list of all current consumers of a node's output. consumers // Get list of all current consumers of a specific output of an
// must point to an array of length at least max_consumers (ideally // operation. `consumers` must point to an array of length at least
// set to TF_NodeOutputNumConsumer(node_out)). Beware that a // `max_consumers` (ideally set to
// concurrent modification of the graph can increase the number of // TF_OperationOutputNumConsumers(oper_out)). Beware that a concurrent
// consumers of a node. Returns the number of output consumers // modification of the graph can increase the number of consumers of
// (should match TF_NodeOutputNumConsumers(node_out)). // an operation. Returns the number of output consumers (should match
extern int TF_NodeOutputConsumers(TF_Port node_out, TF_Port* consumers, // TF_OperationOutputNumConsumers(oper_out)).
int max_consumers); extern int TF_OperationOutputConsumers(TF_Port oper_out, TF_Port* consumers,
int max_consumers);
// Get the number of control inputs to a node. // Get the number of control inputs to an operation.
extern int TF_NodeNumControlInputs(TF_Node* node); extern int TF_OperationNumControlInputs(TF_Operation* oper);
// Get list of all control inputs to a node. control_inputs must // Get list of all control inputs to an operation. `control_inputs` must
// point to an array of length max_control_inputs (ideally set to // point to an array of length `max_control_inputs` (ideally set to
// TF_NodeNumControlInputs(node)). Returns the number of control // TF_OperationNumControlInputs(oper)). Returns the number of control
// inputs (should match TF_NodeNumControlInputs(node)). // inputs (should match TF_OperationNumControlInputs(oper)).
extern int TF_NodeGetControlInputs(TF_Node* node, TF_Node** control_inputs, extern int TF_OperationGetControlInputs(TF_Operation* oper,
int max_control_inputs); TF_Operation** control_inputs,
int max_control_inputs);
// Get the number of nodes that have *node as a control inputs. // Get the number of operations that have `*oper` as a control input.
// Note that this number can change when new nodes are added to the // Note that this number can change when new operations are added to
// graph. // the graph.
extern int TF_NodeNumControlOutputs(TF_Node* node); extern int TF_OperationNumControlOutputs(TF_Operation* oper);
// Get the list of nodes that have *node as a control input. // Get the list of operations that have `*oper` as a control input.
// control_outputs must point to an array of length at least // `control_outputs` must point to an array of length at least
// max_control_outputs (ideally set to // `max_control_outputs` (ideally set to
// TF_NodeNumControlOutputs(node)). Beware that a concurrent // TF_OperationNumControlOutputs(oper)). Beware that a concurrent
// modification of the graph can increase the number of control // modification of the graph can increase the number of control
// outputs. Returns the number of control outputs (should match // outputs. Returns the number of control outputs (should match
// TF_NodeNumControlOutputs(node)). // TF_OperationNumControlOutputs(oper)).
extern int TF_NodeGetControlOutputs(TF_Node* node, TF_Node** control_outputs, extern int TF_OperationGetControlOutputs(TF_Operation* oper,
int max_control_outputs); TF_Operation** control_outputs,
int max_control_outputs);
// Sets `output_attr_value` to the binary-serialized AttrValue proto // Sets `output_attr_value` to the binary-serialized AttrValue proto
// representation of the value of the `attr_name` attr of `node`. // representation of the value of the `attr_name` attr of `oper`.
extern void TF_NodeGetAttrValueProto(TF_Node* node, const char* attr_name, extern void TF_OperationGetAttrValueProto(TF_Operation* oper,
TF_Buffer* output_attr_value, const char* attr_name,
TF_Status* status); TF_Buffer* output_attr_value,
TF_Status* status);
// Returns the node in the graph with `node_name`. Returns nullptr if // Returns the operation in the graph with `oper_name`. Returns nullptr if
// no node found. // no operation found.
extern TF_Node* TF_GraphNodeByName(TF_Graph* graph, const char* node_name); extern TF_Operation* TF_GraphOperationByName(TF_Graph* graph,
const char* oper_name);
// Iterate through the nodes of a graph. To use: // Iterate through the operations of a graph. To use:
// size_t pos = 0; // size_t pos = 0;
// TF_Node* node; // TF_Operation* oper;
// while ((node = TF_GraphNextNode(graph, &pos)) != nullptr) { // while ((oper = TF_GraphNextOperation(graph, &pos)) != nullptr) {
// DoSomethingWithNode(node); // DoSomethingWithOperation(oper);
// } // }
extern TF_Node* TF_GraphNextNode(TF_Graph* graph, size_t* pos); extern TF_Operation* TF_GraphNextOperation(TF_Graph* graph, size_t* pos);
// Note: The following two functions may fail on very large protos in the // Note: The following two functions may fail on very large protos in the
// future. // future.
@ -473,18 +492,19 @@ extern TF_Node* TF_GraphNextNode(TF_Graph* graph, size_t* pos);
extern void TF_GraphToGraphDef(TF_Graph* graph, TF_Buffer* output_graph_def, extern void TF_GraphToGraphDef(TF_Graph* graph, TF_Buffer* output_graph_def,
TF_Status* status); TF_Status* status);
extern void TF_NodeToNodeDef(TF_Node* node, TF_Buffer* output_node_def, extern void TF_OperationToNodeDef(TF_Operation* oper,
TF_Status* status); TF_Buffer* output_node_def,
TF_Status* status);
// TODO(josh11b): Query attrs for a Node. // TODO(josh11b): Query attrs for an operation.
// TODO(cwhipkey): Query shape for node outputs. // TODO(cwhipkey): Query shape for operation outputs.
// TODO(josh11b,mrry): Import GraphDef into TF_Graph. // TODO(josh11b,mrry): Import GraphDef into TF_Graph.
// TODO(andydavis): Function to add gradients to a graph. // TODO(andydavis): Function to add gradients to a graph.
// TODO(josh11b): Register OpDef, available to all nodes added // TODO(josh11b): Register OpDef, available to all operations added
// to this graph. // to this graph.
// The following two may both benefit from a subgraph-definition API // The following two may both benefit from a subgraph-definition API
@ -530,8 +550,8 @@ extern void TF_SessionRun(TF_SessionWithGraph* session,
// Output tensors // Output tensors
const TF_Port* outputs, TF_Tensor** output_values, const TF_Port* outputs, TF_Tensor** output_values,
int noutputs, int noutputs,
// Target nodes // Target operations
const TF_Node* const* target_nodes, int ntargets, const TF_Operation* const* target_opers, int ntargets,
// RunMetadata // RunMetadata
TF_Buffer* run_metadata, TF_Buffer* run_metadata,
// Output status // Output status
@ -543,8 +563,8 @@ extern void TF_SessionPRunSetup(TF_SessionWithGraph*,
const TF_Port* inputs, int ninputs, const TF_Port* inputs, int ninputs,
// Output names // Output names
const TF_Port* outputs, int noutputs, const TF_Port* outputs, int noutputs,
// Target nodes // Target operations
const TF_Node* const* target_nodes, const TF_Operation* const* target_opers,
int ntargets, int ntargets,
// Output handle // Output handle
const char** handle, const char** handle,
@ -559,8 +579,9 @@ extern void TF_SessionPRun(TF_SessionWithGraph*, const char* handle,
// Output tensors // Output tensors
const TF_Port* outputs, TF_Tensor** output_values, const TF_Port* outputs, TF_Tensor** output_values,
int noutputs, int noutputs,
// Target nodes // Target operations
const TF_Node* const* target_nodes, int ntargets, const TF_Operation* const* target_opers,
int ntargets,
// Output status // Output status
TF_Status*); TF_Status*);
@ -622,10 +643,9 @@ extern void TF_Run(TF_Session*,
// Input tensors // Input tensors
const char** input_names, TF_Tensor** inputs, int ninputs, const char** input_names, TF_Tensor** inputs, int ninputs,
// Output tensors // Output tensors
const char** output_tensor_names, TF_Tensor** outputs, const char** output_names, TF_Tensor** outputs, int noutputs,
int noutputs, // Target operations
// Target nodes const char** target_oper_names, int ntargets,
const char** target_node_names, int ntargets,
// RunMetadata // RunMetadata
TF_Buffer* run_metadata, TF_Buffer* run_metadata,
// Output status // Output status
@ -643,9 +663,9 @@ extern void TF_PRunSetup(TF_Session*,
// Input names // Input names
const char** input_names, int ninputs, const char** input_names, int ninputs,
// Output names // Output names
const char** output_tensor_names, int noutputs, const char** output_names, int noutputs,
// Target nodes // Target operations
const char** target_node_names, int ntargets, const char** target_oper_names, int ntargets,
// Output handle // Output handle
const char** handle, const char** handle,
// Output status // Output status
@ -658,10 +678,10 @@ extern void TF_PRun(TF_Session*, const char* handle,
// Input tensors // Input tensors
const char** input_names, TF_Tensor** inputs, int ninputs, const char** input_names, TF_Tensor** inputs, int ninputs,
// Output tensors // Output tensors
const char** output_tensor_names, TF_Tensor** outputs, const char** output_names, TF_Tensor** outputs,
int noutputs, int noutputs,
// Target nodes // Target operations
const char** target_node_names, int ntargets, const char** target_oper_names, int ntargets,
// Output status // Output status
TF_Status*); TF_Status*);

View File

@ -202,32 +202,33 @@ static TF_Tensor* Int32Tensor(int32 v) {
&Int32Deallocator, nullptr); &Int32Deallocator, nullptr);
} }
TF_Node* Placeholder(TF_Graph* graph, TF_Status* s) { TF_Operation* Placeholder(TF_Graph* graph, TF_Status* s) {
TF_NodeDescription* desc = TF_NewNode(graph, "Placeholder", "feed"); TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", "feed");
TF_SetAttrType(desc, "dtype", TF_INT32); TF_SetAttrType(desc, "dtype", TF_INT32);
return TF_FinishNode(desc, s); return TF_FinishOperation(desc, s);
} }
TF_Node* ScalarConst(int32 v, TF_Graph* graph, TF_Status* s) { TF_Operation* ScalarConst(int32 v, TF_Graph* graph, TF_Status* s) {
TF_NodeDescription* desc = TF_NewNode(graph, "Const", "scalar"); TF_OperationDescription* desc = TF_NewOperation(graph, "Const", "scalar");
TF_SetAttrTensor(desc, "value", Int32Tensor(v), s); TF_SetAttrTensor(desc, "value", Int32Tensor(v), s);
if (TF_GetCode(s) != TF_OK) return nullptr; if (TF_GetCode(s) != TF_OK) return nullptr;
TF_SetAttrType(desc, "dtype", TF_INT32); TF_SetAttrType(desc, "dtype", TF_INT32);
return TF_FinishNode(desc, s); return TF_FinishOperation(desc, s);
} }
TF_Node* Add(TF_Node* l, TF_Node* r, TF_Graph* graph, TF_Status* s) { TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph,
TF_NodeDescription* desc = TF_NewNode(graph, "AddN", "add"); TF_Status* s) {
TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", "add");
TF_Port add_inputs[2] = {{l, 0}, {r, 0}}; TF_Port add_inputs[2] = {{l, 0}, {r, 0}};
TF_AddInputList(desc, add_inputs, 2); TF_AddInputList(desc, add_inputs, 2);
return TF_FinishNode(desc, s); return TF_FinishOperation(desc, s);
} }
TF_Node* Neg(TF_Node* n, TF_Graph* graph, TF_Status* s) { TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s) {
TF_NodeDescription* desc = TF_NewNode(graph, "Neg", "neg"); TF_OperationDescription* desc = TF_NewOperation(graph, "Neg", "neg");
TF_Port neg_input = {n, 0}; TF_Port neg_input = {n, 0};
TF_AddInput(desc, neg_input); TF_AddInput(desc, neg_input);
return TF_FinishNode(desc, s); return TF_FinishOperation(desc, s);
} }
bool IsPlaceholder(const NodeDef& node_def) { bool IsPlaceholder(const NodeDef& node_def) {
@ -318,10 +319,10 @@ bool GetGraphDef(TF_Graph* graph, GraphDef* graph_def) {
return ret; return ret;
} }
bool GetNodeDef(TF_Node* node, NodeDef* node_def) { bool GetNodeDef(TF_Operation* oper, NodeDef* node_def) {
TF_Status* s = TF_NewStatus(); TF_Status* s = TF_NewStatus();
TF_Buffer* buffer = TF_NewBuffer(); TF_Buffer* buffer = TF_NewBuffer();
TF_NodeToNodeDef(node, buffer, s); TF_OperationToNodeDef(oper, buffer, s);
bool ret = TF_GetCode(s) == TF_OK; bool ret = TF_GetCode(s) == TF_OK;
EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); EXPECT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
if (ret) ret = node_def->ParseFromArray(buffer->data, buffer->length); if (ret) ret = node_def->ParseFromArray(buffer->data, buffer->length);
@ -330,10 +331,10 @@ bool GetNodeDef(TF_Node* node, NodeDef* node_def) {
return ret; return ret;
} }
bool GetAttrValue(TF_Node* node, const char* attr_name, bool GetAttrValue(TF_Operation* oper, const char* attr_name,
tensorflow::AttrValue* attr_value, TF_Status* s) { tensorflow::AttrValue* attr_value, TF_Status* s) {
TF_Buffer* buffer = TF_NewBuffer(); TF_Buffer* buffer = TF_NewBuffer();
TF_NodeGetAttrValueProto(node, attr_name, buffer, s); TF_OperationGetAttrValueProto(oper, attr_name, buffer, s);
bool ret = TF_GetCode(s) == TF_OK; bool ret = TF_GetCode(s) == TF_OK;
if (ret) ret = attr_value->ParseFromArray(buffer->data, buffer->length); if (ret) ret = attr_value->ParseFromArray(buffer->data, buffer->length);
TF_DeleteBuffer(buffer); TF_DeleteBuffer(buffer);
@ -344,82 +345,83 @@ TEST(CAPI, Graph) {
TF_Status* s = TF_NewStatus(); TF_Status* s = TF_NewStatus();
TF_Graph* graph = TF_NewGraph(); TF_Graph* graph = TF_NewGraph();
// Make a placeholder node. // Make a placeholder oper.
TF_Node* feed = Placeholder(graph, s); TF_Operation* feed = Placeholder(graph, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
// Test TF_Node*() query functions. // Test TF_Operation*() query functions.
EXPECT_EQ(string("feed"), string(TF_NodeName(feed))); EXPECT_EQ(string("feed"), string(TF_OperationName(feed)));
EXPECT_EQ(string("Placeholder"), string(TF_NodeOpType(feed))); EXPECT_EQ(string("Placeholder"), string(TF_OperationOpType(feed)));
EXPECT_EQ(string(""), string(TF_NodeDevice(feed))); EXPECT_EQ(string(""), string(TF_OperationDevice(feed)));
EXPECT_EQ(1, TF_NodeNumOutputs(feed)); EXPECT_EQ(1, TF_OperationNumOutputs(feed));
EXPECT_EQ(TF_INT32, TF_NodeOutputType(TF_Port{feed, 0})); EXPECT_EQ(TF_INT32, TF_OperationOutputType(TF_Port{feed, 0}));
EXPECT_EQ(1, TF_NodeOutputListLength(feed, "output", s)); EXPECT_EQ(1, TF_OperationOutputListLength(feed, "output", s));
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
EXPECT_EQ(0, TF_NodeNumInputs(feed)); EXPECT_EQ(0, TF_OperationNumInputs(feed));
EXPECT_EQ(0, TF_NodeOutputNumConsumers(TF_Port{feed, 0})); EXPECT_EQ(0, TF_OperationOutputNumConsumers(TF_Port{feed, 0}));
EXPECT_EQ(0, TF_NodeNumControlInputs(feed)); EXPECT_EQ(0, TF_OperationNumControlInputs(feed));
EXPECT_EQ(0, TF_NodeNumControlOutputs(feed)); EXPECT_EQ(0, TF_OperationNumControlOutputs(feed));
tensorflow::AttrValue attr_value; tensorflow::AttrValue attr_value;
ASSERT_TRUE(GetAttrValue(feed, "dtype", &attr_value, s)) << TF_Message(s); ASSERT_TRUE(GetAttrValue(feed, "dtype", &attr_value, s)) << TF_Message(s);
EXPECT_EQ(attr_value.type(), tensorflow::DT_INT32); EXPECT_EQ(attr_value.type(), tensorflow::DT_INT32);
// Test not found errors in TF_Node*() query functions. // Test not found errors in TF_Operation*() query functions.
EXPECT_EQ(-1, TF_NodeOutputListLength(feed, "bogus", s)); EXPECT_EQ(-1, TF_OperationOutputListLength(feed, "bogus", s));
EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s)); EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s));
ASSERT_FALSE(GetAttrValue(feed, "missing", &attr_value, s)); ASSERT_FALSE(GetAttrValue(feed, "missing", &attr_value, s));
EXPECT_EQ(string("Node has no attr named 'missing'."), string(TF_Message(s))); EXPECT_EQ(string("Operation has no attr named 'missing'."),
string(TF_Message(s)));
// Make a constant node with the scalar "3". // Make a constant oper with the scalar "3".
TF_Node* three = ScalarConst(3, graph, s); TF_Operation* three = ScalarConst(3, graph, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
// Add node. // Add oper.
TF_Node* add = Add(feed, three, graph, s); TF_Operation* add = Add(feed, three, graph, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
// Test TF_Node*() query functions. // Test TF_Operation*() query functions.
EXPECT_EQ(string("add"), string(TF_NodeName(add))); EXPECT_EQ(string("add"), string(TF_OperationName(add)));
EXPECT_EQ(string("AddN"), string(TF_NodeOpType(add))); EXPECT_EQ(string("AddN"), string(TF_OperationOpType(add)));
EXPECT_EQ(string(""), string(TF_NodeDevice(add))); EXPECT_EQ(string(""), string(TF_OperationDevice(add)));
EXPECT_EQ(1, TF_NodeNumOutputs(add)); EXPECT_EQ(1, TF_OperationNumOutputs(add));
EXPECT_EQ(TF_INT32, TF_NodeOutputType(TF_Port{add, 0})); EXPECT_EQ(TF_INT32, TF_OperationOutputType(TF_Port{add, 0}));
EXPECT_EQ(1, TF_NodeOutputListLength(add, "sum", s)); EXPECT_EQ(1, TF_OperationOutputListLength(add, "sum", s));
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
EXPECT_EQ(2, TF_NodeNumInputs(add)); EXPECT_EQ(2, TF_OperationNumInputs(add));
EXPECT_EQ(2, TF_NodeInputListLength(add, "inputs", s)); EXPECT_EQ(2, TF_OperationInputListLength(add, "inputs", s));
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
EXPECT_EQ(TF_INT32, TF_NodeInputType(TF_Port{add, 0})); EXPECT_EQ(TF_INT32, TF_OperationInputType(TF_Port{add, 0}));
EXPECT_EQ(TF_INT32, TF_NodeInputType(TF_Port{add, 1})); EXPECT_EQ(TF_INT32, TF_OperationInputType(TF_Port{add, 1}));
TF_Port add_in_0 = TF_NodeInput(TF_Port{add, 0}); TF_Port add_in_0 = TF_OperationInput(TF_Port{add, 0});
EXPECT_EQ(feed, add_in_0.node); EXPECT_EQ(feed, add_in_0.oper);
EXPECT_EQ(0, add_in_0.index); EXPECT_EQ(0, add_in_0.index);
TF_Port add_in_1 = TF_NodeInput(TF_Port{add, 1}); TF_Port add_in_1 = TF_OperationInput(TF_Port{add, 1});
EXPECT_EQ(three, add_in_1.node); EXPECT_EQ(three, add_in_1.oper);
EXPECT_EQ(0, add_in_1.index); EXPECT_EQ(0, add_in_1.index);
EXPECT_EQ(0, TF_NodeOutputNumConsumers(TF_Port{add, 0})); EXPECT_EQ(0, TF_OperationOutputNumConsumers(TF_Port{add, 0}));
EXPECT_EQ(0, TF_NodeNumControlInputs(add)); EXPECT_EQ(0, TF_OperationNumControlInputs(add));
EXPECT_EQ(0, TF_NodeNumControlOutputs(add)); EXPECT_EQ(0, TF_OperationNumControlOutputs(add));
ASSERT_TRUE(GetAttrValue(add, "T", &attr_value, s)) << TF_Message(s); ASSERT_TRUE(GetAttrValue(add, "T", &attr_value, s)) << TF_Message(s);
EXPECT_EQ(attr_value.type(), tensorflow::DT_INT32); EXPECT_EQ(attr_value.type(), tensorflow::DT_INT32);
ASSERT_TRUE(GetAttrValue(add, "N", &attr_value, s)) << TF_Message(s); ASSERT_TRUE(GetAttrValue(add, "N", &attr_value, s)) << TF_Message(s);
EXPECT_EQ(attr_value.i(), 2); EXPECT_EQ(attr_value.i(), 2);
// Placeholder node now has a consumer. // Placeholder oper now has a consumer.
ASSERT_EQ(1, TF_NodeOutputNumConsumers(TF_Port{feed, 0})); ASSERT_EQ(1, TF_OperationOutputNumConsumers(TF_Port{feed, 0}));
TF_Port feed_port; TF_Port feed_port;
EXPECT_EQ(1, TF_NodeOutputConsumers(TF_Port{feed, 0}, &feed_port, 1)); EXPECT_EQ(1, TF_OperationOutputConsumers(TF_Port{feed, 0}, &feed_port, 1));
EXPECT_EQ(add, feed_port.node); EXPECT_EQ(add, feed_port.oper);
EXPECT_EQ(0, feed_port.index); EXPECT_EQ(0, feed_port.index);
// The scalar const node also has a consumer. // The scalar const oper also has a consumer.
ASSERT_EQ(1, TF_NodeOutputNumConsumers(TF_Port{three, 0})); ASSERT_EQ(1, TF_OperationOutputNumConsumers(TF_Port{three, 0}));
TF_Port three_port; TF_Port three_port;
EXPECT_EQ(1, TF_NodeOutputConsumers(TF_Port{three, 0}, &three_port, 1)); EXPECT_EQ(1, TF_OperationOutputConsumers(TF_Port{three, 0}, &three_port, 1));
EXPECT_EQ(add, three_port.node); EXPECT_EQ(add, three_port.oper);
EXPECT_EQ(1, three_port.index); EXPECT_EQ(1, three_port.index);
// Serialize to GraphDef. // Serialize to GraphDef.
@ -448,8 +450,8 @@ TEST(CAPI, Graph) {
EXPECT_TRUE(found_scalar_const); EXPECT_TRUE(found_scalar_const);
EXPECT_TRUE(found_add); EXPECT_TRUE(found_add);
// Add another node to the graph. // Add another oper to the graph.
TF_Node* neg = Neg(add, graph, s); TF_Operation* neg = Neg(add, graph, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
// Serialize to NodeDef. // Serialize to NodeDef.
@ -469,13 +471,13 @@ TEST(CAPI, Graph) {
EXPECT_EQ(ProtoDebugString(graph_def), ProtoDebugString(graph_def2)); EXPECT_EQ(ProtoDebugString(graph_def), ProtoDebugString(graph_def2));
// Look up some nodes by name. // Look up some nodes by name.
TF_Node* neg2 = TF_GraphNodeByName(graph, "neg"); TF_Operation* neg2 = TF_GraphOperationByName(graph, "neg");
EXPECT_TRUE(neg == neg2); EXPECT_TRUE(neg == neg2);
NodeDef node_def2; NodeDef node_def2;
ASSERT_TRUE(GetNodeDef(neg2, &node_def2)); ASSERT_TRUE(GetNodeDef(neg2, &node_def2));
EXPECT_EQ(ProtoDebugString(node_def), ProtoDebugString(node_def2)); EXPECT_EQ(ProtoDebugString(node_def), ProtoDebugString(node_def2));
TF_Node* feed2 = TF_GraphNodeByName(graph, "feed"); TF_Operation* feed2 = TF_GraphOperationByName(graph, "feed");
EXPECT_TRUE(feed == feed2); EXPECT_TRUE(feed == feed2);
ASSERT_TRUE(GetNodeDef(feed, &node_def)); ASSERT_TRUE(GetNodeDef(feed, &node_def));
ASSERT_TRUE(GetNodeDef(feed2, &node_def2)); ASSERT_TRUE(GetNodeDef(feed2, &node_def2));
@ -487,22 +489,22 @@ TEST(CAPI, Graph) {
found_add = false; found_add = false;
bool found_neg = false; bool found_neg = false;
size_t pos = 0; size_t pos = 0;
TF_Node* node; TF_Operation* oper;
while ((node = TF_GraphNextNode(graph, &pos)) != nullptr) { while ((oper = TF_GraphNextOperation(graph, &pos)) != nullptr) {
if (node == feed) { if (oper == feed) {
EXPECT_FALSE(found_placeholder); EXPECT_FALSE(found_placeholder);
found_placeholder = true; found_placeholder = true;
} else if (node == three) { } else if (oper == three) {
EXPECT_FALSE(found_scalar_const); EXPECT_FALSE(found_scalar_const);
found_scalar_const = true; found_scalar_const = true;
} else if (node == add) { } else if (oper == add) {
EXPECT_FALSE(found_add); EXPECT_FALSE(found_add);
found_add = true; found_add = true;
} else if (node == neg) { } else if (oper == neg) {
EXPECT_FALSE(found_neg); EXPECT_FALSE(found_neg);
found_neg = true; found_neg = true;
} else { } else {
ASSERT_TRUE(GetNodeDef(node, &node_def)); ASSERT_TRUE(GetNodeDef(oper, &node_def));
ADD_FAILURE() << "Unexpected Node: " << ProtoDebugString(node_def); ADD_FAILURE() << "Unexpected Node: " << ProtoDebugString(node_def);
} }
} }
@ -532,7 +534,7 @@ class CSessionWithGraph {
} }
void SetInputs( void SetInputs(
std::initializer_list<std::pair<TF_Node*, TF_Tensor*>> inputs) { std::initializer_list<std::pair<TF_Operation*, TF_Tensor*>> inputs) {
DeleteInputValues(); DeleteInputValues();
inputs_.clear(); inputs_.clear();
for (const auto& p : inputs) { for (const auto& p : inputs) {
@ -541,17 +543,17 @@ class CSessionWithGraph {
} }
} }
void SetOutputs(std::initializer_list<TF_Node*> outputs) { void SetOutputs(std::initializer_list<TF_Operation*> outputs) {
ResetOutputValues(); ResetOutputValues();
outputs_.clear(); outputs_.clear();
for (TF_Node* o : outputs) { for (TF_Operation* o : outputs) {
outputs_.emplace_back(TF_Port{o, 0}); outputs_.emplace_back(TF_Port{o, 0});
} }
} }
void SetTargets(std::initializer_list<TF_Node*> targets) { void SetTargets(std::initializer_list<TF_Operation*> targets) {
targets_.clear(); targets_.clear();
for (TF_Node* t : targets) { for (TF_Operation* t : targets) {
targets_.emplace_back(t); targets_.emplace_back(t);
} }
} }
@ -572,7 +574,8 @@ class CSessionWithGraph {
TF_Tensor** output_values_ptr = TF_Tensor** output_values_ptr =
output_values_.empty() ? nullptr : &output_values_[0]; output_values_.empty() ? nullptr : &output_values_[0];
TF_Node* const* targets_ptr = targets_.empty() ? nullptr : &targets_[0]; TF_Operation* const* targets_ptr =
targets_.empty() ? nullptr : &targets_[0];
TF_SessionRun(session_, nullptr, inputs_ptr, input_values_ptr, TF_SessionRun(session_, nullptr, inputs_ptr, input_values_ptr,
inputs_.size(), outputs_ptr, output_values_ptr, inputs_.size(), outputs_ptr, output_values_ptr,
@ -615,23 +618,23 @@ class CSessionWithGraph {
std::vector<TF_Tensor*> input_values_; std::vector<TF_Tensor*> input_values_;
std::vector<TF_Port> outputs_; std::vector<TF_Port> outputs_;
std::vector<TF_Tensor*> output_values_; std::vector<TF_Tensor*> output_values_;
std::vector<TF_Node*> targets_; std::vector<TF_Operation*> targets_;
}; };
TEST(CAPI, SessionWithGraph) { TEST(CAPI, SessionWithGraph) {
TF_Status* s = TF_NewStatus(); TF_Status* s = TF_NewStatus();
TF_Graph* graph = TF_NewGraph(); TF_Graph* graph = TF_NewGraph();
// Make a placeholder node. // Make a placeholder operation.
TF_Node* feed = Placeholder(graph, s); TF_Operation* feed = Placeholder(graph, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
// Make a constant node with the scalar "2". // Make a constant operation with the scalar "2".
TF_Node* two = ScalarConst(2, graph, s); TF_Operation* two = ScalarConst(2, graph, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
// Add node. // Add operation.
TF_Node* add = Add(feed, two, graph, s); TF_Operation* add = Add(feed, two, graph, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
// Create a session for this graph. // Create a session for this graph.
@ -652,11 +655,11 @@ TEST(CAPI, SessionWithGraph) {
static_cast<tensorflow::int32*>(TF_TensorData(out)); static_cast<tensorflow::int32*>(TF_TensorData(out));
EXPECT_EQ(3 + 2, *output_contents); EXPECT_EQ(3 + 2, *output_contents);
// Add another node to the graph. // Add another operation to the graph.
TF_Node* neg = Neg(add, graph, s); TF_Operation* neg = Neg(add, graph, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
// Run up to the new node. // Run up to the new operation.
csession.SetInputs({{feed, Int32Tensor(7)}}); csession.SetInputs({{feed, Int32Tensor(7)}});
csession.SetOutputs({neg}); csession.SetOutputs({neg});
csession.Run(s); csession.Run(s);

View File

@ -243,4 +243,4 @@ try:
plot_with_labels(low_dim_embs, labels) plot_with_labels(low_dim_embs, labels)
except ImportError: except ImportError:
print("Please install sklearn and matplotlib to visualize embeddings.") print("Please install sklearn, matplotlib, and scipy to visualize embeddings.")

View File

@ -248,7 +248,7 @@ class Word2Vec(object):
true_logits = tf.reduce_sum(tf.mul(example_emb, true_w), 1) + true_b true_logits = tf.reduce_sum(tf.mul(example_emb, true_w), 1) + true_b
# Sampled logits: [batch_size, num_sampled] # Sampled logits: [batch_size, num_sampled]
# We replicate sampled noise lables for all examples in the batch # We replicate sampled noise labels for all examples in the batch
# using the matmul. # using the matmul.
sampled_b_vec = tf.reshape(sampled_b, [opts.num_samples]) sampled_b_vec = tf.reshape(sampled_b, [opts.num_samples])
sampled_logits = tf.matmul(example_emb, sampled_logits = tf.matmul(example_emb,

View File

@ -1 +1 @@
23 25

View File

@ -29,6 +29,7 @@ import imghdr
import json import json
import mimetypes import mimetypes
import os import os
import re
from six import BytesIO from six import BytesIO
from six.moves import BaseHTTPServer from six.moves import BaseHTTPServer
@ -65,6 +66,11 @@ _IMGHDR_TO_MIMETYPE = {
} }
_DEFAULT_IMAGE_MIMETYPE = 'application/octet-stream' _DEFAULT_IMAGE_MIMETYPE = 'application/octet-stream'
# Allows *, gzip or x-gzip, but forbid gzip;q=0
# https://tools.ietf.org/html/rfc7231#section-5.3.4
_ALLOWS_GZIP_PATTERN = re.compile(
r'(?:^|,|\s)(?:(?:x-)?gzip|\*)(?!;q=0)(?:\s|,|$)')
def _content_type_for_image(encoded_image_string): def _content_type_for_image(encoded_image_string):
image_type = imghdr.what(None, encoded_image_string) image_type = imghdr.what(None, encoded_image_string)
@ -91,6 +97,10 @@ class TensorboardHandler(BaseHTTPServer.BaseHTTPRequestHandler):
# How many samples to include in sampling API calls by default. # How many samples to include in sampling API calls by default.
DEFAULT_SAMPLE_COUNT = 10 DEFAULT_SAMPLE_COUNT = 10
# NOTE TO MAINTAINERS: An accurate Content-Length MUST be specified on all
# responses using send_header.
protocol_version = 'HTTP/1.1'
def __init__(self, multiplexer, *args): def __init__(self, multiplexer, *args):
self._multiplexer = multiplexer self._multiplexer = multiplexer
BaseHTTPServer.BaseHTTPRequestHandler.__init__(self, *args) BaseHTTPServer.BaseHTTPRequestHandler.__init__(self, *args)
@ -162,25 +172,54 @@ class TensorboardHandler(BaseHTTPServer.BaseHTTPRequestHandler):
prefix = os.path.commonprefix([base, absolute_path]) prefix = os.path.commonprefix([base, absolute_path])
return prefix == base return prefix == base
def _respond(self, content, content_type, code=200, encoding=None):
"""Sends HTTP response.
All text responses are assumed to be utf-8 unless specified otherwise.
Args:
content: The content to respond with, which is converted to bytes.
content_type: The mime type of the content.
code: The numeric HTTP status code to use.
encoding: The encoding if any (not sanity checked.)
"""
content = compat.as_bytes(content)
self.send_response(code)
if content_type.startswith(('text/', 'application/json')):
if 'charset=' not in content_type:
content_type += '; charset=utf-8'
self.send_header('Content-Type', content_type)
self.send_header('Content-Length', len(content))
if encoding:
self.send_header('Content-Encoding', encoding)
self.end_headers()
self.wfile.write(content)
def _is_gzip_accepted(self):
"""Returns true if Accept-Encoding contains gzip."""
accept_encoding = self.headers.get('Accept-Encoding', '')
return _ALLOWS_GZIP_PATTERN.search(accept_encoding) is not None
def _send_gzip_response(self, content, content_type, code=200): def _send_gzip_response(self, content, content_type, code=200):
"""Writes the given content as gzip response using the given content type. """Writes the given content as gzip response using the given content type.
If the HTTP client does not accept gzip encoding, then the response will be
sent uncompressed.
Args: Args:
content: The content to respond with. content: The content to respond with.
content_type: The mime type of the content. content_type: The mime type of the content.
code: The numeric HTTP status code to use. code: The numeric HTTP status code to use.
""" """
out = BytesIO() encoding = None
f = gzip.GzipFile(fileobj=out, mode='wb') if self._is_gzip_accepted():
f.write(compat.as_bytes(content)) out = BytesIO()
f.close() f = gzip.GzipFile(fileobj=out, mode='wb', compresslevel=3)
gzip_content = out.getvalue() f.write(compat.as_bytes(content))
self.send_response(code) f.close()
self.send_header('Content-Type', content_type) content = out.getvalue()
self.send_header('Content-Length', len(gzip_content)) encoding = 'gzip'
self.send_header('Content-Encoding', 'gzip') self._respond(content, content_type, code, encoding)
self.end_headers()
self.wfile.write(gzip_content)
def _send_json_response(self, obj, code=200): def _send_json_response(self, obj, code=200):
"""Writes out the given object as JSON using the given HTTP status code. """Writes out the given object as JSON using the given HTTP status code.
@ -191,14 +230,8 @@ class TensorboardHandler(BaseHTTPServer.BaseHTTPRequestHandler):
obj: The object to respond with. obj: The object to respond with.
code: The numeric HTTP status code to use. code: The numeric HTTP status code to use.
""" """
content = json.dumps(json_util.WrapSpecialFloats(obj))
output = json.dumps(json_util.WrapSpecialFloats(obj)) self._respond(content, 'application/json', code)
self.send_response(code)
self.send_header('Content-Type', 'application/json')
self.send_header('Content-Length', len(output))
self.end_headers()
self.wfile.write(compat.as_bytes(output))
def _send_csv_response(self, serialized_csv, code=200): def _send_csv_response(self, serialized_csv, code=200):
"""Writes out the given string, which represents CSV data. """Writes out the given string, which represents CSV data.
@ -210,12 +243,7 @@ class TensorboardHandler(BaseHTTPServer.BaseHTTPRequestHandler):
serialized_csv: A string containing some CSV data. serialized_csv: A string containing some CSV data.
code: The numeric HTTP status code to use. code: The numeric HTTP status code to use.
""" """
self._respond(serialized_csv, 'text/csv', code)
self.send_response(code)
self.send_header('Content-Type', 'text/csv')
self.send_header('Content-Length', len(serialized_csv))
self.end_headers()
self.wfile.write(serialized_csv)
def _serve_scalars(self, query_params): def _serve_scalars(self, query_params):
"""Given a tag and single run, return array of ScalarEvents. """Given a tag and single run, return array of ScalarEvents.
@ -372,12 +400,7 @@ class TensorboardHandler(BaseHTTPServer.BaseHTTPRequestHandler):
image = self._multiplexer.Images(run, tag)[index] image = self._multiplexer.Images(run, tag)[index]
encoded_image_string = image.encoded_image_string encoded_image_string = image.encoded_image_string
content_type = _content_type_for_image(encoded_image_string) content_type = _content_type_for_image(encoded_image_string)
self._respond(encoded_image_string, content_type)
self.send_response(200)
self.send_header('Content-Type', content_type)
self.send_header('Content-Length', len(encoded_image_string))
self.end_headers()
self.wfile.write(encoded_image_string)
def _query_for_individual_image(self, run, tag, index): def _query_for_individual_image(self, run, tag, index):
"""Builds a URL for accessing the specified image. """Builds a URL for accessing the specified image.
@ -429,12 +452,7 @@ class TensorboardHandler(BaseHTTPServer.BaseHTTPRequestHandler):
audio = self._multiplexer.Audio(run, tag)[index] audio = self._multiplexer.Audio(run, tag)[index]
encoded_audio_string = audio.encoded_audio_string encoded_audio_string = audio.encoded_audio_string
content_type = audio.content_type content_type = audio.content_type
self._respond(encoded_audio_string, content_type)
self.send_response(200)
self.send_header('Content-Type', content_type)
self.send_header('Content-Length', len(encoded_audio_string))
self.end_headers()
self.wfile.write(encoded_audio_string)
def _query_for_individual_audio(self, run, tag, index): def _query_for_individual_audio(self, run, tag, index):
"""Builds a URL for accessing the specified audio. """Builds a URL for accessing the specified audio.
@ -523,13 +541,9 @@ class TensorboardHandler(BaseHTTPServer.BaseHTTPRequestHandler):
logging.info('path %s not found, sending 404', path) logging.info('path %s not found, sending 404', path)
self.send_error(404) self.send_error(404)
return return
mimetype, encoding = mimetypes.guess_type(path)
self.send_response(200) mimetype = mimetype or 'application/octet-stream'
self._respond(contents, mimetype, encoding=encoding)
mimetype = mimetypes.guess_type(path)[0] or 'application/octet-stream'
self.send_header('Content-Type', mimetype)
self.end_headers()
self.wfile.write(contents)
def do_GET(self): # pylint: disable=invalid-name def do_GET(self): # pylint: disable=invalid-name
"""Handler for all get requests.""" """Handler for all get requests."""

View File

@ -41,7 +41,7 @@ TENSORBOARD_SIZE_GUIDANCE = {
event_accumulator.IMAGES: 4, event_accumulator.IMAGES: 4,
event_accumulator.AUDIO: 4, event_accumulator.AUDIO: 4,
event_accumulator.SCALARS: 1000, event_accumulator.SCALARS: 1000,
event_accumulator.HISTOGRAMS: 1, event_accumulator.HISTOGRAMS: 50,
} }
@ -80,11 +80,8 @@ def ParseEventFilesSpec(logdir):
else: else:
run_name = None run_name = None
path = specification path = specification
if not gcs.IsGCSPath(path):
if not os.path.isabs(path) and not gcs.IsGCSPath(path): path = os.path.realpath(os.path.expanduser(path))
# Create absolute path out of relative one.
path = os.path.join(os.path.realpath('.'), path)
files[path] = run_name files[path] = run_name
return files return files

View File

@ -64,9 +64,9 @@ class TensorboardServerTest(tf.test.TestCase):
self._server.shutdown() self._server.shutdown()
self._server.server_close() self._server.server_close()
def _get(self, path): def _get(self, path, headers={}):
"""Perform a GET request for the given path.""" """Perform a GET request for the given path."""
self._connection.request('GET', path) self._connection.request('GET', path, None, headers)
return self._connection.getresponse() return self._connection.getresponse()
def _getJson(self, path): def _getJson(self, path):
@ -76,18 +76,6 @@ class TensorboardServerTest(tf.test.TestCase):
self.assertEqual(response.status, 200) self.assertEqual(response.status, 200)
return json.loads(response.read().decode('utf-8')) return json.loads(response.read().decode('utf-8'))
def _decodeResponse(self, response):
"""Decompresses (if necessary) the response from the server."""
encoding = response.getheader('Content-Encoding')
content = response.read()
if encoding in ('gzip', 'x-gzip', 'deflate'):
if encoding == 'deflate':
data = BytesIO(zlib.decompress(content))
else:
data = gzip.GzipFile('', 'rb', 9, BytesIO(content))
content = data.read()
return content
def testBasicStartup(self): def testBasicStartup(self):
"""Start the server up and then shut it down immediately.""" """Start the server up and then shut it down immediately."""
pass pass
@ -180,8 +168,7 @@ class TensorboardServerTest(tf.test.TestCase):
response = self._get('/data/graph?run=run1&limit_attr_size=1024' response = self._get('/data/graph?run=run1&limit_attr_size=1024'
'&large_attrs_key=_very_large_attrs') '&large_attrs_key=_very_large_attrs')
self.assertEqual(response.status, 200) self.assertEqual(response.status, 200)
# Decompress (unzip) the response, since graphs come gzipped. graph_pbtxt = response.read()
graph_pbtxt = self._decodeResponse(response)
# Parse the graph from pbtxt into a graph message. # Parse the graph from pbtxt into a graph message.
graph = tf.GraphDef() graph = tf.GraphDef()
graph = text_format.Parse(graph_pbtxt, graph) graph = text_format.Parse(graph_pbtxt, graph)
@ -194,12 +181,40 @@ class TensorboardServerTest(tf.test.TestCase):
self.assertEqual(graph.node[1].attr['_very_large_attrs'].list.s, self.assertEqual(graph.node[1].attr['_very_large_attrs'].list.s,
[b'very_large_attr']) [b'very_large_attr'])
def testAcceptGzip_compressesResponse(self):
response = self._get('/data/graph?run=run1&limit_attr_size=1024'
'&large_attrs_key=_very_large_attrs',
{'Accept-Encoding': 'gzip'})
self.assertEqual(response.status, 200)
self.assertEqual(response.getheader('Content-Encoding'), 'gzip')
pbtxt = gzip.GzipFile('', 'rb', 9, BytesIO(response.read())).read()
graph = text_format.Parse(pbtxt, tf.GraphDef())
self.assertEqual(len(graph.node), 2)
def testAcceptAnyEncoding_compressesResponse(self):
response = self._get('/data/graph?run=run1&limit_attr_size=1024'
'&large_attrs_key=_very_large_attrs',
{'Accept-Encoding': '*'})
self.assertEqual(response.status, 200)
self.assertEqual(response.getheader('Content-Encoding'), 'gzip')
pbtxt = gzip.GzipFile('', 'rb', 9, BytesIO(response.read())).read()
graph = text_format.Parse(pbtxt, tf.GraphDef())
self.assertEqual(len(graph.node), 2)
def testAcceptDoodleEncoding_doesNotCompressResponse(self):
response = self._get('/data/graph?run=run1&limit_attr_size=1024'
'&large_attrs_key=_very_large_attrs',
{'Accept-Encoding': 'doodle'})
self.assertEqual(response.status, 200)
self.assertIsNone(response.getheader('Content-Encoding'))
graph = text_format.Parse(response.read(), tf.GraphDef())
self.assertEqual(len(graph.node), 2)
def testRunMetadata(self): def testRunMetadata(self):
"""Test retrieving the run metadata information.""" """Test retrieving the run metadata information."""
response = self._get('/data/run_metadata?run=run1&tag=test%20run') response = self._get('/data/run_metadata?run=run1&tag=test%20run')
self.assertEqual(response.status, 200) self.assertEqual(response.status, 200)
# Decompress (unzip) the response, since run outputs come gzipped. run_metadata_pbtxt = response.read()
run_metadata_pbtxt = self._decodeResponse(response)
# Parse from pbtxt into a message. # Parse from pbtxt into a message.
run_metadata = tf.RunMetadata() run_metadata = tf.RunMetadata()
text_format.Parse(run_metadata_pbtxt, run_metadata) text_format.Parse(run_metadata_pbtxt, run_metadata)
@ -283,11 +298,46 @@ class TensorboardServerTest(tf.test.TestCase):
class ParseEventFilesSpecTest(tf.test.TestCase): class ParseEventFilesSpecTest(tf.test.TestCase):
def testRunName(self):
logdir_string = 'lol:/cat'
expected = {'/cat': 'lol'}
self.assertEqual(server.ParseEventFilesSpec(logdir_string), expected)
def testPathWithColonThatComesAfterASlash_isNotConsideredARunName(self):
logdir_string = '/lol:/cat'
expected = {'/lol:/cat': None}
self.assertEqual(server.ParseEventFilesSpec(logdir_string), expected)
def testMultipleDirectories(self):
logdir_string = '/a,/b'
expected = {'/a': None, '/b': None}
self.assertEqual(server.ParseEventFilesSpec(logdir_string), expected)
def testNormalizesPaths(self):
logdir_string = '/lol/.//cat/../cat'
expected = {'/lol/cat': None}
self.assertEqual(server.ParseEventFilesSpec(logdir_string), expected)
def testAbsolutifies(self):
logdir_string = 'lol/cat'
expected = {os.path.realpath('lol/cat'): None}
self.assertEqual(server.ParseEventFilesSpec(logdir_string), expected)
def testRespectsGCSPath(self): def testRespectsGCSPath(self):
logdir_string = 'gs://foo/path' logdir_string = 'gs://foo/path'
expected = {'gs://foo/path': None} expected = {'gs://foo/path': None}
self.assertEqual(server.ParseEventFilesSpec(logdir_string), expected) self.assertEqual(server.ParseEventFilesSpec(logdir_string), expected)
def testDoesNotExpandUserInGCSPath(self):
logdir_string = 'gs://~/foo/path'
expected = {'gs://~/foo/path': None}
self.assertEqual(server.ParseEventFilesSpec(logdir_string), expected)
def testDoesNotNormalizeGCSPath(self):
logdir_string = 'gs://foo/./path//..'
expected = {'gs://foo/./path//..': None}
self.assertEqual(server.ParseEventFilesSpec(logdir_string), expected)
class TensorBoardAssetsTest(tf.test.TestCase): class TensorBoardAssetsTest(tf.test.TestCase):

View File

@ -10,7 +10,6 @@
"iron-flex-layout", "iron-flex-layout",
"iron-form-element-behavior", "iron-form-element-behavior",
"iron-icon", "iron-icon",
"iron-icons",
"iron-iconset-svg", "iron-iconset-svg",
"iron-input", "iron-input",
"iron-menu-behavior", "iron-menu-behavior",
@ -40,8 +39,8 @@
"iron-a11y-announcer": "PolymerElements/iron-a11y-announcer#1.0.4", "iron-a11y-announcer": "PolymerElements/iron-a11y-announcer#1.0.4",
"iron-a11y-keys-behavior": "PolymerElements/iron-a11y-keys-behavior#1.1.2", "iron-a11y-keys-behavior": "PolymerElements/iron-a11y-keys-behavior#1.1.2",
"iron-ajax": "PolymerElements/iron-ajax#1.2.0", "iron-ajax": "PolymerElements/iron-ajax#1.2.0",
"iron-autogrow-textarea": "PolymerElements/iron-autogrow-textarea#1.0.11", "iron-autogrow-textarea": "PolymerElements/iron-autogrow-textarea#1.0.12",
"iron-behaviors": "PolymerElements/iron-behaviors#1.0.16", "iron-behaviors": "PolymerElements/iron-behaviors#1.0.17",
"iron-checked-element-behavior": "PolymerElements/iron-checked-element-behavior#1.0.4", "iron-checked-element-behavior": "PolymerElements/iron-checked-element-behavior#1.0.4",
"iron-collapse": "PolymerElements/iron-collapse#1.0.8", "iron-collapse": "PolymerElements/iron-collapse#1.0.8",
"iron-dropdown": "PolymerElements/iron-dropdown#1.4.0", "iron-dropdown": "PolymerElements/iron-dropdown#1.4.0",
@ -49,16 +48,15 @@
"iron-flex-layout": "PolymerElements/iron-flex-layout#1.3.0", "iron-flex-layout": "PolymerElements/iron-flex-layout#1.3.0",
"iron-form-element-behavior": "PolymerElements/iron-form-element-behavior#1.0.6", "iron-form-element-behavior": "PolymerElements/iron-form-element-behavior#1.0.6",
"iron-icon": "PolymerElements/iron-icon#1.0.8", "iron-icon": "PolymerElements/iron-icon#1.0.8",
"iron-icons": "PolymerElements/iron-icons#1.1.3",
"iron-iconset-svg": "PolymerElements/iron-iconset-svg#1.0.9", "iron-iconset-svg": "PolymerElements/iron-iconset-svg#1.0.9",
"iron-input": "PolymerElements/iron-input#1.0.7", "iron-input": "PolymerElements/iron-input#1.0.10",
"iron-list": "PolymerElements/iron-list#1.1.7", "iron-list": "PolymerElements/iron-list#1.1.7",
"iron-menu-behavior": "PolymerElements/iron-menu-behavior#1.1.8", "iron-menu-behavior": "PolymerElements/iron-menu-behavior#1.1.8",
"iron-meta": "PolymerElements/iron-meta#1.1.1", "iron-meta": "PolymerElements/iron-meta#1.1.1",
"iron-overlay-behavior": "PolymerElements/iron-overlay-behavior#1.7.6", "iron-overlay-behavior": "PolymerElements/iron-overlay-behavior#1.7.6",
"iron-range-behavior": "PolymerElements/iron-range-behavior#1.0.4", "iron-range-behavior": "PolymerElements/iron-range-behavior#1.0.4",
"iron-resizable-behavior": "PolymerElements/iron-resizable-behavior#1.0.3", "iron-resizable-behavior": "PolymerElements/iron-resizable-behavior#1.0.3",
"iron-selector": "PolymerElements/iron-selector#1.2.4", "iron-selector": "PolymerElements/iron-selector#1.5.2",
"iron-validatable-behavior": "PolymerElements/iron-validatable-behavior#1.1.1", "iron-validatable-behavior": "PolymerElements/iron-validatable-behavior#1.1.1",
"lodash": "3.8.0", "lodash": "3.8.0",
"neon-animation": "PolymerElements/neon-animation#1.2.2", "neon-animation": "PolymerElements/neon-animation#1.2.2",
@ -67,14 +65,14 @@
"paper-checkbox": "PolymerElements/paper-checkbox#1.1.3", "paper-checkbox": "PolymerElements/paper-checkbox#1.1.3",
"paper-dialog": "PolymerElements/paper-dialog#1.0.4", "paper-dialog": "PolymerElements/paper-dialog#1.0.4",
"paper-dialog-behavior": "PolymerElements/paper-dialog-behavior#1.2.5", "paper-dialog-behavior": "PolymerElements/paper-dialog-behavior#1.2.5",
"paper-dropdown-menu": "PolymerElements/paper-dropdown-menu#1.1.3", "paper-dropdown-menu": "PolymerElements/paper-dropdown-menu#1.3.2",
"paper-header-panel": "PolymerElements/paper-header-panel#1.1.4", "paper-header-panel": "PolymerElements/paper-header-panel#1.1.4",
"paper-icon-button": "PolymerElements/paper-icon-button#1.1.1", "paper-icon-button": "PolymerElements/paper-icon-button#1.1.1",
"paper-input": "PolymerElements/paper-input#1.1.5", "paper-input": "PolymerElements/paper-input#1.1.14",
"paper-item": "PolymerElements/paper-item#1.1.4", "paper-item": "PolymerElements/paper-item#1.1.4",
"paper-material": "PolymerElements/paper-material#1.0.6", "paper-material": "PolymerElements/paper-material#1.0.6",
"paper-menu": "PolymerElements/paper-menu#1.2.2", "paper-menu": "PolymerElements/paper-menu#1.2.2",
"paper-menu-button": "PolymerElements/paper-menu-button#1.2.0", "paper-menu-button": "PolymerElements/paper-menu-button#1.5.0",
"paper-progress": "PolymerElements/paper-progress#1.0.9", "paper-progress": "PolymerElements/paper-progress#1.0.9",
"paper-radio-button": "PolymerElements/paper-radio-button#1.1.2", "paper-radio-button": "PolymerElements/paper-radio-button#1.1.2",
"paper-radio-group": "PolymerElements/paper-radio-group#1.0.9", "paper-radio-group": "PolymerElements/paper-radio-group#1.0.9",
@ -116,8 +114,8 @@
"iron-a11y-announcer": "1.0.4", "iron-a11y-announcer": "1.0.4",
"iron-a11y-keys-behavior": "1.1.2", "iron-a11y-keys-behavior": "1.1.2",
"iron-ajax": "1.2.0", "iron-ajax": "1.2.0",
"iron-autogrow-textarea": "1.0.11", "iron-autogrow-textarea": "1.0.12",
"iron-behaviors": "1.0.16", "iron-behaviors": "1.0.17",
"iron-checked-element-behavior": "1.0.4", "iron-checked-element-behavior": "1.0.4",
"iron-collapse": "1.0.8", "iron-collapse": "1.0.8",
"iron-dropdown": "1.4.0", "iron-dropdown": "1.4.0",
@ -127,14 +125,14 @@
"iron-icon": "1.0.8", "iron-icon": "1.0.8",
"iron-icons": "1.1.3", "iron-icons": "1.1.3",
"iron-iconset-svg": "1.0.9", "iron-iconset-svg": "1.0.9",
"iron-input": "1.0.7", "iron-input": "1.0.10",
"iron-list": "1.1.7", "iron-list": "1.1.7",
"iron-menu-behavior": "1.1.8", "iron-menu-behavior": "1.1.8",
"iron-meta": "1.1.1", "iron-meta": "1.1.1",
"iron-overlay-behavior": "1.7.6", "iron-overlay-behavior": "1.7.6",
"iron-range-behavior": "1.0.4", "iron-range-behavior": "1.0.4",
"iron-resizable-behavior": "1.0.3", "iron-resizable-behavior": "1.0.3",
"iron-selector": "1.2.4", "iron-selector": "1.5.2",
"iron-validatable-behavior": "1.1.1", "iron-validatable-behavior": "1.1.1",
"lodash": "3.8.0", "lodash": "3.8.0",
"neon-animation": "1.2.2", "neon-animation": "1.2.2",
@ -143,14 +141,14 @@
"paper-checkbox": "1.1.3", "paper-checkbox": "1.1.3",
"paper-dialog": "1.0.4", "paper-dialog": "1.0.4",
"paper-dialog-behavior": "1.2.5", "paper-dialog-behavior": "1.2.5",
"paper-dropdown-menu": "1.1.3", "paper-dropdown-menu": "1.3.2",
"paper-header-panel": "1.1.4", "paper-header-panel": "1.1.4",
"paper-icon-button": "1.1.1", "paper-icon-button": "1.1.1",
"paper-input": "1.1.5", "paper-input": "1.1.14",
"paper-item": "1.1.4", "paper-item": "1.1.4",
"paper-material": "1.0.6", "paper-material": "1.0.6",
"paper-menu": "1.2.2", "paper-menu": "1.2.2",
"paper-menu-button": "1.2.0", "paper-menu-button": "1.5.0",
"paper-progress": "1.0.9", "paper-progress": "1.0.9",
"paper-radio-button": "1.1.2", "paper-radio-button": "1.1.2",
"paper-radio-group": "1.0.9", "paper-radio-group": "1.0.9",

View File

@ -22,7 +22,6 @@ filegroup(
"@iron_flex_layout//:iron_flex_layout", "@iron_flex_layout//:iron_flex_layout",
"@iron_form_element_behavior//:iron_form_element_behavior", "@iron_form_element_behavior//:iron_form_element_behavior",
"@iron_icon//:iron_icon", "@iron_icon//:iron_icon",
"@iron_icons//:iron_icons",
"@iron_iconset_svg//:iron_iconset_svg", "@iron_iconset_svg//:iron_iconset_svg",
"@iron_input//:iron_input", "@iron_input//:iron_input",
"@iron_list//:iron_list", "@iron_list//:iron_list",

View File

@ -7,7 +7,10 @@ exports_files(["LICENSE"])
filegroup( filegroup(
name = "all_files", name = "all_files",
srcs = glob( srcs = glob(
["tf-*/**/*", "vz-*/**/*"], [
"tf-*/**/*",
"vz-*/**/*",
],
exclude = [ exclude = [
"**/tf_model_zoo/*", "**/tf_model_zoo/*",
"**/METADATA", "**/METADATA",

View File

@ -182,11 +182,16 @@ module TF.Backend {
let url = this.router.histograms(tag, run); let url = this.router.histograms(tag, run);
p = this.requestManager.request(url); p = this.requestManager.request(url);
return p.then(map(detupler(createHistogram))).then(function(histos) { return p.then(map(detupler(createHistogram))).then(function(histos) {
// Get the minimum and maximum values across all histograms so that the
// visualization is aligned for all timesteps.
let min = d3.min(histos, d => d.min);
let max = d3.max(histos, d => d.max);
return histos.map(function(histo, i) { return histos.map(function(histo, i) {
return { return {
wall_time: histo.wall_time, wall_time: histo.wall_time,
step: histo.step, step: histo.step,
bins: convertBins(histo) bins: convertBins(histo, min, max)
}; };
}); });
}); });
@ -254,11 +259,65 @@ module TF.Backend {
} }
/** Given a RunToTag, return sorted array of all runs */ /** Given a RunToTag, return sorted array of all runs */
export function getRuns(r: RunToTag): string[] { return _.keys(r).sort(); } export function getRuns(r: RunToTag): string[] {
return _.keys(r).sort(compareTagNames);
}
/** Given a RunToTag, return array of all tags (sorted + dedup'd) */ /** Given a RunToTag, return array of all tags (sorted + dedup'd) */
export function getTags(r: RunToTag): string[] { export function getTags(r: RunToTag): string[] {
return _.union.apply(null, _.values(r)).sort(); return _.union.apply(null, _.values(r)).sort(compareTagNames);
}
/** Compares tag names asciinumerically broken into components. */
export function compareTagNames(a, b: string): number {
let ai = 0;
let bi = 0;
while (true) {
if (ai === a.length) return bi === b.length ? 0 : -1;
if (bi === b.length) return 1;
if (isDigit(a[ai]) && isDigit(b[bi])) {
let ais = ai;
let bis = bi;
ai = consumeNumber(a, ai + 1);
bi = consumeNumber(b, bi + 1);
let an = parseFloat(a.slice(ais, ai));
let bn = parseFloat(b.slice(bis, bi));
if (an < bn) return -1;
if (an > bn) return 1;
continue;
}
if (isBreak(a[ai])) {
if (!isBreak(b[bi])) return -1;
} else if (isBreak(b[bi])) {
return 1;
} else if (a[ai] < b[bi]) {
return -1;
} else if (a[ai] > b[bi]) {
return 1;
}
ai++;
bi++;
}
}
function consumeNumber(s: string, i: number): number {
let decimal = false;
for (; i < s.length; i++) {
if (isDigit(s[i])) continue;
if (!decimal && s[i] === '.') {
decimal = true;
continue;
}
break;
}
return i;
}
function isDigit(c: string): boolean { return '0' <= c && c <= '9'; }
function isBreak(c: string): boolean {
// TODO(jart): Remove underscore when people stop using it like a slash.
return c === '/' || c === '_' || isDigit(c);
} }
/** /**
@ -313,34 +372,59 @@ module TF.Backend {
* Takes histogram data as stored by tensorboard backend and converts it to * Takes histogram data as stored by tensorboard backend and converts it to
* the standard d3 histogram data format to make it more compatible and easier * the standard d3 histogram data format to make it more compatible and easier
* to visualize. When visualizing histograms, having the left edge and width * to visualize. When visualizing histograms, having the left edge and width
* makes things quite a bit easier. * makes things quite a bit easier. The bins are also converted to have an
* uniform width, what makes the visualization easier to understand.
* *
* @param histogram A histogram from tensorboard backend. * @param histogram A histogram from tensorboard backend.
* @param min The leftmost edge. The binning will start on it.
* @param max The rightmost edge. The binning will end on it.
* @param numBins The number of bins of the converted data. The default of 30
* is a sensible default, using more starts to get artifacts because the event
* data is stored in buckets, and you start being able to see the aliased
* borders between each bucket.
* @return A histogram bin. Each bin has an x (left edge), a dx (width), * @return A histogram bin. Each bin has an x (left edge), a dx (width),
* and a y (count). * and a y (count).
* *
* If given rightedges are inclusive, then these left edges (x) are exclusive. * If given rightedges are inclusive, then these left edges (x) are exclusive.
*/ */
export function convertBins(histogram: Histogram) { export function convertBins(
histogram: Histogram, min: number, max: number, numBins = 30) {
if (histogram.bucketRightEdges.length !== histogram.bucketCounts.length) { if (histogram.bucketRightEdges.length !== histogram.bucketCounts.length) {
throw(new Error('Edges and counts are of different lengths.')); throw(new Error('Edges and counts are of different lengths.'));
} }
var previousRightEdge = histogram.min; let binWidth = (max - min) / numBins;
return histogram.bucketRightEdges.map(function( let bucketLeft = min; // Use the min as the starting point for the bins.
rightEdge: number, i: number) { let bucketPos = 0;
return d3.range(min, max, binWidth).map(function(binLeft) {
let binRight = binLeft + binWidth;
// Use the previous bin's rightEdge as the new leftEdge // Take the count of each existing bucket, multiply it by the proportion
var left = previousRightEdge; // of overlap with the new bin, then sum and store as the count for the
// new bin. If no overlap, will add to zero, if 100% overlap, will include
// the full count into new bin.
let binY = 0;
while (bucketPos < histogram.bucketRightEdges.length) {
// Clip the right edge because right-most edge can be infinite-sized.
let bucketRight = Math.min(max, histogram.bucketRightEdges[bucketPos]);
// We need to clip the rightEdge because right-most edge can be let intersect =
// infinite-sized Math.min(bucketRight, binRight) - Math.max(bucketLeft, binLeft);
var right = Math.min(histogram.max, rightEdge); let count = (intersect / (bucketRight - bucketLeft)) *
histogram.bucketCounts[bucketPos];
// Store rightEdgeValue for next iteration binY += intersect > 0 ? count : 0;
previousRightEdge = rightEdge;
return {x: left, dx: right - left, y: histogram.bucketCounts[i]}; // If bucketRight is bigger than binRight, than this bin is finished and
// there is data for the next bin, so don't increment bucketPos.
if (bucketRight > binRight) {
break;
}
bucketLeft = Math.max(min, bucketRight);
bucketPos++;
};
return {x: binLeft, dx: binWidth, y: binY};
}); });
} }

View File

@ -191,13 +191,16 @@ module TF.Backend {
it('Throws and error if the inputs are of different lengths', function() { it('Throws and error if the inputs are of different lengths', function() {
assert.throws(function() { assert.throws(function() {
convertBins( convertBins(
{bucketRightEdges: [0], bucketCounts: [1, 2], min: 1, max: 2}); {bucketRightEdges: [0], bucketCounts: [1, 2], min: 1, max: 2}, 1, 2,
2);
}, 'Edges and counts are of different lengths.'); }, 'Edges and counts are of different lengths.');
}); });
it('Handles data with no bins', function() { it('Handles data with no bins', function() {
assert.deepEqual( assert.deepEqual(
convertBins({bucketRightEdges: [], bucketCounts: [], min: 0, max: 0}), convertBins(
{bucketRightEdges: [], bucketCounts: [], min: 0, max: 0}, 0, 0,
0),
[]); []);
}); });
@ -205,12 +208,14 @@ module TF.Backend {
let counts = [1]; let counts = [1];
let rightEdges = [1.21e-12]; let rightEdges = [1.21e-12];
let histogram = [{x: 1.1e-12, dx: 1.21e-12 - 1.1e-12, y: 1}]; let histogram = [{x: 1.1e-12, dx: 1.21e-12 - 1.1e-12, y: 1}];
let newHistogram = convertBins({ let newHistogram = convertBins(
bucketRightEdges: rightEdges, {
bucketCounts: counts, bucketRightEdges: rightEdges,
min: 1.1e-12, bucketCounts: counts,
max: 1.21e-12 min: 1.1e-12,
}); max: 1.21e-12
},
1.1e-12, 1.21e-12, 1);
assertHistogramEquality(newHistogram, histogram); assertHistogramEquality(newHistogram, histogram);
}); });
@ -218,15 +223,17 @@ module TF.Backend {
let counts = [1, 2]; let counts = [1, 2];
let rightEdges = [1.1e-12, 1.21e-12]; let rightEdges = [1.1e-12, 1.21e-12];
let histogram = [ let histogram = [
{x: 1.0e-12, dx: 1.1e-12 - 1.0e-12, y: 1}, {x: 1.0e-12, dx: 1.05e-13, y: 1.09090909090909},
{x: 1.1e-12, dx: 1.21e-12 - 1.1e-12, y: 2} {x: 1.105e-12, dx: 1.05e-13, y: 1.9090909090909}
]; ];
let newHistogram = convertBins({ let newHistogram = convertBins(
bucketRightEdges: rightEdges, {
bucketCounts: counts, bucketRightEdges: rightEdges,
min: 1.0e-12, bucketCounts: counts,
max: 1.21e-12 min: 1.0e-12,
}); max: 1.21e-12
},
1.0e-12, 1.21e-12, 2);
assertHistogramEquality(newHistogram, histogram); assertHistogramEquality(newHistogram, histogram);
}); });
@ -236,15 +243,17 @@ module TF.Backend {
let counts = [1, 2]; let counts = [1, 2];
let rightEdges = [-1.0e-12, 1.0e-12]; let rightEdges = [-1.0e-12, 1.0e-12];
let histogram = [ let histogram = [
{x: -1.1e-12, dx: 1.1e-12 - 1.0e-12, y: 1}, {x: -1.1e-12, dx: 1.05e-12, y: 1.95},
{x: -1.0e-12, dx: 2.0e-12, y: 2} {x: -0.5e-13, dx: 1.05e-12, y: 1.05}
]; ];
let newHistogram = convertBins({ let newHistogram = convertBins(
bucketRightEdges: rightEdges, {
bucketCounts: counts, bucketRightEdges: rightEdges,
min: -1.1e-12, bucketCounts: counts,
max: 1.0e-12 min: -1.1e-12,
}); max: 1.0e-12
},
-1.1e-12, 1.0e-12, 2);
assertHistogramEquality(newHistogram, histogram); assertHistogramEquality(newHistogram, histogram);
}); });
@ -253,16 +262,71 @@ module TF.Backend {
let counts = [1, 2, 3]; let counts = [1, 2, 3];
let rightEdges = [0, 1.0e-12, 1.0e14]; let rightEdges = [0, 1.0e-12, 1.0e14];
let histogram = [ let histogram = [
{x: -1.0e-12, dx: 1.0e-12, y: 1}, {x: 0, dx: 1.0e-12, y: 2}, {x: -1.0e-12, dx: 0.7e-12, y: 0.7},
{x: 1.0e-12, dx: 1.1e-12 - 1.0e-12, y: 3} {x: -0.3e-12, dx: 0.7e-12, y: 1.1},
{x: 0.4e-12, dx: 0.7e-12, y: 4.2}
]; ];
let newHistogram = convertBins({ let newHistogram = convertBins(
bucketRightEdges: rightEdges, {
bucketCounts: counts, bucketRightEdges: rightEdges,
min: -1.0e-12, bucketCounts: counts,
max: 1.1e-12 min: -1.0e-12,
}); max: 1.1e-12
},
-1.0e-12, 1.1e-12, 3);
assertHistogramEquality(newHistogram, histogram); assertHistogramEquality(newHistogram, histogram);
}); });
}); });
describe('sortTagNames', () => {
let sortTagNames = (a) => a.sort(compareTagNames);
it('is asciibetical', () => {
assert.deepEqual(sortTagNames(['a', 'b']), ['a', 'b']);
assert.deepEqual(sortTagNames(['a', 'B']), ['B', 'a']);
});
it('sorts integer portions', () => {
assert.deepEqual(['03', '1'].sort(), ['03', '1']);
assert.deepEqual(sortTagNames(['03', '1']), ['1', '03']);
assert.deepEqual(sortTagNames(['a03', 'a1']), ['a1', 'a03']);
assert.deepEqual(sortTagNames(['a03', 'b1']), ['a03', 'b1']);
assert.deepEqual(sortTagNames(['x0a03', 'x0a1']), ['x0a1', 'x0a03']);
assert.deepEqual(sortTagNames(['a/b/03', 'a/b/1']), ['a/b/1', 'a/b/03']);
});
it('sorts floating point portions', () => {
assert.deepEqual(sortTagNames(['a0.1', 'a0.01']), ['a0.01', 'a0.1']);
});
it('is componentized by slash', () => {
assert.deepEqual(['a+/a', 'a/a', 'ab/a'].sort(), ['a+/a', 'a/a', 'ab/a']);
assert.deepEqual(
sortTagNames(['a+/a', 'a/a', 'ab/a']), ['a/a', 'a+/a', 'ab/a']);
});
it('is componentized by underscore', () => {
assert.deepEqual(
sortTagNames(['a+_a', 'a_a', 'ab_a']), ['a_a', 'a+_a', 'ab_a']);
assert.deepEqual(
sortTagNames(['a+/a', 'a_a', 'ab_a']), ['a_a', 'a+/a', 'ab_a']);
});
it('is componentized by number boundaries', () => {
assert.deepEqual(
sortTagNames(['a+0a', 'a0a', 'ab0a']), ['a0a', 'a+0a', 'ab0a']);
});
it('empty comes first', () => {
assert.deepEqual(
sortTagNames(['a', '//', '/', '']), ['', '/', '//', 'a']);
});
it('decimal parsed correctly', () => {
assert.deepEqual(sortTagNames(['0.2', '0.03']), ['0.03', '0.2']);
assert.deepEqual(sortTagNames(['0..2', '0..03']), ['0..2', '0..03']);
assert.deepEqual(sortTagNames(['.2', '.03']), ['.2', '.03']);
});
});
} }

View File

@ -19,6 +19,7 @@ limitations under the License.
<script src="../../webcomponentsjs/webcomponents-lite.min.js"></script> <script src="../../webcomponentsjs/webcomponents-lite.min.js"></script>
<script src="../../web-component-tester/browser.js"></script> <script src="../../web-component-tester/browser.js"></script>
<link rel="import" href="../../polymer/polymer.html"> <link rel="import" href="../../polymer/polymer.html">
<link rel="import" href="../../tf-imports/d3.html">
</head> </head>
<body> <body>
<test-fixture id="testElementFixture"> <test-fixture id="testElementFixture">

View File

@ -1,5 +1,6 @@
<link rel="import" href="../polymer/polymer.html"> <link rel="import" href="../polymer/polymer.html">
<link rel="import" href="../tf-imports/lodash.html"> <link rel="import" href="../tf-imports/lodash.html">
<link rel="import" href="../tf-imports/d3.html">
<script src="requestManager.js"></script> <script src="requestManager.js"></script>
<script src="urlPathHelpers.js"></script> <script src="urlPathHelpers.js"></script>

View File

@ -16,19 +16,23 @@
position: relative; position: relative;
} }
.card .card-title { .card .card-title, .card .card-subtitle {
flex-grow: 0; flex-grow: 0;
flex-shrink: 0; flex-shrink: 0;
margin-bottom: 10px;
font-size: 14px; font-size: 14px;
text-overflow: ellipsis; text-overflow: ellipsis;
overflow: hidden; overflow: hidden;
} }
.card .card-subtitle {
font-size: 12px;
}
.card .card-content { .card .card-content {
flex-grow: 1; flex-grow: 1;
flex-shrink: 1; flex-shrink: 1;
display: flex; display: flex;
margin-top: 10px;
} }
.card .card-bottom-row { .card .card-bottom-row {
position: absolute; position: absolute;

View File

@ -36,10 +36,11 @@
selectedRuns: Array, selectedRuns: Array,
xType: String, xType: String,
dataProvider: Function, dataProvider: Function,
_initialized: Boolean, _attached: Boolean,
_makeChartAsyncCallbackId: { type: Number, value: null }
}, },
observers: [ observers: [
"_makeChart(tag, dataProvider, xType, colorScale, _initialized)", "_makeChart(tag, dataProvider, xType, colorScale, _attached)",
"_changeRuns(_chart, selectedRuns.*)" "_changeRuns(_chart, selectedRuns.*)"
], ],
_changeRuns: function(chart) { _changeRuns: function(chart) {
@ -55,23 +56,26 @@
reload: function() { reload: function() {
this._chart.reload(); this._chart.reload();
}, },
_makeChart: function(tag, dataProvider, xType, colorScale, _initialized) { _makeChart: function(tag, dataProvider, xType, colorScale, _attached) {
if (!_initialized) { if (this._makeChartAsyncCallbackId === null) {
return; this.cancelAsync(this._makeChartAsyncCallbackId);
} }
if (this._chart) this._chart.destroy();
var chart = new TF.DistributionChart(tag, dataProvider, xType, colorScale); this._makeChartAsyncCallbackId = this.async(function() {
var svg = d3.select(this.$.chartsvg); this._makeChartAsyncCallbackId = null;
this.async(function() { if (!_attached) return;
if (this._chart) this._chart.destroy();
var chart = new TF.DistributionChart(tag, dataProvider, xType, colorScale);
var svg = d3.select(this.$.chartsvg);
chart.renderTo(svg); chart.renderTo(svg);
this._chart = chart; this._chart = chart;
}, 350); }, 350);
}, },
attached: function() { attached: function() {
this._initialized = true; this._attached = true;
}, },
detached: function() { detached: function() {
this._initialized = false; this._attached = false;
} }
}); });
</script> </script>

View File

@ -1,6 +1,6 @@
<link rel="import" href="../polymer/polymer.html"> <link rel="import" href="../polymer/polymer.html">
<link rel="import" href="../tf-event-dashboard/tf-run-selector.html"> <link rel="import" href="../tf-event-dashboard/tf-run-selector.html">
<link rel="import" href="../tf-event-dashboard/tf-x-type-selector.html"> <link rel="import" href="../tf-option-selector/tf-option-selector.html">
<link rel="import" href="../tf-color-scale/tf-color-scale.html"> <link rel="import" href="../tf-color-scale/tf-color-scale.html">
<link rel="import" href="../tf-dashboard-common/tf-dashboard.html"> <link rel="import" href="../tf-dashboard-common/tf-dashboard.html">
<link rel="import" href="../tf-categorizer/tf-categorizer.html"> <link rel="import" href="../tf-categorizer/tf-categorizer.html">
@ -47,10 +47,15 @@ tf-collapsable-panes.
></tf-categorizer> ></tf-categorizer>
</div> </div>
<div class="sidebar-section"> <div class="sidebar-section">
<tf-x-type-selector <tf-option-selector
id="xTypeSelector" id="xTypeSelector"
out-x-type="{{xType}}" name="Horizontal Axis"
></tf-x-type-selector> selected-id="{{_xType}}"
>
<paper-button id="step">step</paper-button>
<paper-button id="relative">relative</paper-button>
<paper-button id="wall_time">wall</paper-button>
</tf-option-selector>
</div> </div>
<div class="sidebar-section"> <div class="sidebar-section">
<tf-run-selector <tf-run-selector
@ -80,7 +85,7 @@ tf-collapsable-panes.
tag="[[tag]]" tag="[[tag]]"
id="chart" id="chart"
selected-runs="[[_array(run)]]" selected-runs="[[_array(run)]]"
x-type="[[xType]]" x-type="[[_xType]]"
data-provider="[[dataProvider]]" data-provider="[[dataProvider]]"
color-scale="[[colorScale]]" color-scale="[[colorScale]]"
on-keyup="toggleSelected" on-keyup="toggleSelected"
@ -117,6 +122,10 @@ tf-collapsable-panes.
type: Array, type: Array,
computed: "_getVisibleTags(selectedRuns.*, run2tag.*)" computed: "_getVisibleTags(selectedRuns.*, run2tag.*)"
}, },
_xType: {
type: String,
value: "step"
},
dataType: {value: "compressedHistogram"}, dataType: {value: "compressedHistogram"},
}, },
_exists: function(run, tag) { _exists: function(run, tag) {

View File

@ -1,7 +1,7 @@
<link rel="import" href="../polymer/polymer.html"> <link rel="import" href="../polymer/polymer.html">
<link rel="import" href="tf-run-selector.html"> <link rel="import" href="tf-run-selector.html">
<link rel="import" href="tf-smoothing-input.html"> <link rel="import" href="tf-smoothing-input.html">
<link rel="import" href="tf-x-type-selector.html"> <link rel="import" href="../tf-option-selector/tf-option-selector.html">
<link rel="import" href="../tf-color-scale/tf-color-scale.html"> <link rel="import" href="../tf-color-scale/tf-color-scale.html">
<link rel="import" href="../tf-categorizer/tf-categorizer.html"> <link rel="import" href="../tf-categorizer/tf-categorizer.html">
<link rel="import" href="../tf-chart-scaffold/tf-chart-scaffold.html"> <link rel="import" href="../tf-chart-scaffold/tf-chart-scaffold.html">
@ -59,10 +59,15 @@ The #center div contains tf-line-charts embedded inside tf-collapsable-panes.
></tf-smoothing-input> ></tf-smoothing-input>
</div> </div>
<div class="sidebar-section"> <div class="sidebar-section">
<tf-x-type-selector <tf-option-selector
id="xTypeSelector" id="xTypeSelector"
out-x-type="{{xType}}" name="Horizontal Axis"
></tf-x-type-selector> selected-id="{{_xType}}"
>
<paper-button id="step">step</paper-button>
<paper-button id="relative">relative</paper-button>
<paper-button id="wall_time">wall</paper-button>
</tf-option-selector>
</div> </div>
<div class="sidebar-section"> <div class="sidebar-section">
<tf-run-selector <tf-run-selector
@ -92,7 +97,7 @@ The #center div contains tf-line-charts embedded inside tf-collapsable-panes.
> >
<vz-line-chart <vz-line-chart
id="chart" id="chart"
x-type="[[xType]]" x-type="[[_xType]]"
color-scale="[[colorScale]]" color-scale="[[colorScale]]"
smoothing-decay="[[_smoothingDecay]]" smoothing-decay="[[_smoothingDecay]]"
smoothing-enabled="[[_smoothingEnabled]]" smoothing-enabled="[[_smoothingEnabled]]"
@ -160,6 +165,10 @@ The #center div contains tf-line-charts embedded inside tf-collapsable-panes.
type: Object, type: Object,
notify: true, notify: true,
}, },
_xType: {
type: String,
value: "step"
}
}, },
attached: function() { attached: function() {
this.async(function() { this.async(function() {

View File

@ -16,7 +16,8 @@ limitations under the License.
/* tslint:disable:no-namespace */ /* tslint:disable:no-namespace */
module TF.Globals { module TF.Globals {
// The names of TensorBoard tabs. // The names of TensorBoard tabs.
export var TABS = ['events', 'images', 'audio', 'graphs', 'distributions']; export var TABS =
['events', 'images', 'audio', 'graphs', 'distributions', 'histograms'];
// If true, TensorBoard stores its hash in the URI state. // If true, TensorBoard stores its hash in the URI state.
// If false, tab switching in TensorBoard will not update location hash, // If false, tab switching in TensorBoard will not update location hash,

View File

@ -0,0 +1,90 @@
node {
name: "life"
op: "Const"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
}
int_val: 2
}
}
}
}
node {
name: "universe"
op: "Const"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
}
int_val: 40
}
}
}
}
node {
name: "everything"
op: "Const"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
}
int_val: 0
}
}
}
}
node {
name: "Add"
op: "Add"
input: "life"
input: "universe"
attr {
key: "T"
value {
type: DT_INT32
}
}
}
node {
name: "answer"
op: "Add"
input: "Add"
input: "everything"
attr {
key: "T"
value {
type: DT_INT32
}
}
}
versions {
producer: 10
}

View File

@ -0,0 +1,28 @@
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<link rel="import" href="../tf-graph-app.html">
<link rel="import" href="../../iron-demo-helpers/demo-snippet.html">
<style>
body {
margin: 0;
}
</style>
</head>
<body>
<h3>Answer to the Ultimate Question of Life, the Universe, and Everything</h3>
<demo-snippet>
<template>
<tf-graph-app id="tfgraph"></tf-graph-app>
<script>
let g = document.querySelector("#tfgraph");
fetch("graph.pbtxt").then(r => r.text()).then(pbtxt => {
g.pbtxt = pbtxt;
});
</script>
</template>
</demo-snippet>
</body>
</html>

View File

@ -0,0 +1,14 @@
<!doctype html>
<html>
<head>
<title>vz-vega</title>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<script src="../webcomponentsjs/webcomponents-lite.js"></script>
<link rel="import" href="../iron-component-page/iron-component-page.html">
</head>
<body>
<iron-component-page src="tf-graph-app.html"></iron-component-page>
</body>
</html>

View File

@ -2,12 +2,22 @@
<link rel="import" href="../tf-graph-board/tf-graph-board.html"> <link rel="import" href="../tf-graph-board/tf-graph-board.html">
<link rel="import" href="../tf-graph-loader/tf-graph-loader.html"> <link rel="import" href="../tf-graph-loader/tf-graph-loader.html">
<link rel="import" href="../tf-graph/tf-graph-controls.html"> <link rel="import" href="../tf-graph/tf-graph-controls.html">
<!-- Stand alone element of tf-graph for embedding. <!--
Stand alone element of tf-graph for embedding.
Example The pbtxt format is the stringified version of the graphdef.
<tf-graph-app pbtxt="[[pbtxt]]"></tf-graph-app> <tf-graph-app pbtxt="[[pbtxt]]"></tf-graph-app>
import tensorflow as tf
life = tf.constant(2, name='life')
universe = tf.constant(40, name='universe')
everything = tf.constant(0, name='everything')
lifeuniverse = tf.add(life, universe)
answer = tf.add(lifeuniverse, everything, name='answer')
open("graph.pbtxt", "w").write(str(tf.get_default_graph().as_graph_def()))
@demo
--> -->
<dom-module id="tf-graph-app"> <dom-module id="tf-graph-app">

View File

@ -442,7 +442,7 @@
_getHasDisplayableNodeStats: function(stats) { _getHasDisplayableNodeStats: function(stats) {
return tf.graph.util.hasDisplayableNodeStats(stats); return tf.graph.util.hasDisplayableNodeStats(stats);
}, },
_getNodeStatsFormattedBytes(stats) { _getNodeStatsFormattedBytes: function(stats) {
if (!stats || !stats.totalBytes) { if (!stats || !stats.totalBytes) {
return; return;
} }
@ -450,7 +450,7 @@
return tf.graph.util.convertUnitsToHumanReadable( return tf.graph.util.convertUnitsToHumanReadable(
stats.totalBytes, tf.graph.util.MEMORY_UNITS); stats.totalBytes, tf.graph.util.MEMORY_UNITS);
}, },
_getNodeStatsFormattedComputeTime(stats) { _getNodeStatsFormattedComputeTime: function(stats) {
if (!stats || !stats.totalMicros) { if (!stats || !stats.totalMicros) {
return; return;
} }
@ -458,7 +458,7 @@
return tf.graph.util.convertUnitsToHumanReadable( return tf.graph.util.convertUnitsToHumanReadable(
stats.totalMicros, tf.graph.util.TIME_UNITS); stats.totalMicros, tf.graph.util.TIME_UNITS);
}, },
_getNodeStatsFormattedOutputSizes(stats) { _getNodeStatsFormattedOutputSizes: function(stats) {
if (!stats || !stats.outputSize || !stats.outputSize.length) { if (!stats || !stats.outputSize || !stats.outputSize.length) {
return; return;
} }

View File

@ -0,0 +1,202 @@
<link rel="import" href="../polymer/polymer.html">
<link rel="import" href="../tf-chart-scaffold/tf-chart-scaffold.html">
<link rel="import" href="../tf-event-dashboard/tf-run-selector.html">
<link rel="import" href="../tf-color-scale/tf-color-scale.html">
<link rel="import" href="../tf-dashboard-common/tf-dashboard.html">
<link rel="import" href="../tf-categorizer/tf-categorizer.html">
<link rel="import" href="../tf-option-selector/tf-option-selector.html">
<link rel="import" href="../tf-collapsable-pane/tf-collapsable-pane.html">
<link rel="import" href="../vz-histogram-timeseries/vz-histogram-timeseries.html">
<link rel="import" href="../iron-collapse/iron-collapse.html">
<link rel="import" href="../paper-icon-button/paper-icon-button.html">
<link rel="import" href="../tf-imports/lodash.html">
<link rel="import" href="../tf-backend/tf-backend.html">
<!--
tf-histogram-dashboard is a complete frontend that loads runs from a backend,
and creates chart panes that display data for those runs.
It provides a categorizer, run selector, and x type selector, by which the user
can customize how data is organized and displayed.
Each chart has a button that can toggle whether it is "selected"; selectedRuns
charts are larger.
Organizationally, the #plumbing div contains components that have no concrete
manifestation and just effect data bindings or data loading. The #sidebar contains
shared controls like the tf-categorizer, tf-run-selector, and tf-x-type-selector.
The #center div contains vz-histogram-timeseries embedded inside
tf-collapsable-panes.
-->
<dom-module id="tf-histogram-dashboard">
<template>
<div id="plumbing">
<tf-color-scale
id="colorScale"
runs="[[runs]]"
out-color-scale="{{colorScale}}"
></tf-color-scale>
</div>
<tf-dashboard-layout>
<div class="sidebar">
<div class="sidebar-section">
<tf-categorizer
id="categorizer"
tags="[[_visibleTags]]"
categories="{{categories}}"
></tf-categorizer>
</div>
<div class="sidebar-section">
<tf-option-selector
id="histogramModeSelector"
name="Histogram Mode"
selected-id="{{_histogramMode}}"
>
<paper-button id="overlay">overlay</paper-button>
<paper-button id="offset">offset</paper-button>
</tf-option-selector>
</div>
<div class="sidebar-section">
<tf-option-selector
id="timePropertySelector"
name="Offset Time Axis"
selected-id="{{_timeProperty}}"
>
<paper-button id="step">step</paper-button>
<paper-button id="relative">relative</paper-button>
<paper-button id="wall_time">wall</paper-button>
</tf-option-selector>
</div>
<div class="sidebar-section">
<tf-run-selector
id="runSelector"
runs="[[runs]]"
color-scale="[[colorScale]]"
out-selected="{{selectedRuns}}"
></tf-run-selector>
</div>
</div>
<div class="center">
<tf-no-data-warning
data-type="histogram"
show-warning="[[dataNotFound]]"
></tf-no-data-warning>
<template is="dom-repeat" items="[[categories]]">
<tf-collapsable-pane name="[[item.name]]" count="[[_count(item.tags, selectedRuns.*, runToCompressedHistograms.*)]]">
<div class="layout horizontal wrap">
<template is="dom-repeat" items="[[item.tags]]" as="tag">
<template is="dom-repeat" items="[[selectedRuns]]" as="run">
<template is="dom-if" if="[[_exists(run, tag, run2tag.*)]]">
<div class="card">
<span class="card-title">[[tag]]</span>
<span class="card-subtitle">[[run]]</span>
<div class="card-content">
<tf-chart-scaffold
tag="[[tag]]"
visible-series="[[_array(run)]]"
data-provider="[[dataProvider]]"
>
<vz-histogram-timeseries
id="chart"
time-property="[[_timeProperty]]"
mode="[[_histogramMode]]"
color-scale="[[_colorScaleFunction]]"
on-keyup="toggleSelected"
tabindex="2"
></vz-histogram-timeseries>
</tf-chart-scaffold>
<paper-icon-button
class="expand-button"
icon="fullscreen"
on-tap="toggleSelected"
></paper-icon-button>
</div>
</div>
</template>
</template>
</template>
</div>
</tf-collapsable-pane>
</template>
</div>
</tf-dashboard-layout>
<style include="dashboard-style"></style>
</template>
<script>
Polymer({
is: "tf-histogram-dashboard",
behaviors: [
TF.Dashboard.ReloadBehavior("tf-chart-scaffold"),
TF.Backend.Behavior,
],
properties: {
_histogramMode: {
type: String,
value: "offset"
},
_timeProperty: {
type: String,
value: "step"
},
_visibleTags: {
type: Array,
computed: "_getVisibleTags(selectedRuns.*, run2tag.*)"
},
_colorScaleFunction: {
type: Function,
computed: "_getColorScaleFunction(colorScale)"
},
colorScale: Object,
dataType: {value: "histogram"},
},
_exists: function(run, tag) {
return this.run2tag[run].indexOf(tag) !== -1;
},
attached: function() {
this.async(function() {
this.fire("rendered");
});
},
_array: function(x) {
return [x];
},
_count: function(tags) {
var targetTags = {};
tags.forEach(function(t) {
targetTags[t] = true;
});
var count = 0;
var _this = this;
this.selectedRuns.forEach(function(r) {
_this.run2tag[r].forEach(function(t) {
if (targetTags[t]) {
count++;
}
});
});
return count;
},
_getVisibleTags: function() {
var keys = this.selectedRuns;
var dict = this.run2tag;
return _.union.apply(null, keys.map(function(k) {return dict[k]}));
},
_getColorScaleFunction: function() {
return this.colorScale.scale.bind(this.colorScale);
},
toggleSelected: function(e) {
var currentTarget = Polymer.dom(e.currentTarget);
var parentDiv = currentTarget.parentNode.parentNode;
parentDiv.classList.toggle("selected");
var chartScaffold = currentTarget.previousElementSibling;
if (chartScaffold) {
chartScaffold.chart().redraw();
}
},
});
</script>
</dom-module>

View File

@ -0,0 +1,77 @@
<link rel="import" href="../polymer/polymer.html">
<link rel="import" href="../tf-dashboard-common/tensorboard-color.html">
<!--
tf-option-selector is a simple component that has buttons as content and
provides a "selectedId" property that is one of the IDs of the buttons inside it.
-->
<dom-module id="tf-option-selector">
<template>
<div id="wrap">
<h3>[[name]]</h3>
<div class="content-wrapper"><content></content></div>
</div>
<style>
.content-wrapper ::content > * {
width: 30%;
font-size: 13px;
background: none;
margin-top: 10px;
color: var(--tb-ui-dark-accent);
}
.content-wrapper ::content :first-of-type {
margin-left: 0;
}
.content-wrapper ::content .selected {
background-color: var(--tb-ui-dark-accent);
color: white!important;
}
h3 {
color: var(--paper-grey-800);
margin: 0;
font-weight: normal;
font-size: 14px;
margin-bottom: 5px;
display: block;
pointer-events: none;
}
</style>
</template>
<script>
Polymer({
is: "tf-option-selector",
properties: {
name: String,
selectedId: {
type: String,
notify: true,
observer: '_selectedIdChanged'
}
},
attached: function() {
this.async(function() {
this.getEffectiveChildren().forEach(function(node) {
this.listen(node, 'tap', '_selectTarget');
}.bind(this));
});
},
_selectTarget: function(e) {
this.selectedId = e.currentTarget.id;
},
_selectedIdChanged: function() {
var selected = this.queryEffectiveChildren('#' + this.selectedId);
if (!selected) {
return;
}
this.getEffectiveChildren().forEach(function(node) {
node.classList.remove("selected");
});
selected.classList.add("selected");
}
});
</script>
</dom-module>

View File

@ -8,6 +8,7 @@
<link rel="import" href="../tf-globals/tf-globals.html"> <link rel="import" href="../tf-globals/tf-globals.html">
<link rel="import" href="../tf-event-dashboard/tf-event-dashboard.html"> <link rel="import" href="../tf-event-dashboard/tf-event-dashboard.html">
<link rel="import" href="../tf-distribution-dashboard/tf-distribution-dashboard.html"> <link rel="import" href="../tf-distribution-dashboard/tf-distribution-dashboard.html">
<link rel="import" href="../tf-histogram-dashboard/tf-histogram-dashboard.html">
<link rel="import" href="../tf-image-dashboard/tf-image-dashboard.html"> <link rel="import" href="../tf-image-dashboard/tf-image-dashboard.html">
<link rel="import" href="../tf-audio-dashboard/tf-audio-dashboard.html"> <link rel="import" href="../tf-audio-dashboard/tf-audio-dashboard.html">
<link rel="import" href="../tf-graph-dashboard/tf-graph-dashboard.html"> <link rel="import" href="../tf-graph-dashboard/tf-graph-dashboard.html">
@ -96,6 +97,13 @@ allows the user to toggle between various dashboards.
backend="[[_backend]]" backend="[[_backend]]"
></tf-distribution-dashboard> ></tf-distribution-dashboard>
</template> </template>
<template is="dom-if" if="[[_modeIsHistograms(mode)]]">
<tf-histogram-dashboard
id="histograms"
backend="[[_backend]]"
></tf-histogram-dashboard>
</template>
</div> </div>
</paper-header-panel> </paper-header-panel>
@ -230,6 +238,9 @@ allows the user to toggle between various dashboards.
_modeIsDistributions: function(mode) { _modeIsDistributions: function(mode) {
return mode === "distributions"; return mode === "distributions";
}, },
_modeIsHistograms: function(mode) {
return mode === "histograms";
},
selectedDashboard: function() { selectedDashboard: function() {
var dashboard = this.$$("#" + this.mode); var dashboard = this.$$("#" + this.mode);
if (dashboard == null) { if (dashboard == null) {

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,14 @@
<!doctype html>
<html>
<head>
<title>vz-histogram-timeseries</title>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<script src="../webcomponentsjs/webcomponents-lite.js"></script>
<link rel="import" href="../iron-component-page/iron-component-page.html">
</head>
<body>
<iron-component-page src="vz-histogram-timeseries.html"></iron-component-page>
</body>
</html>

View File

@ -1,15 +1,45 @@
<link rel="import" href="../polymer/polymer.html"> <link rel="import" href="../polymer/polymer.html">
<link rel="import" href="../tf-imports/d3.html">
<!--
vz-histogram-timeseries creates an element that draws beautiful histograms for
displaying how data is distributed over time.
This histogram supports changing the time axis type and different modes of
visualization.
@demo
-->
<dom-module id="vz-histogram-timeseries"> <dom-module id="vz-histogram-timeseries">
<template> <template>
<svg id="svg">
<g>
<g class="axis x"></g>
<g class="axis y"></g>
<g class="axis y slice"></g>
<g class="stage">
<rect class="background"></rect>
</g>
<g class="x-axis-hover"></g>
</g>
</svg>
<style> <style>
:host { :host {
display: block; display: flex;
flex-direction: column;
flex-grow: 1;
flex-shrink: 1;
position: relative;
} }
svg { svg {
font-family: roboto, sans-serif; font-family: roboto, sans-serif;
overflow: visible;
display: block;
width: 100%;
flex-grow: 1;
flex-shrink: 1;
} }
.background { .background {
@ -100,52 +130,144 @@
.large .axis .tick:nth-child(2n + 1) text { display: block; } .large .axis .tick:nth-child(2n + 1) text { display: block; }
</style> </style>
<svg id="svg">
<g>
<g class="axis x"></g>
<g class="axis y"></g>
<g class="axis y slice"></g>
<g class="stage">
<rect class="background"></rect>
</g>
<g class="x-axis-hover"></g>
</g>
</svg>
</template> </template>
<script> <script>
"use strict";
Polymer({ Polymer({
is: "vz-histogram-timeseries", is: "vz-histogram-timeseries",
properties: { properties: {
mode: { type: String, value: "offset" }, //offset | overlay /**
width: { type: Number, value: 500 }, * Defines which view mode is being used by the chart. Supported values
height: { type: Number, value: 500 }, * are:
timeProperty: { type: String, value: "step" }, * - "offset" - Offset view of the data showing all timesteps.
bins: { type: String, value: "bins" }, * - "overlay" - Overlays all timesteps into one 2D view, with the
x: { type: String, value: "x" }, * brighter lines representing the newer timesteps.
dx: { type: String, value: "dx" }, */
y: { type: String, value: "y" }, mode: {
data: { type: Array, value: function(){ return [{ step: 0, bins: [{ x: 0, dx: 1, y: 0 }] }, { step: 1, bins: [{ x: 0, dx: 1, y: 0 }] }];}} type: String,
// type: HistogramSeriesDatum[] as described in vz-histogram-timeseries.d.ts value: "offset"
},
/*
* The name of the datum's property that contains the time values.
* Allows:
* - "step" - Linear scale using the "step" property of the datum.
* - "wall_time" - Temporal scale using the "wall_time" property of the
* datum.
* - "relative" - Temporal scale starting at 0 created by using
* the "wall_time" property of the datum.
*/
timeProperty: {
type: String,
value: "step"
},
/**
* The name of the data's property that contains the bins.
*/
bins: {
type: String,
value: "bins"
},
/**
* The name of the datum's property that contains the x values.
*/
x: {
type: String,
value: "x"
},
/**
* The name of the datum's property that contains the bin width values.
*/
dx: {
type: String,
value: "dx"
},
/**
* The name of the datum's property that contains the bin height.
*/
y: {
type: String,
value: "y"
},
/**
* Scale that maps series names to colors. The default colors are from
* d3.scale.category10() scale. Use this property to replace the default
* line colors with colors of your own choice.
*/
colorScale: {
type: Object,
value: function() {
return d3.scale.category10();
}
},
/**
* Duration of the transition between histogram modes.
*/
modeTransitionDuration: {
type: Number,
value: 500
},
_attached: Boolean,
_name: { type: String, value: null },
_data: { type: Array, value: null },
}, },
observers: [
'redraw(timeProperty, _attached)',
'_modeRedraw(mode)'
],
ready: function() { ready: function() {
// Polymer's way of scoping styles on nodes that d3 created // Polymer's way of scoping styles on nodes that d3 created
this.scopeSubtree(this.$["svg"], true); this.scopeSubtree(this.$.svg, true);
}, },
draw: function(duration) { attached: function() {
this._attached = true;
},
detached: function() {
this._attached = false;
},
setVisibleSeries: function(names) {
// Do nothing.
},
setSeriesData: function(name, data) {
this._name = name;
this._data = data;
this.redraw();
},
/**
* Redraws the chart. This is only called if the chart is attached to the
* screen and if the chart has data.
*/
redraw: function() {
this._draw(0);
},
_modeRedraw: function() {
this._draw(this.modeTransitionDuration);
},
_draw: function(duration) {
if (!this._attached || !this._data) {
return;
}
// //
// Data verification // Data verification
// //
if (!(this.data.length > 0)) throw(new Error("Not enough steps in the data")); if (duration === undefined) throw(new Error("vz-histogram-timeseries _draw needs duration"));
if (!this.data[0].hasOwnProperty(this.timeProperty)) throw(new Error("No time property of '" + this.timeProperty + "' in data")); if (this._data.length <= 0) throw(new Error("Not enough steps in the data"));
if (!this.data[0].hasOwnProperty(this.bins)) throw(new Error("No bins property of '" + this.bins + "' in data")); if (!this._data[0].hasOwnProperty(this.bins)) throw(new Error("No bins property of '" + this.bins + "' in data"));
if (!(this.data[0][this.bins].length > 0)) throw(new Error("Must have at least one bin in bins in data")); if (this._data[0][this.bins].length <= 0) throw(new Error("Must have at least one bin in bins in data"));
if (!this.data[0][this.bins][0].hasOwnProperty(this.x)) throw(new Error("No x property '" + this.x + "' on bins data")); if (!this._data[0][this.bins][0].hasOwnProperty(this.x)) throw(new Error("No x property '" + this.x + "' on bins data"));
if (!this.data[0][this.bins][0].hasOwnProperty(this.dx)) throw(new Error("No dx property '" + this.dx + "' on bins data")); if (!this._data[0][this.bins][0].hasOwnProperty(this.dx)) throw(new Error("No dx property '" + this.dx + "' on bins data"));
if (!this.data[0][this.bins][0].hasOwnProperty(this.y)) throw(new Error("No y property '" + this.y + "' on bins data")); if (!this._data[0][this.bins][0].hasOwnProperty(this.y)) throw(new Error("No y property '" + this.y + "' on bins data"));
// //
// Initialization // Initialization
@ -156,18 +278,24 @@
var dxProp = this.dx; var dxProp = this.dx;
var yProp = this.y; var yProp = this.y;
var xAccessor = (d) => d[xProp]; var data = this._data;
var yAccessor = (d) => d[yProp]; var name = this._name;
var dxAccessor = (d) => d[dxProp];
var xRightAccessor = (d) => d[xProp] + d[dxProp];
var timeAccessor = (d) => d[timeProp];
var duration = duration | 0;
var data = this.data;
var mode = this.mode; var mode = this.mode;
var color = d3.hcl(this.colorScale(name));
var outerWidth = this.width, var xAccessor = function(d) { return d[xProp] };
outerHeight = this.height; var yAccessor = function(d) { return d[yProp] };
var dxAccessor = function(d) { return d[dxProp] };
var xRightAccessor = function(d) { return d[xProp] + d[dxProp] };
var timeAccessor = function(d) { return d[timeProp] };
if (timeProp === "relative") {
timeAccessor = function(d) { return d.wall_time - data[0].wall_time };
}
var brect = this.$.svg.getBoundingClientRect();
var outerWidth = brect.width,
outerHeight = brect.height;
var sliceHeight, var sliceHeight,
margin = {top: 5, right: 60, bottom: 20, left: 24}; margin = {top: 5, right: 60, bottom: 20, left: 24};
@ -188,8 +316,16 @@
// //
// Text formatters // Text formatters
// //
var formatTime = d3.time.format("%x"), var format = d3.format(".3n");
format = d3.format(".3n"); var yAxisFormat = d3.format(".0f");
if (timeProp === "wall_time") {
yAxisFormat = d3.time.format("%X");
} else if (timeProp === "relative") {
yAxisFormat = function(d) {
return d3.format(".1r")(d / 3.6e6); // Convert to hours.
};
}
// //
// Calculate the extents // Calculate the extents
@ -209,8 +345,10 @@
// //
var outlineCanvasSize = 500; var outlineCanvasSize = 500;
var yScale = (timeProp === "step" ? d3.scale.linear() : d3.time.scale()) var extent = d3.extent(data, timeAccessor);
.domain(d3.extent(data, timeAccessor))
var yScale = (timeProp === "wall_time" ? d3.time.scale() : d3.scale.linear())
.domain(extent)
.range([0, (mode === "offset" ? height : 0)]); .range([0, (mode === "offset" ? height : 0)]);
var ySliceScale = d3.scale.linear() var ySliceScale = d3.scale.linear()
@ -235,7 +373,7 @@
var outlineColor = d3.scale.linear() var outlineColor = d3.scale.linear()
.domain(d3.extent(data, timeAccessor)) .domain(d3.extent(data, timeAccessor))
.range(["#FFA726", "#BF360C"]) .range([color.darker(), color.brighter()])
.interpolate(d3.interpolateHcl); .interpolate(d3.interpolateHcl);
var xAxis = d3.svg.axis() var xAxis = d3.svg.axis()
@ -245,20 +383,31 @@
var yAxis = d3.svg.axis() var yAxis = d3.svg.axis()
.scale(yScale) .scale(yScale)
.ticks(Math.max(2, width / 20)) .ticks(Math.max(2, height / 15))
.tickFormat(yAxisFormat)
.orient("right"); .orient("right");
var ySliceAxis = d3.svg.axis() var ySliceAxis = d3.svg.axis()
.scale(ySliceScale) .scale(ySliceScale)
.ticks(Math.max(2, width / 20)) .ticks(Math.max(2, height / 15))
.tickSize(width + 5) .tickSize(width + 5)
.orient("right"); .orient("right");
var path = d3.svg.area() var xBinCentroid = function(d) {
return d[xProp] + d[dxProp] / 2;
};
var linePath = d3.svg.line()
.interpolate("linear") .interpolate("linear")
.x(function(d) { return xLineScale(d[xProp] + d[dxProp] / 2); }) .x(function(d) { return xLineScale(xBinCentroid(d)); })
.y0(function(d) { return yLineScale(0); }) .y(function(d) { return yLineScale(d[yProp]); });
.y1(function(d) { return yLineScale(d[yProp]); });
var path = function(d) {
// Draw a line from 0 to the first point and from the last point to 0.
return 'M' + xLineScale(xBinCentroid(d[0])) + ',' + yLineScale(0) +
'L' + linePath(d).slice(1) +
"L" + xLineScale(xBinCentroid(d[d.length - 1])) + "," + yLineScale(0);
};
// //
// Render // Render
@ -318,14 +467,14 @@
.attr("width", outerWidth) .attr("width", outerWidth)
.attr("height", outerHeight); .attr("height", outerHeight);
var histogram = stage.selectAll(".histogram").data(data, function(d) { return d[timeProp]; }), var histogram = stage.selectAll(".histogram").data(data),
histogramExit = histogram.exit().remove(), histogramExit = histogram.exit().remove(),
histogramEnter = histogram.enter().append("g").attr("class", "histogram"), histogramEnter = histogram.enter().append("g").attr("class", "histogram"),
histogramUpdate = histogram histogramUpdate = histogram
.sort(function(a, b) { return a[timeProp] - b[timeProp]; }), .sort(function(a, b) { return timeAccessor(a) - timeAccessor(b); }),
histogramTransition = gTransition.selectAll(".histogram") histogramTransition = gTransition.selectAll(".histogram")
.attr("transform", function(d) { .attr("transform", function(d) {
return "translate(0, " + (mode === "offset" ? (yScale(d[timeProp]) - sliceHeight) : 0) + ")"; return "translate(0, " + (mode === "offset" ? (yScale(timeAccessor(d)) - sliceHeight) : 0) + ")";
}); });
var baselineEnter = histogramEnter.append("line").attr("class", "baseline"), var baselineEnter = histogramEnter.append("line").attr("class", "baseline"),
@ -342,14 +491,14 @@
.style("stroke-width", 1), .style("stroke-width", 1),
outlineTransition = histogramTransition.select(".outline") outlineTransition = histogramTransition.select(".outline")
.attr("transform", "scale(" + width / outlineCanvasSize + ", " + sliceHeight / outlineCanvasSize + ")") .attr("transform", "scale(" + width / outlineCanvasSize + ", " + sliceHeight / outlineCanvasSize + ")")
.style("stroke", function(d) { return (mode === "offset" ? "white" : outlineColor(d[timeProp])); }) .style("stroke", function(d) { return (mode === "offset" ? "white" : outlineColor(timeAccessor(d))); })
.style("fill-opacity", function(d) { return (mode === "offset" ? 1 : 0); }) .style("fill-opacity", function(d) { return (mode === "offset" ? 1 : 0); })
.style("fill", function(d) { return outlineColor(d[timeProp]); }); .style("fill", function(d) { return outlineColor(timeAccessor(d)); });
var hoverEnter = histogramEnter.append("g") var hoverEnter = histogramEnter.append("g")
.attr("class", "hover") .attr("class", "hover")
.style("fill", function(d) { return outlineColor(d[timeProp]); }), .style("fill", function(d) { return outlineColor(timeAccessor(d)); }),
hoverUpdate = histogramUpdate.select(".hover"); hoverUpdate = histogramUpdate.select(".hover");
hoverEnter.append("circle") hoverEnter.append("circle")
@ -397,7 +546,6 @@
.style("opacity", mode === "offset" ? 1 : 0) .style("opacity", mode === "offset" ? 1 : 0)
.attr("transform", "translate(" + width + ", " + (mode === "offset" ? 0 : height) + ")") .attr("transform", "translate(" + width + ", " + (mode === "offset" ? 0 : height) + ")")
.call(yAxis); .call(yAxis);
} }
}); });
</script> </script>

View File

@ -225,11 +225,11 @@ smoothing.
this.scopeSubtree(this.$.chartsvg, true); this.scopeSubtree(this.$.chartsvg, true);
}, },
_makeChart: function(xType, colorScale, _attached) { _makeChart: function(xType, colorScale, _attached) {
if (this._makeChartAsyncHandle === null) { if (this._makeChartAsyncCallbackId === null) {
this.cancelAsync(this._makeChartAsyncCallbackId); this.cancelAsync(this._makeChartAsyncCallbackId);
} }
this._makeChartAsyncHandle = this.async(function() { this._makeChartAsyncCallbackId = this.async(function() {
this._makeChartAsyncCallbackId = null; this._makeChartAsyncCallbackId = null;
if (!this._attached) return; if (!this._attached) return;
if (this._chart) this._chart.destroy(); if (this._chart) this._chart.destroy();
@ -238,7 +238,7 @@ smoothing.
var svg = d3.select(this.$.chartsvg); var svg = d3.select(this.$.chartsvg);
chart.renderTo(svg); chart.renderTo(svg);
this._chart = chart; this._chart = chart;
}.bind(this), 350); }, 350);
}, },
_reloadFromCache: function() { _reloadFromCache: function() {
if(this._chart) { if(this._chart) {

View File

@ -419,6 +419,10 @@ module VZ {
this.datasets = names.map((r) => this.getDataset(r)); this.datasets = names.map((r) => this.getDataset(r));
this.datasets.forEach((d) => d.onUpdate(this.onDatasetChanged)); this.datasets.forEach((d) => d.onUpdate(this.onDatasetChanged));
this.linePlot.datasets(this.datasets); this.linePlot.datasets(this.datasets);
if (this.smoothingEnabled) {
this.smoothLinePlot.datasets(this.datasets);
}
} }
/** /**

File diff suppressed because it is too large Load Diff

View File

@ -42,14 +42,14 @@ Instead, use `gulp regenerate` to create a new version with your changes.\n\
/** /**
* Returns a list of non-tensorboard components inside the components * Returns a list of non-tensorboard components inside the components
* directory, i.e. components that don't begin with 'tf-'. * directory, i.e. components that don't begin with 'tf-' or 'vz-''.
*/ */
function getNonTensorBoardComponents() { function getNonTensorBoardComponents() {
return fs.readdirSync('components') return fs.readdirSync('components')
.filter(function(file) { .filter(function(file) {
var prefix = file.slice(0,3); var prefix = file.slice(0,3);
return fs.statSync(path.join('components', file)).isDirectory() && return fs.statSync(path.join('components', file)).isDirectory() &&
prefix !== 'tf-'; prefix !== 'tf-' && prefix !== 'vz-';
}) })
.map(function(dir) { return '/' + dir + '/'; }); .map(function(dir) { return '/' + dir + '/'; });
} }

View File

@ -94,7 +94,7 @@ def main(unused_argv=None):
if FLAGS.inspect: if FLAGS.inspect:
logging.info('Not bringing up TensorBoard, but inspecting event files.') logging.info('Not bringing up TensorBoard, but inspecting event files.')
efi.inspect(logdir=FLAGS.logdir, efi.inspect(logdir=FLAGS.logdir,
event_file=FLAGS.event_file, event_file=os.path.expanduser(FLAGS.event_file),
tag=FLAGS.tag) tag=FLAGS.tag)
return 0 return 0