Automated rollback of change 142681535
Change: 142694447
This commit is contained in:
parent
e884a9cadb
commit
e959223fd3
@ -450,23 +450,12 @@ cc_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cloud_ops_op_lib",
|
||||
srcs = ["ops/cloud_ops.cc"],
|
||||
copts = tf_copts(),
|
||||
linkstatic = 1,
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":framework"],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "ops",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":array_ops_op_lib",
|
||||
":candidate_sampling_ops_op_lib",
|
||||
":cloud_ops_op_lib",
|
||||
":control_flow_ops_op_lib",
|
||||
":ctc_ops_op_lib",
|
||||
":data_flow_ops_op_lib",
|
||||
@ -613,7 +602,6 @@ cc_library(
|
||||
"//tensorflow/core/kernels:string",
|
||||
"//tensorflow/core/kernels:training_ops",
|
||||
"//tensorflow/core/kernels:word2vec_kernels",
|
||||
"//tensorflow/core/kernels/cloud:bigquery_reader_ops",
|
||||
] + if_not_windows([
|
||||
"//tensorflow/core/kernels:fact_op",
|
||||
"//tensorflow/core/kernels:array_not_windows",
|
||||
|
@ -9,7 +9,6 @@ licenses(["notice"]) # Apache 2.0
|
||||
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_kernel_library",
|
||||
"tf_cc_test",
|
||||
)
|
||||
|
||||
@ -31,24 +30,6 @@ filegroup(
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "bigquery_reader_ops",
|
||||
srcs = [
|
||||
"bigquery_reader_ops.cc",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":bigquery_table_accessor",
|
||||
":bigquery_table_partition_proto_cc",
|
||||
"//tensorflow/core:cloud_ops_op_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/kernels:reader_base",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "bigquery_table_accessor",
|
||||
srcs = [
|
||||
|
@ -1,193 +0,0 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
|
||||
#include "tensorflow/core/example/example.pb.h"
|
||||
#include "tensorflow/core/framework/reader_op_kernel.h"
|
||||
#include "tensorflow/core/kernels/cloud/bigquery_table_accessor.h"
|
||||
#include "tensorflow/core/kernels/cloud/bigquery_table_partition.pb.h"
|
||||
#include "tensorflow/core/kernels/reader_base.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/math/math_util.h"
|
||||
#include "tensorflow/core/lib/strings/numbers.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
constexpr int64 kDefaultRowBufferSize = 1000; // Number of rows to buffer.
|
||||
|
||||
// This is a helper function for reading table attributes from context.
|
||||
Status GetTableAttrs(OpKernelConstruction* context, string* project_id,
|
||||
string* dataset_id, string* table_id,
|
||||
int64* timestamp_millis, std::vector<string>* columns,
|
||||
string* test_end_point) {
|
||||
TF_RETURN_IF_ERROR(context->GetAttr("project_id", project_id));
|
||||
TF_RETURN_IF_ERROR(context->GetAttr("dataset_id", dataset_id));
|
||||
TF_RETURN_IF_ERROR(context->GetAttr("table_id", table_id));
|
||||
TF_RETURN_IF_ERROR(context->GetAttr("timestamp_millis", timestamp_millis));
|
||||
TF_RETURN_IF_ERROR(context->GetAttr("columns", columns));
|
||||
TF_RETURN_IF_ERROR(context->GetAttr("test_end_point", test_end_point));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Note that overriden methods with names ending in "Locked" are called by
|
||||
// ReaderBase while a mutex is held.
|
||||
// See comments for ReaderBase.
|
||||
class BigQueryReader : public ReaderBase {
|
||||
public:
|
||||
explicit BigQueryReader(BigQueryTableAccessor* bigquery_table_accessor,
|
||||
const string& node_name)
|
||||
: ReaderBase(strings::StrCat("BigQueryReader '", node_name, "'")),
|
||||
bigquery_table_accessor_(CHECK_NOTNULL(bigquery_table_accessor)) {}
|
||||
|
||||
Status OnWorkStartedLocked() override {
|
||||
BigQueryTablePartition partition;
|
||||
if (!partition.ParseFromString(current_work())) {
|
||||
return errors::InvalidArgument(
|
||||
"Could not parse work as as valid partition.");
|
||||
}
|
||||
TF_RETURN_IF_ERROR(bigquery_table_accessor_->SetPartition(partition));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ReadLocked(string* key, string* value, bool* produced,
|
||||
bool* at_end) override {
|
||||
*at_end = false;
|
||||
*produced = false;
|
||||
if (bigquery_table_accessor_->Done()) {
|
||||
*at_end = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Example example;
|
||||
int64 row_id;
|
||||
TF_RETURN_IF_ERROR(bigquery_table_accessor_->ReadRow(&row_id, &example));
|
||||
|
||||
*key = std::to_string(row_id);
|
||||
*value = example.SerializeAsString();
|
||||
*produced = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
// Not owned.
|
||||
BigQueryTableAccessor* bigquery_table_accessor_;
|
||||
};
|
||||
|
||||
class BigQueryReaderOp : public ReaderOpKernel {
|
||||
public:
|
||||
explicit BigQueryReaderOp(OpKernelConstruction* context)
|
||||
: ReaderOpKernel(context) {
|
||||
string table_id;
|
||||
string project_id;
|
||||
string dataset_id;
|
||||
int64 timestamp_millis;
|
||||
std::vector<string> columns;
|
||||
string test_end_point;
|
||||
|
||||
OP_REQUIRES_OK(context,
|
||||
GetTableAttrs(context, &project_id, &dataset_id, &table_id,
|
||||
×tamp_millis, &columns, &test_end_point));
|
||||
OP_REQUIRES_OK(context,
|
||||
BigQueryTableAccessor::New(
|
||||
project_id, dataset_id, table_id, timestamp_millis,
|
||||
kDefaultRowBufferSize, test_end_point, columns,
|
||||
BigQueryTablePartition(), &bigquery_table_accessor_));
|
||||
|
||||
SetReaderFactory([this]() {
|
||||
return new BigQueryReader(bigquery_table_accessor_.get(), name());
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<BigQueryTableAccessor> bigquery_table_accessor_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("BigQueryReader").Device(DEVICE_CPU),
|
||||
BigQueryReaderOp);
|
||||
|
||||
class GenerateBigQueryReaderPartitionsOp : public OpKernel {
|
||||
public:
|
||||
explicit GenerateBigQueryReaderPartitionsOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {
|
||||
string project_id;
|
||||
string dataset_id;
|
||||
string table_id;
|
||||
int64 timestamp_millis;
|
||||
std::vector<string> columns;
|
||||
string test_end_point;
|
||||
|
||||
OP_REQUIRES_OK(context,
|
||||
GetTableAttrs(context, &project_id, &dataset_id, &table_id,
|
||||
×tamp_millis, &columns, &test_end_point));
|
||||
OP_REQUIRES_OK(context,
|
||||
BigQueryTableAccessor::New(
|
||||
project_id, dataset_id, table_id, timestamp_millis,
|
||||
kDefaultRowBufferSize, test_end_point, columns,
|
||||
BigQueryTablePartition(), &bigquery_table_accessor_));
|
||||
OP_REQUIRES_OK(context, InitializeNumberOfPartitions(context));
|
||||
OP_REQUIRES_OK(context, InitializeTotalNumberOfRows());
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const int64 partition_size = tensorflow::MathUtil::CeilOfRatio<int64>(
|
||||
total_num_rows_, num_partitions_);
|
||||
Tensor* output_tensor = nullptr;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(0, TensorShape({num_partitions_}),
|
||||
&output_tensor));
|
||||
|
||||
auto output = output_tensor->template flat<string>();
|
||||
for (int64 i = 0; i < num_partitions_; ++i) {
|
||||
BigQueryTablePartition partition;
|
||||
partition.set_start_index(i * partition_size);
|
||||
partition.set_end_index(
|
||||
std::min(total_num_rows_, (i + 1) * partition_size) - 1);
|
||||
output(i) = partition.SerializeAsString();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
Status InitializeTotalNumberOfRows() {
|
||||
total_num_rows_ = bigquery_table_accessor_->total_num_rows();
|
||||
if (total_num_rows_ <= 0) {
|
||||
return errors::FailedPrecondition("Invalid total number of rows.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status InitializeNumberOfPartitions(OpKernelConstruction* context) {
|
||||
TF_RETURN_IF_ERROR(context->GetAttr("num_partitions", &num_partitions_));
|
||||
if (num_partitions_ <= 0) {
|
||||
return errors::FailedPrecondition("Invalid number of partitions.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int64 num_partitions_;
|
||||
int64 total_num_rows_;
|
||||
std::unique_ptr<BigQueryTableAccessor> bigquery_table_accessor_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("GenerateBigQueryReaderPartitions").Device(DEVICE_CPU),
|
||||
GenerateBigQueryReaderPartitionsOp);
|
||||
|
||||
} // namespace tensorflow
|
@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/cloud/bigquery_table_accessor.h"
|
||||
|
||||
#include "tensorflow/core/example/feature.pb.h"
|
||||
@ -22,15 +23,6 @@ namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
constexpr size_t kBufferSize = 1024 * 1024; // In bytes.
|
||||
const string kBigQueryEndPoint = "https://www.googleapis.com/bigquery/v2";
|
||||
|
||||
bool IsPartitionEmpty(const BigQueryTablePartition& partition) {
|
||||
if (partition.end_index() != -1 &&
|
||||
partition.end_index() < partition.start_index()) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
Status ParseJson(StringPiece json, Json::Value* result) {
|
||||
Json::Reader reader;
|
||||
@ -100,18 +92,17 @@ Status ParseColumnType(const string& type,
|
||||
|
||||
Status BigQueryTableAccessor::New(
|
||||
const string& project_id, const string& dataset_id, const string& table_id,
|
||||
int64 timestamp_millis, int64 row_buffer_size, const string& end_point,
|
||||
const std::vector<string>& columns, const BigQueryTablePartition& partition,
|
||||
int64 timestamp_millis, int64 row_buffer_size,
|
||||
const std::set<string>& columns, const BigQueryTablePartition& partition,
|
||||
std::unique_ptr<BigQueryTableAccessor>* accessor) {
|
||||
return New(project_id, dataset_id, table_id, timestamp_millis,
|
||||
row_buffer_size, end_point, columns, partition, nullptr, nullptr,
|
||||
accessor);
|
||||
row_buffer_size, columns, partition, nullptr, nullptr, accessor);
|
||||
}
|
||||
|
||||
Status BigQueryTableAccessor::New(
|
||||
const string& project_id, const string& dataset_id, const string& table_id,
|
||||
int64 timestamp_millis, int64 row_buffer_size, const string& end_point,
|
||||
const std::vector<string>& columns, const BigQueryTablePartition& partition,
|
||||
int64 timestamp_millis, int64 row_buffer_size,
|
||||
const std::set<string>& columns, const BigQueryTablePartition& partition,
|
||||
std::unique_ptr<AuthProvider> auth_provider,
|
||||
std::unique_ptr<HttpRequest::Factory> http_request_factory,
|
||||
std::unique_ptr<BigQueryTableAccessor>* accessor) {
|
||||
@ -119,16 +110,14 @@ Status BigQueryTableAccessor::New(
|
||||
return errors::InvalidArgument(
|
||||
"Cannot use zero or negative timestamp to query a table.");
|
||||
}
|
||||
const string& big_query_end_point =
|
||||
end_point.empty() ? kBigQueryEndPoint : end_point;
|
||||
if (auth_provider == nullptr && http_request_factory == nullptr) {
|
||||
accessor->reset(new BigQueryTableAccessor(
|
||||
project_id, dataset_id, table_id, timestamp_millis, row_buffer_size,
|
||||
big_query_end_point, columns, partition));
|
||||
accessor->reset(new BigQueryTableAccessor(project_id, dataset_id, table_id,
|
||||
timestamp_millis, row_buffer_size,
|
||||
columns, partition));
|
||||
} else {
|
||||
accessor->reset(new BigQueryTableAccessor(
|
||||
project_id, dataset_id, table_id, timestamp_millis, row_buffer_size,
|
||||
big_query_end_point, columns, partition, std::move(auth_provider),
|
||||
columns, partition, std::move(auth_provider),
|
||||
std::move(http_request_factory)));
|
||||
}
|
||||
return (*accessor)->ReadSchema();
|
||||
@ -136,11 +125,11 @@ Status BigQueryTableAccessor::New(
|
||||
|
||||
BigQueryTableAccessor::BigQueryTableAccessor(
|
||||
const string& project_id, const string& dataset_id, const string& table_id,
|
||||
int64 timestamp_millis, int64 row_buffer_size, const string& end_point,
|
||||
const std::vector<string>& columns, const BigQueryTablePartition& partition)
|
||||
int64 timestamp_millis, int64 row_buffer_size,
|
||||
const std::set<string>& columns, const BigQueryTablePartition& partition)
|
||||
: BigQueryTableAccessor(
|
||||
project_id, dataset_id, table_id, timestamp_millis, row_buffer_size,
|
||||
end_point, columns, partition,
|
||||
columns, partition,
|
||||
std::unique_ptr<AuthProvider>(new GoogleAuthProvider()),
|
||||
std::unique_ptr<HttpRequest::Factory>(new HttpRequest::Factory())) {
|
||||
row_buffer_.resize(row_buffer_size);
|
||||
@ -148,16 +137,15 @@ BigQueryTableAccessor::BigQueryTableAccessor(
|
||||
|
||||
BigQueryTableAccessor::BigQueryTableAccessor(
|
||||
const string& project_id, const string& dataset_id, const string& table_id,
|
||||
int64 timestamp_millis, int64 row_buffer_size, const string& end_point,
|
||||
const std::vector<string>& columns, const BigQueryTablePartition& partition,
|
||||
int64 timestamp_millis, int64 row_buffer_size,
|
||||
const std::set<string>& columns, const BigQueryTablePartition& partition,
|
||||
std::unique_ptr<AuthProvider> auth_provider,
|
||||
std::unique_ptr<HttpRequest::Factory> http_request_factory)
|
||||
: project_id_(project_id),
|
||||
dataset_id_(dataset_id),
|
||||
table_id_(table_id),
|
||||
timestamp_millis_(timestamp_millis),
|
||||
columns_(columns.begin(), columns.end()),
|
||||
bigquery_end_point_(end_point),
|
||||
columns_(columns),
|
||||
partition_(partition),
|
||||
auth_provider_(std::move(auth_provider)),
|
||||
http_request_factory_(std::move(http_request_factory)) {
|
||||
@ -165,14 +153,10 @@ BigQueryTableAccessor::BigQueryTableAccessor(
|
||||
Reset();
|
||||
}
|
||||
|
||||
Status BigQueryTableAccessor::SetPartition(
|
||||
void BigQueryTableAccessor::SetPartition(
|
||||
const BigQueryTablePartition& partition) {
|
||||
if (partition.start_index() < 0) {
|
||||
return errors::InvalidArgument("Start index cannot be negative.");
|
||||
}
|
||||
partition_ = partition;
|
||||
Reset();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void BigQueryTableAccessor::Reset() {
|
||||
@ -188,8 +172,7 @@ Status BigQueryTableAccessor::ReadRow(int64* row_id, Example* example) {
|
||||
|
||||
// If the next row is already fetched and cached, return the row from the
|
||||
// buffer. Otherwise, fill up the row buffer from BigQuery and return a row.
|
||||
if (next_row_in_buffer_ != -1 &&
|
||||
next_row_in_buffer_ < ComputeMaxResultsArg()) {
|
||||
if (next_row_in_buffer_ != -1 && next_row_in_buffer_ < row_buffer_.size()) {
|
||||
*row_id = first_buffered_row_index_ + next_row_in_buffer_;
|
||||
*example = row_buffer_[next_row_in_buffer_];
|
||||
next_row_in_buffer_++;
|
||||
@ -207,12 +190,12 @@ Status BigQueryTableAccessor::ReadRow(int64* row_id, Example* example) {
|
||||
// we use the page token (which returns rows faster).
|
||||
if (!next_page_token_.empty()) {
|
||||
TF_RETURN_IF_ERROR(request->SetUri(strings::StrCat(
|
||||
BigQueryUriPrefix(), "data?maxResults=", ComputeMaxResultsArg(),
|
||||
BigQueryUriPrefix(), "data?maxResults=", row_buffer_.size(),
|
||||
"&pageToken=", request->EscapeString(next_page_token_))));
|
||||
first_buffered_row_index_ += row_buffer_.size();
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(request->SetUri(strings::StrCat(
|
||||
BigQueryUriPrefix(), "data?maxResults=", ComputeMaxResultsArg(),
|
||||
BigQueryUriPrefix(), "data?maxResults=", row_buffer_.size(),
|
||||
"&startIndex=", first_buffered_row_index_)));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(request->AddAuthBearerHeader(auth_token));
|
||||
@ -239,18 +222,6 @@ Status BigQueryTableAccessor::ReadRow(int64* row_id, Example* example) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int64 BigQueryTableAccessor::ComputeMaxResultsArg() {
|
||||
if (partition_.end_index() == -1) {
|
||||
return row_buffer_.size();
|
||||
}
|
||||
if (IsPartitionEmpty(partition_)) {
|
||||
return 0;
|
||||
}
|
||||
return std::min(static_cast<int64>(row_buffer_.size()),
|
||||
static_cast<int64>(partition_.end_index() -
|
||||
partition_.start_index() + 1));
|
||||
}
|
||||
|
||||
Status BigQueryTableAccessor::ParseColumnValues(
|
||||
const Json::Value& value, const SchemaNode& root_schema_node,
|
||||
Example* example) {
|
||||
@ -393,17 +364,21 @@ Status BigQueryTableAccessor::AppendValueToExample(
|
||||
|
||||
string BigQueryTableAccessor::BigQueryTableAccessor::BigQueryUriPrefix() {
|
||||
HttpRequest request;
|
||||
return strings::StrCat(bigquery_end_point_, "/projects/",
|
||||
return strings::StrCat("https://www.googleapis.com/bigquery/v2/projects/",
|
||||
request.EscapeString(project_id_), "/datasets/",
|
||||
request.EscapeString(dataset_id_), "/tables/",
|
||||
request.EscapeString(table_id_), "/");
|
||||
}
|
||||
|
||||
string BigQueryTableAccessor::FullTableName() {
|
||||
return strings::StrCat(project_id_, ":", dataset_id_, ".", table_id_, "@",
|
||||
timestamp_millis_);
|
||||
}
|
||||
|
||||
bool BigQueryTableAccessor::Done() {
|
||||
return (total_num_rows_ <= first_buffered_row_index_ + next_row_in_buffer_) ||
|
||||
IsPartitionEmpty(partition_) ||
|
||||
(partition_.end_index() != -1 &&
|
||||
partition_.end_index() <
|
||||
partition_.end_index() <=
|
||||
first_buffered_row_index_ + next_row_in_buffer_);
|
||||
}
|
||||
|
||||
|
@ -55,23 +55,16 @@ class BigQueryTableAccessor {
|
||||
};
|
||||
|
||||
/// \brief Creates a new BigQueryTableAccessor object.
|
||||
//
|
||||
// We do not allow relative (negative or zero) snapshot times here since we
|
||||
// want to have a consistent snapshot of the table for the lifetime of this
|
||||
// object.
|
||||
// Use end_point if you want to connect to a different end point than the
|
||||
// official BigQuery end point. Otherwise send an empty string.
|
||||
static Status New(const string& project_id, const string& dataset_id,
|
||||
const string& table_id, int64 timestamp_millis,
|
||||
int64 row_buffer_size, const string& end_point,
|
||||
const std::vector<string>& columns,
|
||||
int64 row_buffer_size, const std::set<string>& columns,
|
||||
const BigQueryTablePartition& partition,
|
||||
std::unique_ptr<BigQueryTableAccessor>* accessor);
|
||||
|
||||
/// \brief Starts reading a new partition.
|
||||
Status SetPartition(const BigQueryTablePartition& partition);
|
||||
void SetPartition(const BigQueryTablePartition& partition);
|
||||
|
||||
/// \brief Returns true if there are more rows available in the current
|
||||
/// \brief Returns false if there are more rows available in the current
|
||||
/// partition.
|
||||
bool Done();
|
||||
|
||||
@ -81,11 +74,9 @@ class BigQueryTableAccessor {
|
||||
/// in the BigQuery service.
|
||||
Status ReadRow(int64* row_id, Example* example);
|
||||
|
||||
/// \brief Returns total number of rows in the table.
|
||||
/// \brief Returns total number of rows.
|
||||
int64 total_num_rows() { return total_num_rows_; }
|
||||
|
||||
virtual ~BigQueryTableAccessor() {}
|
||||
|
||||
private:
|
||||
friend class BigQueryTableAccessorTest;
|
||||
|
||||
@ -104,8 +95,7 @@ class BigQueryTableAccessor {
|
||||
/// these two variables.
|
||||
static Status New(const string& project_id, const string& dataset_id,
|
||||
const string& table_id, int64 timestamp_millis,
|
||||
int64 row_buffer_size, const string& end_point,
|
||||
const std::vector<string>& columns,
|
||||
int64 row_buffer_size, const std::set<string>& columns,
|
||||
const BigQueryTablePartition& partition,
|
||||
std::unique_ptr<AuthProvider> auth_provider,
|
||||
std::unique_ptr<HttpRequest::Factory> http_request_factory,
|
||||
@ -114,16 +104,14 @@ class BigQueryTableAccessor {
|
||||
/// \brief Constructs an object for a given table and partition.
|
||||
BigQueryTableAccessor(const string& project_id, const string& dataset_id,
|
||||
const string& table_id, int64 timestamp_millis,
|
||||
int64 row_buffer_size, const string& end_point,
|
||||
const std::vector<string>& columns,
|
||||
int64 row_buffer_size, const std::set<string>& columns,
|
||||
const BigQueryTablePartition& partition);
|
||||
|
||||
/// Used for unit testing.
|
||||
BigQueryTableAccessor(
|
||||
const string& project_id, const string& dataset_id,
|
||||
const string& table_id, int64 timestamp_millis, int64 row_buffer_size,
|
||||
const string& end_point, const std::vector<string>& columns,
|
||||
const BigQueryTablePartition& partition,
|
||||
const std::set<string>& columns, const BigQueryTablePartition& partition,
|
||||
std::unique_ptr<AuthProvider> auth_provider,
|
||||
std::unique_ptr<HttpRequest::Factory> http_request_factory);
|
||||
|
||||
@ -144,7 +132,7 @@ class BigQueryTableAccessor {
|
||||
Status AppendValueToExample(const string& column_name,
|
||||
const Json::Value& column_value,
|
||||
const BigQueryTableAccessor::ColumnType type,
|
||||
Example* example);
|
||||
Example* ex);
|
||||
|
||||
/// \brief Resets internal counters for reading a partition.
|
||||
void Reset();
|
||||
@ -152,28 +140,25 @@ class BigQueryTableAccessor {
|
||||
/// \brief Helper function that returns BigQuery http endpoint prefix.
|
||||
string BigQueryUriPrefix();
|
||||
|
||||
/// \brief Computes the maxResults arg to send to BigQuery.
|
||||
int64 ComputeMaxResultsArg();
|
||||
|
||||
/// \brief Returns full name of the underlying table name.
|
||||
string FullTableName() {
|
||||
return strings::StrCat(project_id_, ":", dataset_id_, ".", table_id_, "@",
|
||||
timestamp_millis_);
|
||||
}
|
||||
string FullTableName();
|
||||
|
||||
const string project_id_;
|
||||
const string dataset_id_;
|
||||
const string table_id_;
|
||||
|
||||
// Snapshot timestamp.
|
||||
//
|
||||
// Indicates a snapshot of the table in milliseconds since the epoch.
|
||||
//
|
||||
// We do not allow relative (negative or zero) times here since we want to
|
||||
// have a consistent snapshot of the table for the lifetime of this object.
|
||||
// For more details, see 'Table Decorators' in BigQuery documentation.
|
||||
const int64 timestamp_millis_;
|
||||
|
||||
// Columns that should be read. Empty means all columns.
|
||||
const std::set<string> columns_;
|
||||
|
||||
// HTTP address of BigQuery end point to use.
|
||||
const string bigquery_end_point_;
|
||||
|
||||
// Describes the portion of the table that we are currently accessing.
|
||||
BigQueryTablePartition partition_;
|
||||
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr char kTestProject[] = "test-project";
|
||||
@ -68,10 +69,10 @@ class BigQueryTableAccessorTest : public ::testing::Test {
|
||||
Status CreateTableAccessor(const string& project_id, const string& dataset_id,
|
||||
const string& table_id, int64 timestamp_millis,
|
||||
int64 row_buffer_size,
|
||||
const std::vector<string>& columns,
|
||||
const std::set<string>& columns,
|
||||
const BigQueryTablePartition& partition) {
|
||||
return BigQueryTableAccessor::New(
|
||||
project_id, dataset_id, table_id, timestamp_millis, row_buffer_size, "",
|
||||
project_id, dataset_id, table_id, timestamp_millis, row_buffer_size,
|
||||
columns, partition, std::unique_ptr<AuthProvider>(new FakeAuthProvider),
|
||||
std::unique_ptr<HttpRequest::Factory>(
|
||||
new FakeHttpRequestFactory(&requests_)),
|
||||
@ -196,7 +197,7 @@ TEST_F(BigQueryTableAccessorTest, ReadOneRowTest) {
|
||||
kTestRow));
|
||||
BigQueryTablePartition partition;
|
||||
partition.set_start_index(2);
|
||||
partition.set_end_index(2);
|
||||
partition.set_end_index(3);
|
||||
TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 1,
|
||||
{}, partition));
|
||||
int64 row_id;
|
||||
@ -226,7 +227,7 @@ TEST_F(BigQueryTableAccessorTest, ReadOneRowPartialTest) {
|
||||
kTestRow));
|
||||
BigQueryTablePartition partition;
|
||||
partition.set_start_index(2);
|
||||
partition.set_end_index(2);
|
||||
partition.set_end_index(3);
|
||||
TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 1,
|
||||
{"bool_field", "rec_field.float_field"},
|
||||
partition));
|
||||
@ -257,7 +258,7 @@ TEST_F(BigQueryTableAccessorTest, ReadOneRowWithNullsTest) {
|
||||
kTestRowWithNulls));
|
||||
BigQueryTablePartition partition;
|
||||
partition.set_start_index(2);
|
||||
partition.set_end_index(2);
|
||||
partition.set_end_index(3);
|
||||
TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 1,
|
||||
{}, partition));
|
||||
int64 row_id;
|
||||
@ -287,7 +288,7 @@ TEST_F(BigQueryTableAccessorTest, BrokenRowTest) {
|
||||
kBrokenTestRow));
|
||||
BigQueryTablePartition partition;
|
||||
partition.set_start_index(2);
|
||||
partition.set_end_index(2);
|
||||
partition.set_end_index(3);
|
||||
TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 1,
|
||||
{}, partition));
|
||||
int64 row_id;
|
||||
@ -356,7 +357,7 @@ TEST_F(BigQueryTableAccessorTest, SwitchingPartitionsTest) {
|
||||
kSampleSchema));
|
||||
requests_.emplace_back(new FakeHttpRequest(
|
||||
"Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/"
|
||||
"datasets/test-dataset/tables/test-table/data?maxResults=1&startIndex=0\n"
|
||||
"datasets/test-dataset/tables/test-table/data?maxResults=2&startIndex=0\n"
|
||||
"Auth Token: fake_token\n",
|
||||
kTestTwoRows));
|
||||
requests_.emplace_back(new FakeHttpRequest(
|
||||
@ -373,7 +374,7 @@ TEST_F(BigQueryTableAccessorTest, SwitchingPartitionsTest) {
|
||||
|
||||
BigQueryTablePartition partition;
|
||||
partition.set_start_index(0);
|
||||
partition.set_end_index(0);
|
||||
partition.set_end_index(1);
|
||||
TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 2,
|
||||
{}, partition));
|
||||
|
||||
@ -395,7 +396,7 @@ TEST_F(BigQueryTableAccessorTest, SwitchingPartitionsTest) {
|
||||
1234);
|
||||
|
||||
partition.set_start_index(0);
|
||||
partition.set_end_index(1);
|
||||
partition.set_end_index(2);
|
||||
accessor_->SetPartition(partition);
|
||||
TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example));
|
||||
EXPECT_EQ(0, row_id);
|
||||
@ -409,23 +410,4 @@ TEST_F(BigQueryTableAccessorTest, SwitchingPartitionsTest) {
|
||||
2222);
|
||||
}
|
||||
|
||||
TEST_F(BigQueryTableAccessorTest, EmptyPartitionTest) {
|
||||
requests_.emplace_back(new FakeHttpRequest(
|
||||
"Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/"
|
||||
"datasets/test-dataset/tables/test-table/\n"
|
||||
"Auth Token: fake_token\n",
|
||||
kSampleSchema));
|
||||
|
||||
BigQueryTablePartition partition;
|
||||
partition.set_start_index(3);
|
||||
partition.set_end_index(2);
|
||||
TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 1,
|
||||
{}, partition));
|
||||
EXPECT_TRUE(accessor_->Done());
|
||||
|
||||
int64 row_id;
|
||||
Example example;
|
||||
EXPECT_TRUE(errors::IsOutOfRange(accessor_->ReadRow(&row_id, &example)));
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -1,88 +0,0 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
/* This file registers all cloud ops. */
|
||||
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
namespace tensorflow {
|
||||
|
||||
using shape_inference::InferenceContext;
|
||||
|
||||
REGISTER_OP("BigQueryReader")
|
||||
.Attr("container: string = ''")
|
||||
.Attr("shared_name: string = ''")
|
||||
.Attr("project_id: string")
|
||||
.Attr("dataset_id: string")
|
||||
.Attr("table_id: string")
|
||||
.Attr("columns: list(string)")
|
||||
.Attr("timestamp_millis: int")
|
||||
.Attr("test_end_point: string = ''")
|
||||
.Output("reader_handle: Ref(string)")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
c->set_output(0, c->Vector(2));
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
A Reader that outputs rows from a BigQuery table as tensorflow Examples.
|
||||
|
||||
container: If non-empty, this reader is placed in the given container.
|
||||
Otherwise, a default container is used.
|
||||
shared_name: If non-empty, this reader is named in the given bucket
|
||||
with this shared_name. Otherwise, the node name is used instead.
|
||||
project_id: GCP project ID.
|
||||
dataset_id: BigQuery Dataset ID.
|
||||
table_id: Table to read.
|
||||
columns: List of columns to read. Leave empty to read all columns.
|
||||
timestamp_millis: Table snapshot timestamp in millis since epoch. Relative
|
||||
(negative or zero) snapshot times are not allowed. For more details, see
|
||||
'Table Decorators' in BigQuery docs.
|
||||
test_end_point: Do not use. For testing purposes only.
|
||||
reader_handle: The handle to reference the Reader.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("GenerateBigQueryReaderPartitions")
|
||||
.Attr("project_id: string")
|
||||
.Attr("dataset_id: string")
|
||||
.Attr("table_id: string")
|
||||
.Attr("columns: list(string)")
|
||||
.Attr("timestamp_millis: int")
|
||||
.Attr("num_partitions: int")
|
||||
.Attr("test_end_point: string = ''")
|
||||
.Output("partitions: string")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
Generates serialized partition messages suitable for batch reads.
|
||||
|
||||
This op should not be used directly by clients. Instead, the
|
||||
bigquery_reader_ops.py file defines a clean interface to the reader.
|
||||
|
||||
project_id: GCP project ID.
|
||||
dataset_id: BigQuery Dataset ID.
|
||||
table_id: Table to read.
|
||||
columns: List of columns to read. Leave empty to read all columns.
|
||||
timestamp_millis: Table snapshot timestamp in millis since epoch. Relative
|
||||
(negative or zero) snapshot times are not allowed. For more details, see
|
||||
'Table Decorators' in BigQuery docs.
|
||||
num_partitions: Number of partitions to split the table into.
|
||||
test_end_point: Do not use. For testing purposes only.
|
||||
partitions: Serialized table partitions.
|
||||
)doc");
|
||||
|
||||
} // namespace tensorflow
|
@ -35,7 +35,6 @@ py_library(
|
||||
":check_ops",
|
||||
":client",
|
||||
":client_testlib",
|
||||
":cloud_ops",
|
||||
":confusion_matrix",
|
||||
":control_flow_ops",
|
||||
":errors",
|
||||
@ -790,11 +789,6 @@ tf_gen_op_wrapper_private_py(
|
||||
require_shape_functions = True,
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_private_py(
|
||||
name = "cloud_ops_gen",
|
||||
require_shape_functions = True,
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_private_py(
|
||||
name = "control_flow_ops_gen",
|
||||
require_shape_functions = True,
|
||||
@ -1437,19 +1431,6 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "cloud_ops",
|
||||
srcs = [
|
||||
"ops/cloud/__init__.py",
|
||||
"ops/cloud/bigquery_reader_ops.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":cloud_ops_gen",
|
||||
":framework",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "script_ops",
|
||||
srcs = ["ops/script_ops.py"],
|
||||
@ -2050,17 +2031,6 @@ cuda_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "bigquery_reader_ops_test",
|
||||
size = "small",
|
||||
srcs = ["ops/cloud/bigquery_reader_ops_test.py"],
|
||||
additional_deps = [
|
||||
":cloud_ops",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "training",
|
||||
srcs = glob(
|
||||
|
@ -83,7 +83,6 @@ from tensorflow.python.ops.standard_ops import *
|
||||
|
||||
# Bring in subpackages.
|
||||
from tensorflow.python.layers import layers
|
||||
from tensorflow.python.ops import cloud
|
||||
from tensorflow.python.ops import metrics
|
||||
from tensorflow.python.ops import nn
|
||||
from tensorflow.python.ops import sdca_ops as sdca
|
||||
@ -214,7 +213,6 @@ _allowed_symbols.extend([
|
||||
_allowed_symbols.extend([
|
||||
'app',
|
||||
'compat',
|
||||
'cloud',
|
||||
'errors',
|
||||
'flags',
|
||||
'gfile',
|
||||
|
@ -1,22 +0,0 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Import cloud ops."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# pylint: disable=wildcard-import
|
||||
from tensorflow.python.ops.cloud.bigquery_reader_ops import *
|
@ -1,150 +0,0 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""BigQuery reading support for TensorFlow."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import gen_cloud_ops
|
||||
from tensorflow.python.ops import io_ops
|
||||
|
||||
|
||||
class BigQueryReader(io_ops.ReaderBase):
|
||||
"""A Reader that outputs keys and tf.Example values from a BigQuery table.
|
||||
|
||||
Example use:
|
||||
```python
|
||||
# Assume a BigQuery has the following schema,
|
||||
# name STRING,
|
||||
# age INT,
|
||||
# state STRING
|
||||
|
||||
# Create the parse_examples list of features.
|
||||
features = dict(
|
||||
name=tf.FixedLenFeature([1], tf.string),
|
||||
age=tf.FixedLenFeature([1], tf.int32),
|
||||
state=tf.FixedLenFeature([1], dtype=tf.string, default_value="UNK"))
|
||||
|
||||
# Create a Reader.
|
||||
reader = bigquery_reader_ops.BigQueryReader(project_id=PROJECT,
|
||||
dataset_id=DATASET,
|
||||
table_id=TABLE,
|
||||
timestamp_millis=TIME,
|
||||
num_partitions=NUM_PARTITIONS,
|
||||
features=features)
|
||||
|
||||
# Populate a queue with the BigQuery Table partitions.
|
||||
queue = tf.training.string_input_producer(reader.partitions())
|
||||
|
||||
# Read and parse examples.
|
||||
row_id, examples_serialized = reader.read(queue)
|
||||
examples = tf.parse_example(examples_serialized, features=features)
|
||||
|
||||
# Process the Tensors examples["name"], examples["age"], etc...
|
||||
```
|
||||
|
||||
Note that to create a reader a snapshot timestamp is necessary. This
|
||||
will enable the reader to look at a consistent snapshot of the table.
|
||||
For more information, see 'Table Decorators' in BigQuery docs.
|
||||
|
||||
See ReaderBase for supported methods.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
project_id,
|
||||
dataset_id,
|
||||
table_id,
|
||||
timestamp_millis,
|
||||
num_partitions,
|
||||
features=None,
|
||||
columns=None,
|
||||
test_end_point=None,
|
||||
name=None):
|
||||
"""Creates a BigQueryReader.
|
||||
|
||||
Args:
|
||||
project_id: GCP project ID.
|
||||
dataset_id: BigQuery dataset ID.
|
||||
table_id: BigQuery table ID.
|
||||
timestamp_millis: timestamp to snapshot the table in milliseconds since
|
||||
the epoch. Relative (negative or zero) snapshot times are not allowed.
|
||||
For more details, see 'Table Decorators' in BigQuery docs.
|
||||
num_partitions: Number of non-overlapping partitions to read from.
|
||||
features: parse_example compatible dict from keys to `VarLenFeature` and
|
||||
`FixedLenFeature` objects. Keys are read as columns from the db.
|
||||
columns: list of columns to read, can be set iff features is None.
|
||||
test_end_point: Used only for testing purposes (optional).
|
||||
name: a name for the operation (optional).
|
||||
|
||||
Raises:
|
||||
TypeError: - If features is neither None nor a dict or
|
||||
- If columns is is neither None nor a list or
|
||||
- If both features and columns are None or set.
|
||||
"""
|
||||
if (features is None) == (columns is None):
|
||||
raise TypeError("exactly one of features and columns must be set.")
|
||||
|
||||
if features is not None:
|
||||
if not isinstance(features, dict):
|
||||
raise TypeError("features must be a dict.")
|
||||
self._columns = list(features.keys())
|
||||
elif columns is not None:
|
||||
if not isinstance(columns, list):
|
||||
raise TypeError("columns must be a list.")
|
||||
self._columns = columns
|
||||
|
||||
self._project_id = project_id
|
||||
self._dataset_id = dataset_id
|
||||
self._table_id = table_id
|
||||
self._timestamp_millis = timestamp_millis
|
||||
self._num_partitions = num_partitions
|
||||
self._test_end_point = test_end_point
|
||||
|
||||
reader = gen_cloud_ops.big_query_reader(
|
||||
name=name,
|
||||
project_id=self._project_id,
|
||||
dataset_id=self._dataset_id,
|
||||
table_id=self._table_id,
|
||||
timestamp_millis=self._timestamp_millis,
|
||||
columns=self._columns,
|
||||
test_end_point=self._test_end_point)
|
||||
super(BigQueryReader, self).__init__(reader)
|
||||
|
||||
def partitions(self, name=None):
|
||||
"""Returns serialized BigQueryTablePartition messages.
|
||||
|
||||
These messages represent a non-overlapping division of a table for a
|
||||
bulk read.
|
||||
|
||||
Args:
|
||||
name: a name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
`1-D` string `Tensor` of serialized `BigQueryTablePartition` messages.
|
||||
"""
|
||||
return gen_cloud_ops.generate_big_query_reader_partitions(
|
||||
name=name,
|
||||
project_id=self._project_id,
|
||||
dataset_id=self._dataset_id,
|
||||
table_id=self._table_id,
|
||||
timestamp_millis=self._timestamp_millis,
|
||||
num_partitions=self._num_partitions,
|
||||
test_end_point=self._test_end_point,
|
||||
columns=self._columns)
|
||||
|
||||
|
||||
ops.NotDifferentiable("BigQueryReader")
|
@ -1,274 +0,0 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for BigQueryReader Op."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import json
|
||||
import re
|
||||
import threading
|
||||
|
||||
from six.moves import SimpleHTTPServer
|
||||
from six.moves import socketserver
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.core.example import example_pb2
|
||||
from tensorflow.core.framework import types_pb2
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
_PROJECT = "test-project"
|
||||
_DATASET = "test-dataset"
|
||||
_TABLE = "test-table"
|
||||
# List representation of the test rows in the 'test-table' in BigQuery.
|
||||
# The schema for each row is: [int64, string, float].
|
||||
# The values for rows are generated such that some columns have null values. The
|
||||
# general formula here is:
|
||||
# - The int64 column is present in every row.
|
||||
# - The string column is only avaiable in even rows.
|
||||
# - The float column is only available in every third row.
|
||||
_ROWS = [[0, "s_0", 0.1], [1, None, None], [2, "s_2", None], [3, None, 3.1],
|
||||
[4, "s_4", None], [5, None, None], [6, "s_6", 6.1], [7, None, None],
|
||||
[8, "s_8", None], [9, None, 9.1]]
|
||||
# Schema for 'test-table'.
|
||||
# The schema currently has three columns: int64, string, and float
|
||||
_SCHEMA = {
|
||||
"kind": "bigquery#table",
|
||||
"id": "test-project:test-dataset.test-table",
|
||||
"schema": {
|
||||
"fields": [{
|
||||
"name": "int64_col",
|
||||
"type": "INTEGER",
|
||||
"mode": "NULLABLE"
|
||||
}, {
|
||||
"name": "string_col",
|
||||
"type": "STRING",
|
||||
"mode": "NULLABLE"
|
||||
}, {
|
||||
"name": "float_col",
|
||||
"type": "FLOAT",
|
||||
"mode": "NULLABLE"
|
||||
}]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def _ConvertRowToExampleProto(row):
|
||||
"""Converts the input row to an Example proto.
|
||||
|
||||
Args:
|
||||
row: Input Row instance.
|
||||
|
||||
Returns:
|
||||
An Example proto initialized with row values.
|
||||
"""
|
||||
|
||||
example = example_pb2.Example()
|
||||
example.features.feature["int64_col"].int64_list.value.append(row[0])
|
||||
if row[1] is not None:
|
||||
example.features.feature["string_col"].bytes_list.value.append(compat.as_bytes(row[1]))
|
||||
if row[2] is not None:
|
||||
example.features.feature["float_col"].float_list.value.append(row[2])
|
||||
return example
|
||||
|
||||
|
||||
class FakeBigQueryServer(threading.Thread):
|
||||
"""Fake http server to return schema and data for sample table."""
|
||||
|
||||
def __init__(self, address, port):
|
||||
"""Creates a FakeBigQueryServer.
|
||||
|
||||
Args:
|
||||
address: Server address
|
||||
port: Server port. Pass 0 to automatically pick an empty port.
|
||||
"""
|
||||
threading.Thread.__init__(self)
|
||||
self.handler = BigQueryRequestHandler
|
||||
self.httpd = socketserver.TCPServer((address, port), self.handler)
|
||||
|
||||
def run(self):
|
||||
self.httpd.serve_forever()
|
||||
|
||||
def shutdown(self):
|
||||
self.httpd.shutdown()
|
||||
self.httpd.socket.close()
|
||||
|
||||
|
||||
class BigQueryRequestHandler(SimpleHTTPServer.SimpleHTTPRequestHandler):
|
||||
"""Responds to BigQuery HTTP requests.
|
||||
|
||||
Attributes:
|
||||
num_rows: num_rows in the underlying table served by this class.
|
||||
"""
|
||||
|
||||
num_rows = 0
|
||||
|
||||
def do_GET(self):
|
||||
if "data?maxResults=" not in self.path:
|
||||
# This is a schema request.
|
||||
_SCHEMA["numRows"] = self.num_rows
|
||||
response = json.dumps(_SCHEMA)
|
||||
else:
|
||||
# This is a data request.
|
||||
#
|
||||
# Extract max results and start index.
|
||||
max_results = int(re.findall(r"maxResults=(\d+)", self.path)[0])
|
||||
start_index = int(re.findall(r"startIndex=(\d+)", self.path)[0])
|
||||
|
||||
# Send the rows as JSON.
|
||||
rows = []
|
||||
for row in _ROWS[start_index:start_index + max_results]:
|
||||
row_json = {
|
||||
"f": [{
|
||||
"v": str(row[0])
|
||||
}, {
|
||||
"v": str(row[1]) if row[1] is not None else None
|
||||
}, {
|
||||
"v": str(row[2]) if row[2] is not None else None
|
||||
}]
|
||||
}
|
||||
rows.append(row_json)
|
||||
response = json.dumps({
|
||||
"kind": "bigquery#table",
|
||||
"id": "test-project:test-dataset.test-table",
|
||||
"rows": rows
|
||||
})
|
||||
self.send_response(200)
|
||||
self.end_headers()
|
||||
self.wfile.write(compat.as_bytes(response))
|
||||
|
||||
|
||||
def _SetUpQueue(reader):
|
||||
"""Sets up a queue for a reader."""
|
||||
queue = tf.FIFOQueue(8, [types_pb2.DT_STRING], shapes=())
|
||||
key, value = reader.read(queue)
|
||||
queue.enqueue_many(reader.partitions()).run()
|
||||
queue.close().run()
|
||||
return key, value
|
||||
|
||||
|
||||
class BigQueryReaderOpsTest(tf.test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(BigQueryReaderOpsTest, self).setUp()
|
||||
self.server = FakeBigQueryServer("127.0.0.1", 0)
|
||||
self.server.start()
|
||||
logging.info("server address is %s:%s", self.server.httpd.server_address[0],
|
||||
self.server.httpd.server_address[1])
|
||||
|
||||
def tearDown(self):
|
||||
self.server.shutdown()
|
||||
super(BigQueryReaderOpsTest, self).tearDown()
|
||||
|
||||
def _ReadAndCheckRowsUsingFeatures(self, num_rows):
|
||||
self.server.handler.num_rows = num_rows
|
||||
|
||||
with self.test_session() as sess:
|
||||
feature_configs = {
|
||||
"int64_col":
|
||||
tf.FixedLenFeature(
|
||||
[1], dtype=tf.int64),
|
||||
"string_col":
|
||||
tf.FixedLenFeature(
|
||||
[1], dtype=tf.string, default_value="s_default"),
|
||||
}
|
||||
reader = tf.cloud.BigQueryReader(
|
||||
project_id=_PROJECT,
|
||||
dataset_id=_DATASET,
|
||||
table_id=_TABLE,
|
||||
num_partitions=4,
|
||||
features=feature_configs,
|
||||
timestamp_millis=1,
|
||||
test_end_point=("%s:%s" % (self.server.httpd.server_address[0],
|
||||
self.server.httpd.server_address[1])))
|
||||
|
||||
key, value = _SetUpQueue(reader)
|
||||
|
||||
seen_rows = []
|
||||
features = tf.parse_example(tf.reshape(value, [1]), feature_configs)
|
||||
for _ in range(num_rows):
|
||||
int_value, str_value = sess.run(
|
||||
[features["int64_col"], features["string_col"]])
|
||||
|
||||
# Parse values returned from the session.
|
||||
self.assertEqual(int_value.shape, (1, 1))
|
||||
self.assertEqual(str_value.shape, (1, 1))
|
||||
int64_col = int_value[0][0]
|
||||
string_col = str_value[0][0]
|
||||
seen_rows.append(int64_col)
|
||||
|
||||
# Compare.
|
||||
expected_row = _ROWS[int64_col]
|
||||
self.assertEqual(int64_col, expected_row[0])
|
||||
self.assertEqual(compat.as_str(string_col), ("s_%d" % int64_col) if expected_row[1]
|
||||
else "s_default")
|
||||
|
||||
self.assertItemsEqual(seen_rows, range(num_rows))
|
||||
|
||||
with self.assertRaisesOpError("is closed and has insufficient elements "
|
||||
"\\(requested 1, current size 0\\)"):
|
||||
sess.run([key, value])
|
||||
|
||||
def testReadingSingleRowUsingFeatures(self):
|
||||
self._ReadAndCheckRowsUsingFeatures(1)
|
||||
|
||||
def testReadingMultipleRowsUsingFeatures(self):
|
||||
self._ReadAndCheckRowsUsingFeatures(10)
|
||||
|
||||
def testReadingMultipleRowsUsingColumns(self):
|
||||
num_rows = 10
|
||||
self.server.handler.num_rows = num_rows
|
||||
|
||||
with self.test_session() as sess:
|
||||
reader = tf.cloud.BigQueryReader(
|
||||
project_id=_PROJECT,
|
||||
dataset_id=_DATASET,
|
||||
table_id=_TABLE,
|
||||
num_partitions=4,
|
||||
columns=["int64_col", "float_col", "string_col"],
|
||||
timestamp_millis=1,
|
||||
test_end_point=("%s:%s" % (self.server.httpd.server_address[0],
|
||||
self.server.httpd.server_address[1])))
|
||||
key, value = _SetUpQueue(reader)
|
||||
seen_rows = []
|
||||
for row_index in range(num_rows):
|
||||
returned_row_id, example_proto = sess.run([key, value])
|
||||
example = example_pb2.Example()
|
||||
example.ParseFromString(example_proto)
|
||||
self.assertIn("int64_col", example.features.feature)
|
||||
feature = example.features.feature["int64_col"]
|
||||
self.assertEqual(len(feature.int64_list.value), 1)
|
||||
int64_col = feature.int64_list.value[0]
|
||||
seen_rows.append(int64_col)
|
||||
|
||||
# Create our expected Example.
|
||||
expected_example = example_pb2.Example()
|
||||
expected_example = _ConvertRowToExampleProto(_ROWS[int64_col])
|
||||
|
||||
# Compare.
|
||||
self.assertProtoEquals(example, expected_example)
|
||||
self.assertEqual(row_index, int(returned_row_id))
|
||||
|
||||
self.assertItemsEqual(seen_rows, range(num_rows))
|
||||
|
||||
with self.assertRaisesOpError("is closed and has insufficient elements "
|
||||
"\\(requested 1, current size 0\\)"):
|
||||
sess.run([key, value])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
Loading…
Reference in New Issue
Block a user