Automated rollback of change 142681535

Change: 142694447
This commit is contained in:
A. Unique TensorFlower 2016-12-21 13:19:55 -08:00 committed by TensorFlower Gardener
parent e884a9cadb
commit e959223fd3
12 changed files with 52 additions and 900 deletions

View File

@ -450,23 +450,12 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "cloud_ops_op_lib",
srcs = ["ops/cloud_ops.cc"],
copts = tf_copts(),
linkstatic = 1,
visibility = ["//visibility:public"],
deps = [":framework"],
alwayslink = 1,
)
cc_library(
name = "ops",
visibility = ["//visibility:public"],
deps = [
":array_ops_op_lib",
":candidate_sampling_ops_op_lib",
":cloud_ops_op_lib",
":control_flow_ops_op_lib",
":ctc_ops_op_lib",
":data_flow_ops_op_lib",
@ -613,7 +602,6 @@ cc_library(
"//tensorflow/core/kernels:string",
"//tensorflow/core/kernels:training_ops",
"//tensorflow/core/kernels:word2vec_kernels",
"//tensorflow/core/kernels/cloud:bigquery_reader_ops",
] + if_not_windows([
"//tensorflow/core/kernels:fact_op",
"//tensorflow/core/kernels:array_not_windows",

View File

@ -9,7 +9,6 @@ licenses(["notice"]) # Apache 2.0
load(
"//tensorflow:tensorflow.bzl",
"tf_kernel_library",
"tf_cc_test",
)
@ -31,24 +30,6 @@ filegroup(
visibility = ["//tensorflow:__subpackages__"],
)
tf_kernel_library(
name = "bigquery_reader_ops",
srcs = [
"bigquery_reader_ops.cc",
],
visibility = ["//visibility:public"],
deps = [
":bigquery_table_accessor",
":bigquery_table_partition_proto_cc",
"//tensorflow/core:cloud_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/kernels:reader_base",
],
)
cc_library(
name = "bigquery_table_accessor",
srcs = [

View File

@ -1,193 +0,0 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <map>
#include <memory>
#include <set>
#include "tensorflow/core/example/example.pb.h"
#include "tensorflow/core/framework/reader_op_kernel.h"
#include "tensorflow/core/kernels/cloud/bigquery_table_accessor.h"
#include "tensorflow/core/kernels/cloud/bigquery_table_partition.pb.h"
#include "tensorflow/core/kernels/reader_base.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/math/math_util.h"
#include "tensorflow/core/lib/strings/numbers.h"
namespace tensorflow {
namespace {
constexpr int64 kDefaultRowBufferSize = 1000; // Number of rows to buffer.
// This is a helper function for reading table attributes from context.
Status GetTableAttrs(OpKernelConstruction* context, string* project_id,
string* dataset_id, string* table_id,
int64* timestamp_millis, std::vector<string>* columns,
string* test_end_point) {
TF_RETURN_IF_ERROR(context->GetAttr("project_id", project_id));
TF_RETURN_IF_ERROR(context->GetAttr("dataset_id", dataset_id));
TF_RETURN_IF_ERROR(context->GetAttr("table_id", table_id));
TF_RETURN_IF_ERROR(context->GetAttr("timestamp_millis", timestamp_millis));
TF_RETURN_IF_ERROR(context->GetAttr("columns", columns));
TF_RETURN_IF_ERROR(context->GetAttr("test_end_point", test_end_point));
return Status::OK();
}
} // namespace
// Note that overriden methods with names ending in "Locked" are called by
// ReaderBase while a mutex is held.
// See comments for ReaderBase.
class BigQueryReader : public ReaderBase {
public:
explicit BigQueryReader(BigQueryTableAccessor* bigquery_table_accessor,
const string& node_name)
: ReaderBase(strings::StrCat("BigQueryReader '", node_name, "'")),
bigquery_table_accessor_(CHECK_NOTNULL(bigquery_table_accessor)) {}
Status OnWorkStartedLocked() override {
BigQueryTablePartition partition;
if (!partition.ParseFromString(current_work())) {
return errors::InvalidArgument(
"Could not parse work as as valid partition.");
}
TF_RETURN_IF_ERROR(bigquery_table_accessor_->SetPartition(partition));
return Status::OK();
}
Status ReadLocked(string* key, string* value, bool* produced,
bool* at_end) override {
*at_end = false;
*produced = false;
if (bigquery_table_accessor_->Done()) {
*at_end = true;
return Status::OK();
}
Example example;
int64 row_id;
TF_RETURN_IF_ERROR(bigquery_table_accessor_->ReadRow(&row_id, &example));
*key = std::to_string(row_id);
*value = example.SerializeAsString();
*produced = true;
return Status::OK();
}
private:
// Not owned.
BigQueryTableAccessor* bigquery_table_accessor_;
};
class BigQueryReaderOp : public ReaderOpKernel {
public:
explicit BigQueryReaderOp(OpKernelConstruction* context)
: ReaderOpKernel(context) {
string table_id;
string project_id;
string dataset_id;
int64 timestamp_millis;
std::vector<string> columns;
string test_end_point;
OP_REQUIRES_OK(context,
GetTableAttrs(context, &project_id, &dataset_id, &table_id,
&timestamp_millis, &columns, &test_end_point));
OP_REQUIRES_OK(context,
BigQueryTableAccessor::New(
project_id, dataset_id, table_id, timestamp_millis,
kDefaultRowBufferSize, test_end_point, columns,
BigQueryTablePartition(), &bigquery_table_accessor_));
SetReaderFactory([this]() {
return new BigQueryReader(bigquery_table_accessor_.get(), name());
});
}
private:
std::unique_ptr<BigQueryTableAccessor> bigquery_table_accessor_;
};
REGISTER_KERNEL_BUILDER(Name("BigQueryReader").Device(DEVICE_CPU),
BigQueryReaderOp);
class GenerateBigQueryReaderPartitionsOp : public OpKernel {
public:
explicit GenerateBigQueryReaderPartitionsOp(OpKernelConstruction* context)
: OpKernel(context) {
string project_id;
string dataset_id;
string table_id;
int64 timestamp_millis;
std::vector<string> columns;
string test_end_point;
OP_REQUIRES_OK(context,
GetTableAttrs(context, &project_id, &dataset_id, &table_id,
&timestamp_millis, &columns, &test_end_point));
OP_REQUIRES_OK(context,
BigQueryTableAccessor::New(
project_id, dataset_id, table_id, timestamp_millis,
kDefaultRowBufferSize, test_end_point, columns,
BigQueryTablePartition(), &bigquery_table_accessor_));
OP_REQUIRES_OK(context, InitializeNumberOfPartitions(context));
OP_REQUIRES_OK(context, InitializeTotalNumberOfRows());
}
void Compute(OpKernelContext* context) override {
const int64 partition_size = tensorflow::MathUtil::CeilOfRatio<int64>(
total_num_rows_, num_partitions_);
Tensor* output_tensor = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(0, TensorShape({num_partitions_}),
&output_tensor));
auto output = output_tensor->template flat<string>();
for (int64 i = 0; i < num_partitions_; ++i) {
BigQueryTablePartition partition;
partition.set_start_index(i * partition_size);
partition.set_end_index(
std::min(total_num_rows_, (i + 1) * partition_size) - 1);
output(i) = partition.SerializeAsString();
}
}
private:
Status InitializeTotalNumberOfRows() {
total_num_rows_ = bigquery_table_accessor_->total_num_rows();
if (total_num_rows_ <= 0) {
return errors::FailedPrecondition("Invalid total number of rows.");
}
return Status::OK();
}
Status InitializeNumberOfPartitions(OpKernelConstruction* context) {
TF_RETURN_IF_ERROR(context->GetAttr("num_partitions", &num_partitions_));
if (num_partitions_ <= 0) {
return errors::FailedPrecondition("Invalid number of partitions.");
}
return Status::OK();
}
int64 num_partitions_;
int64 total_num_rows_;
std::unique_ptr<BigQueryTableAccessor> bigquery_table_accessor_;
};
REGISTER_KERNEL_BUILDER(
Name("GenerateBigQueryReaderPartitions").Device(DEVICE_CPU),
GenerateBigQueryReaderPartitionsOp);
} // namespace tensorflow

View File

@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/cloud/bigquery_table_accessor.h"
#include "tensorflow/core/example/feature.pb.h"
@ -22,15 +23,6 @@ namespace tensorflow {
namespace {
constexpr size_t kBufferSize = 1024 * 1024; // In bytes.
const string kBigQueryEndPoint = "https://www.googleapis.com/bigquery/v2";
bool IsPartitionEmpty(const BigQueryTablePartition& partition) {
if (partition.end_index() != -1 &&
partition.end_index() < partition.start_index()) {
return true;
}
return false;
}
Status ParseJson(StringPiece json, Json::Value* result) {
Json::Reader reader;
@ -100,18 +92,17 @@ Status ParseColumnType(const string& type,
Status BigQueryTableAccessor::New(
const string& project_id, const string& dataset_id, const string& table_id,
int64 timestamp_millis, int64 row_buffer_size, const string& end_point,
const std::vector<string>& columns, const BigQueryTablePartition& partition,
int64 timestamp_millis, int64 row_buffer_size,
const std::set<string>& columns, const BigQueryTablePartition& partition,
std::unique_ptr<BigQueryTableAccessor>* accessor) {
return New(project_id, dataset_id, table_id, timestamp_millis,
row_buffer_size, end_point, columns, partition, nullptr, nullptr,
accessor);
row_buffer_size, columns, partition, nullptr, nullptr, accessor);
}
Status BigQueryTableAccessor::New(
const string& project_id, const string& dataset_id, const string& table_id,
int64 timestamp_millis, int64 row_buffer_size, const string& end_point,
const std::vector<string>& columns, const BigQueryTablePartition& partition,
int64 timestamp_millis, int64 row_buffer_size,
const std::set<string>& columns, const BigQueryTablePartition& partition,
std::unique_ptr<AuthProvider> auth_provider,
std::unique_ptr<HttpRequest::Factory> http_request_factory,
std::unique_ptr<BigQueryTableAccessor>* accessor) {
@ -119,16 +110,14 @@ Status BigQueryTableAccessor::New(
return errors::InvalidArgument(
"Cannot use zero or negative timestamp to query a table.");
}
const string& big_query_end_point =
end_point.empty() ? kBigQueryEndPoint : end_point;
if (auth_provider == nullptr && http_request_factory == nullptr) {
accessor->reset(new BigQueryTableAccessor(
project_id, dataset_id, table_id, timestamp_millis, row_buffer_size,
big_query_end_point, columns, partition));
accessor->reset(new BigQueryTableAccessor(project_id, dataset_id, table_id,
timestamp_millis, row_buffer_size,
columns, partition));
} else {
accessor->reset(new BigQueryTableAccessor(
project_id, dataset_id, table_id, timestamp_millis, row_buffer_size,
big_query_end_point, columns, partition, std::move(auth_provider),
columns, partition, std::move(auth_provider),
std::move(http_request_factory)));
}
return (*accessor)->ReadSchema();
@ -136,11 +125,11 @@ Status BigQueryTableAccessor::New(
BigQueryTableAccessor::BigQueryTableAccessor(
const string& project_id, const string& dataset_id, const string& table_id,
int64 timestamp_millis, int64 row_buffer_size, const string& end_point,
const std::vector<string>& columns, const BigQueryTablePartition& partition)
int64 timestamp_millis, int64 row_buffer_size,
const std::set<string>& columns, const BigQueryTablePartition& partition)
: BigQueryTableAccessor(
project_id, dataset_id, table_id, timestamp_millis, row_buffer_size,
end_point, columns, partition,
columns, partition,
std::unique_ptr<AuthProvider>(new GoogleAuthProvider()),
std::unique_ptr<HttpRequest::Factory>(new HttpRequest::Factory())) {
row_buffer_.resize(row_buffer_size);
@ -148,16 +137,15 @@ BigQueryTableAccessor::BigQueryTableAccessor(
BigQueryTableAccessor::BigQueryTableAccessor(
const string& project_id, const string& dataset_id, const string& table_id,
int64 timestamp_millis, int64 row_buffer_size, const string& end_point,
const std::vector<string>& columns, const BigQueryTablePartition& partition,
int64 timestamp_millis, int64 row_buffer_size,
const std::set<string>& columns, const BigQueryTablePartition& partition,
std::unique_ptr<AuthProvider> auth_provider,
std::unique_ptr<HttpRequest::Factory> http_request_factory)
: project_id_(project_id),
dataset_id_(dataset_id),
table_id_(table_id),
timestamp_millis_(timestamp_millis),
columns_(columns.begin(), columns.end()),
bigquery_end_point_(end_point),
columns_(columns),
partition_(partition),
auth_provider_(std::move(auth_provider)),
http_request_factory_(std::move(http_request_factory)) {
@ -165,14 +153,10 @@ BigQueryTableAccessor::BigQueryTableAccessor(
Reset();
}
Status BigQueryTableAccessor::SetPartition(
void BigQueryTableAccessor::SetPartition(
const BigQueryTablePartition& partition) {
if (partition.start_index() < 0) {
return errors::InvalidArgument("Start index cannot be negative.");
}
partition_ = partition;
Reset();
return Status::OK();
}
void BigQueryTableAccessor::Reset() {
@ -188,8 +172,7 @@ Status BigQueryTableAccessor::ReadRow(int64* row_id, Example* example) {
// If the next row is already fetched and cached, return the row from the
// buffer. Otherwise, fill up the row buffer from BigQuery and return a row.
if (next_row_in_buffer_ != -1 &&
next_row_in_buffer_ < ComputeMaxResultsArg()) {
if (next_row_in_buffer_ != -1 && next_row_in_buffer_ < row_buffer_.size()) {
*row_id = first_buffered_row_index_ + next_row_in_buffer_;
*example = row_buffer_[next_row_in_buffer_];
next_row_in_buffer_++;
@ -207,12 +190,12 @@ Status BigQueryTableAccessor::ReadRow(int64* row_id, Example* example) {
// we use the page token (which returns rows faster).
if (!next_page_token_.empty()) {
TF_RETURN_IF_ERROR(request->SetUri(strings::StrCat(
BigQueryUriPrefix(), "data?maxResults=", ComputeMaxResultsArg(),
BigQueryUriPrefix(), "data?maxResults=", row_buffer_.size(),
"&pageToken=", request->EscapeString(next_page_token_))));
first_buffered_row_index_ += row_buffer_.size();
} else {
TF_RETURN_IF_ERROR(request->SetUri(strings::StrCat(
BigQueryUriPrefix(), "data?maxResults=", ComputeMaxResultsArg(),
BigQueryUriPrefix(), "data?maxResults=", row_buffer_.size(),
"&startIndex=", first_buffered_row_index_)));
}
TF_RETURN_IF_ERROR(request->AddAuthBearerHeader(auth_token));
@ -239,18 +222,6 @@ Status BigQueryTableAccessor::ReadRow(int64* row_id, Example* example) {
return Status::OK();
}
int64 BigQueryTableAccessor::ComputeMaxResultsArg() {
if (partition_.end_index() == -1) {
return row_buffer_.size();
}
if (IsPartitionEmpty(partition_)) {
return 0;
}
return std::min(static_cast<int64>(row_buffer_.size()),
static_cast<int64>(partition_.end_index() -
partition_.start_index() + 1));
}
Status BigQueryTableAccessor::ParseColumnValues(
const Json::Value& value, const SchemaNode& root_schema_node,
Example* example) {
@ -393,17 +364,21 @@ Status BigQueryTableAccessor::AppendValueToExample(
string BigQueryTableAccessor::BigQueryTableAccessor::BigQueryUriPrefix() {
HttpRequest request;
return strings::StrCat(bigquery_end_point_, "/projects/",
return strings::StrCat("https://www.googleapis.com/bigquery/v2/projects/",
request.EscapeString(project_id_), "/datasets/",
request.EscapeString(dataset_id_), "/tables/",
request.EscapeString(table_id_), "/");
}
string BigQueryTableAccessor::FullTableName() {
return strings::StrCat(project_id_, ":", dataset_id_, ".", table_id_, "@",
timestamp_millis_);
}
bool BigQueryTableAccessor::Done() {
return (total_num_rows_ <= first_buffered_row_index_ + next_row_in_buffer_) ||
IsPartitionEmpty(partition_) ||
(partition_.end_index() != -1 &&
partition_.end_index() <
partition_.end_index() <=
first_buffered_row_index_ + next_row_in_buffer_);
}

View File

@ -55,23 +55,16 @@ class BigQueryTableAccessor {
};
/// \brief Creates a new BigQueryTableAccessor object.
//
// We do not allow relative (negative or zero) snapshot times here since we
// want to have a consistent snapshot of the table for the lifetime of this
// object.
// Use end_point if you want to connect to a different end point than the
// official BigQuery end point. Otherwise send an empty string.
static Status New(const string& project_id, const string& dataset_id,
const string& table_id, int64 timestamp_millis,
int64 row_buffer_size, const string& end_point,
const std::vector<string>& columns,
int64 row_buffer_size, const std::set<string>& columns,
const BigQueryTablePartition& partition,
std::unique_ptr<BigQueryTableAccessor>* accessor);
/// \brief Starts reading a new partition.
Status SetPartition(const BigQueryTablePartition& partition);
void SetPartition(const BigQueryTablePartition& partition);
/// \brief Returns true if there are more rows available in the current
/// \brief Returns false if there are more rows available in the current
/// partition.
bool Done();
@ -81,11 +74,9 @@ class BigQueryTableAccessor {
/// in the BigQuery service.
Status ReadRow(int64* row_id, Example* example);
/// \brief Returns total number of rows in the table.
/// \brief Returns total number of rows.
int64 total_num_rows() { return total_num_rows_; }
virtual ~BigQueryTableAccessor() {}
private:
friend class BigQueryTableAccessorTest;
@ -104,8 +95,7 @@ class BigQueryTableAccessor {
/// these two variables.
static Status New(const string& project_id, const string& dataset_id,
const string& table_id, int64 timestamp_millis,
int64 row_buffer_size, const string& end_point,
const std::vector<string>& columns,
int64 row_buffer_size, const std::set<string>& columns,
const BigQueryTablePartition& partition,
std::unique_ptr<AuthProvider> auth_provider,
std::unique_ptr<HttpRequest::Factory> http_request_factory,
@ -114,16 +104,14 @@ class BigQueryTableAccessor {
/// \brief Constructs an object for a given table and partition.
BigQueryTableAccessor(const string& project_id, const string& dataset_id,
const string& table_id, int64 timestamp_millis,
int64 row_buffer_size, const string& end_point,
const std::vector<string>& columns,
int64 row_buffer_size, const std::set<string>& columns,
const BigQueryTablePartition& partition);
/// Used for unit testing.
BigQueryTableAccessor(
const string& project_id, const string& dataset_id,
const string& table_id, int64 timestamp_millis, int64 row_buffer_size,
const string& end_point, const std::vector<string>& columns,
const BigQueryTablePartition& partition,
const std::set<string>& columns, const BigQueryTablePartition& partition,
std::unique_ptr<AuthProvider> auth_provider,
std::unique_ptr<HttpRequest::Factory> http_request_factory);
@ -144,7 +132,7 @@ class BigQueryTableAccessor {
Status AppendValueToExample(const string& column_name,
const Json::Value& column_value,
const BigQueryTableAccessor::ColumnType type,
Example* example);
Example* ex);
/// \brief Resets internal counters for reading a partition.
void Reset();
@ -152,28 +140,25 @@ class BigQueryTableAccessor {
/// \brief Helper function that returns BigQuery http endpoint prefix.
string BigQueryUriPrefix();
/// \brief Computes the maxResults arg to send to BigQuery.
int64 ComputeMaxResultsArg();
/// \brief Returns full name of the underlying table name.
string FullTableName() {
return strings::StrCat(project_id_, ":", dataset_id_, ".", table_id_, "@",
timestamp_millis_);
}
string FullTableName();
const string project_id_;
const string dataset_id_;
const string table_id_;
// Snapshot timestamp.
//
// Indicates a snapshot of the table in milliseconds since the epoch.
//
// We do not allow relative (negative or zero) times here since we want to
// have a consistent snapshot of the table for the lifetime of this object.
// For more details, see 'Table Decorators' in BigQuery documentation.
const int64 timestamp_millis_;
// Columns that should be read. Empty means all columns.
const std::set<string> columns_;
// HTTP address of BigQuery end point to use.
const string bigquery_end_point_;
// Describes the portion of the table that we are currently accessing.
BigQueryTablePartition partition_;

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
constexpr char kTestProject[] = "test-project";
@ -68,10 +69,10 @@ class BigQueryTableAccessorTest : public ::testing::Test {
Status CreateTableAccessor(const string& project_id, const string& dataset_id,
const string& table_id, int64 timestamp_millis,
int64 row_buffer_size,
const std::vector<string>& columns,
const std::set<string>& columns,
const BigQueryTablePartition& partition) {
return BigQueryTableAccessor::New(
project_id, dataset_id, table_id, timestamp_millis, row_buffer_size, "",
project_id, dataset_id, table_id, timestamp_millis, row_buffer_size,
columns, partition, std::unique_ptr<AuthProvider>(new FakeAuthProvider),
std::unique_ptr<HttpRequest::Factory>(
new FakeHttpRequestFactory(&requests_)),
@ -196,7 +197,7 @@ TEST_F(BigQueryTableAccessorTest, ReadOneRowTest) {
kTestRow));
BigQueryTablePartition partition;
partition.set_start_index(2);
partition.set_end_index(2);
partition.set_end_index(3);
TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 1,
{}, partition));
int64 row_id;
@ -226,7 +227,7 @@ TEST_F(BigQueryTableAccessorTest, ReadOneRowPartialTest) {
kTestRow));
BigQueryTablePartition partition;
partition.set_start_index(2);
partition.set_end_index(2);
partition.set_end_index(3);
TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 1,
{"bool_field", "rec_field.float_field"},
partition));
@ -257,7 +258,7 @@ TEST_F(BigQueryTableAccessorTest, ReadOneRowWithNullsTest) {
kTestRowWithNulls));
BigQueryTablePartition partition;
partition.set_start_index(2);
partition.set_end_index(2);
partition.set_end_index(3);
TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 1,
{}, partition));
int64 row_id;
@ -287,7 +288,7 @@ TEST_F(BigQueryTableAccessorTest, BrokenRowTest) {
kBrokenTestRow));
BigQueryTablePartition partition;
partition.set_start_index(2);
partition.set_end_index(2);
partition.set_end_index(3);
TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 1,
{}, partition));
int64 row_id;
@ -356,7 +357,7 @@ TEST_F(BigQueryTableAccessorTest, SwitchingPartitionsTest) {
kSampleSchema));
requests_.emplace_back(new FakeHttpRequest(
"Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/"
"datasets/test-dataset/tables/test-table/data?maxResults=1&startIndex=0\n"
"datasets/test-dataset/tables/test-table/data?maxResults=2&startIndex=0\n"
"Auth Token: fake_token\n",
kTestTwoRows));
requests_.emplace_back(new FakeHttpRequest(
@ -373,7 +374,7 @@ TEST_F(BigQueryTableAccessorTest, SwitchingPartitionsTest) {
BigQueryTablePartition partition;
partition.set_start_index(0);
partition.set_end_index(0);
partition.set_end_index(1);
TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 2,
{}, partition));
@ -395,7 +396,7 @@ TEST_F(BigQueryTableAccessorTest, SwitchingPartitionsTest) {
1234);
partition.set_start_index(0);
partition.set_end_index(1);
partition.set_end_index(2);
accessor_->SetPartition(partition);
TF_EXPECT_OK(accessor_->ReadRow(&row_id, &example));
EXPECT_EQ(0, row_id);
@ -409,23 +410,4 @@ TEST_F(BigQueryTableAccessorTest, SwitchingPartitionsTest) {
2222);
}
TEST_F(BigQueryTableAccessorTest, EmptyPartitionTest) {
requests_.emplace_back(new FakeHttpRequest(
"Uri: https://www.googleapis.com/bigquery/v2/projects/test-project/"
"datasets/test-dataset/tables/test-table/\n"
"Auth Token: fake_token\n",
kSampleSchema));
BigQueryTablePartition partition;
partition.set_start_index(3);
partition.set_end_index(2);
TF_EXPECT_OK(CreateTableAccessor(kTestProject, kTestDataset, kTestTable, 1, 1,
{}, partition));
EXPECT_TRUE(accessor_->Done());
int64 row_id;
Example example;
EXPECT_TRUE(errors::IsOutOfRange(accessor_->ReadRow(&row_id, &example)));
}
} // namespace tensorflow

View File

@ -1,88 +0,0 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
/* This file registers all cloud ops. */
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
using shape_inference::InferenceContext;
REGISTER_OP("BigQueryReader")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
.Attr("project_id: string")
.Attr("dataset_id: string")
.Attr("table_id: string")
.Attr("columns: list(string)")
.Attr("timestamp_millis: int")
.Attr("test_end_point: string = ''")
.Output("reader_handle: Ref(string)")
.SetIsStateful()
.SetShapeFn([](InferenceContext* c) {
c->set_output(0, c->Vector(2));
return Status::OK();
})
.Doc(R"doc(
A Reader that outputs rows from a BigQuery table as tensorflow Examples.
container: If non-empty, this reader is placed in the given container.
Otherwise, a default container is used.
shared_name: If non-empty, this reader is named in the given bucket
with this shared_name. Otherwise, the node name is used instead.
project_id: GCP project ID.
dataset_id: BigQuery Dataset ID.
table_id: Table to read.
columns: List of columns to read. Leave empty to read all columns.
timestamp_millis: Table snapshot timestamp in millis since epoch. Relative
(negative or zero) snapshot times are not allowed. For more details, see
'Table Decorators' in BigQuery docs.
test_end_point: Do not use. For testing purposes only.
reader_handle: The handle to reference the Reader.
)doc");
REGISTER_OP("GenerateBigQueryReaderPartitions")
.Attr("project_id: string")
.Attr("dataset_id: string")
.Attr("table_id: string")
.Attr("columns: list(string)")
.Attr("timestamp_millis: int")
.Attr("num_partitions: int")
.Attr("test_end_point: string = ''")
.Output("partitions: string")
.SetShapeFn([](InferenceContext* c) {
c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
return Status::OK();
})
.Doc(R"doc(
Generates serialized partition messages suitable for batch reads.
This op should not be used directly by clients. Instead, the
bigquery_reader_ops.py file defines a clean interface to the reader.
project_id: GCP project ID.
dataset_id: BigQuery Dataset ID.
table_id: Table to read.
columns: List of columns to read. Leave empty to read all columns.
timestamp_millis: Table snapshot timestamp in millis since epoch. Relative
(negative or zero) snapshot times are not allowed. For more details, see
'Table Decorators' in BigQuery docs.
num_partitions: Number of partitions to split the table into.
test_end_point: Do not use. For testing purposes only.
partitions: Serialized table partitions.
)doc");
} // namespace tensorflow

View File

@ -35,7 +35,6 @@ py_library(
":check_ops",
":client",
":client_testlib",
":cloud_ops",
":confusion_matrix",
":control_flow_ops",
":errors",
@ -790,11 +789,6 @@ tf_gen_op_wrapper_private_py(
require_shape_functions = True,
)
tf_gen_op_wrapper_private_py(
name = "cloud_ops_gen",
require_shape_functions = True,
)
tf_gen_op_wrapper_private_py(
name = "control_flow_ops_gen",
require_shape_functions = True,
@ -1437,19 +1431,6 @@ py_library(
],
)
py_library(
name = "cloud_ops",
srcs = [
"ops/cloud/__init__.py",
"ops/cloud/bigquery_reader_ops.py",
],
srcs_version = "PY2AND3",
deps = [
":cloud_ops_gen",
":framework",
],
)
py_library(
name = "script_ops",
srcs = ["ops/script_ops.py"],
@ -2050,17 +2031,6 @@ cuda_py_test(
],
)
tf_py_test(
name = "bigquery_reader_ops_test",
size = "small",
srcs = ["ops/cloud/bigquery_reader_ops_test.py"],
additional_deps = [
":cloud_ops",
"//tensorflow:tensorflow_py",
"//tensorflow/python:util",
],
)
py_library(
name = "training",
srcs = glob(

View File

@ -83,7 +83,6 @@ from tensorflow.python.ops.standard_ops import *
# Bring in subpackages.
from tensorflow.python.layers import layers
from tensorflow.python.ops import cloud
from tensorflow.python.ops import metrics
from tensorflow.python.ops import nn
from tensorflow.python.ops import sdca_ops as sdca
@ -214,7 +213,6 @@ _allowed_symbols.extend([
_allowed_symbols.extend([
'app',
'compat',
'cloud',
'errors',
'flags',
'gfile',

View File

@ -1,22 +0,0 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Import cloud ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=wildcard-import
from tensorflow.python.ops.cloud.bigquery_reader_ops import *

View File

@ -1,150 +0,0 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""BigQuery reading support for TensorFlow."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_cloud_ops
from tensorflow.python.ops import io_ops
class BigQueryReader(io_ops.ReaderBase):
"""A Reader that outputs keys and tf.Example values from a BigQuery table.
Example use:
```python
# Assume a BigQuery has the following schema,
# name STRING,
# age INT,
# state STRING
# Create the parse_examples list of features.
features = dict(
name=tf.FixedLenFeature([1], tf.string),
age=tf.FixedLenFeature([1], tf.int32),
state=tf.FixedLenFeature([1], dtype=tf.string, default_value="UNK"))
# Create a Reader.
reader = bigquery_reader_ops.BigQueryReader(project_id=PROJECT,
dataset_id=DATASET,
table_id=TABLE,
timestamp_millis=TIME,
num_partitions=NUM_PARTITIONS,
features=features)
# Populate a queue with the BigQuery Table partitions.
queue = tf.training.string_input_producer(reader.partitions())
# Read and parse examples.
row_id, examples_serialized = reader.read(queue)
examples = tf.parse_example(examples_serialized, features=features)
# Process the Tensors examples["name"], examples["age"], etc...
```
Note that to create a reader a snapshot timestamp is necessary. This
will enable the reader to look at a consistent snapshot of the table.
For more information, see 'Table Decorators' in BigQuery docs.
See ReaderBase for supported methods.
"""
def __init__(self,
project_id,
dataset_id,
table_id,
timestamp_millis,
num_partitions,
features=None,
columns=None,
test_end_point=None,
name=None):
"""Creates a BigQueryReader.
Args:
project_id: GCP project ID.
dataset_id: BigQuery dataset ID.
table_id: BigQuery table ID.
timestamp_millis: timestamp to snapshot the table in milliseconds since
the epoch. Relative (negative or zero) snapshot times are not allowed.
For more details, see 'Table Decorators' in BigQuery docs.
num_partitions: Number of non-overlapping partitions to read from.
features: parse_example compatible dict from keys to `VarLenFeature` and
`FixedLenFeature` objects. Keys are read as columns from the db.
columns: list of columns to read, can be set iff features is None.
test_end_point: Used only for testing purposes (optional).
name: a name for the operation (optional).
Raises:
TypeError: - If features is neither None nor a dict or
- If columns is is neither None nor a list or
- If both features and columns are None or set.
"""
if (features is None) == (columns is None):
raise TypeError("exactly one of features and columns must be set.")
if features is not None:
if not isinstance(features, dict):
raise TypeError("features must be a dict.")
self._columns = list(features.keys())
elif columns is not None:
if not isinstance(columns, list):
raise TypeError("columns must be a list.")
self._columns = columns
self._project_id = project_id
self._dataset_id = dataset_id
self._table_id = table_id
self._timestamp_millis = timestamp_millis
self._num_partitions = num_partitions
self._test_end_point = test_end_point
reader = gen_cloud_ops.big_query_reader(
name=name,
project_id=self._project_id,
dataset_id=self._dataset_id,
table_id=self._table_id,
timestamp_millis=self._timestamp_millis,
columns=self._columns,
test_end_point=self._test_end_point)
super(BigQueryReader, self).__init__(reader)
def partitions(self, name=None):
"""Returns serialized BigQueryTablePartition messages.
These messages represent a non-overlapping division of a table for a
bulk read.
Args:
name: a name for the operation (optional).
Returns:
`1-D` string `Tensor` of serialized `BigQueryTablePartition` messages.
"""
return gen_cloud_ops.generate_big_query_reader_partitions(
name=name,
project_id=self._project_id,
dataset_id=self._dataset_id,
table_id=self._table_id,
timestamp_millis=self._timestamp_millis,
num_partitions=self._num_partitions,
test_end_point=self._test_end_point,
columns=self._columns)
ops.NotDifferentiable("BigQueryReader")

View File

@ -1,274 +0,0 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for BigQueryReader Op."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import re
import threading
from six.moves import SimpleHTTPServer
from six.moves import socketserver
import tensorflow as tf
from tensorflow.core.example import example_pb2
from tensorflow.core.framework import types_pb2
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat
_PROJECT = "test-project"
_DATASET = "test-dataset"
_TABLE = "test-table"
# List representation of the test rows in the 'test-table' in BigQuery.
# The schema for each row is: [int64, string, float].
# The values for rows are generated such that some columns have null values. The
# general formula here is:
# - The int64 column is present in every row.
# - The string column is only avaiable in even rows.
# - The float column is only available in every third row.
_ROWS = [[0, "s_0", 0.1], [1, None, None], [2, "s_2", None], [3, None, 3.1],
[4, "s_4", None], [5, None, None], [6, "s_6", 6.1], [7, None, None],
[8, "s_8", None], [9, None, 9.1]]
# Schema for 'test-table'.
# The schema currently has three columns: int64, string, and float
_SCHEMA = {
"kind": "bigquery#table",
"id": "test-project:test-dataset.test-table",
"schema": {
"fields": [{
"name": "int64_col",
"type": "INTEGER",
"mode": "NULLABLE"
}, {
"name": "string_col",
"type": "STRING",
"mode": "NULLABLE"
}, {
"name": "float_col",
"type": "FLOAT",
"mode": "NULLABLE"
}]
}
}
def _ConvertRowToExampleProto(row):
"""Converts the input row to an Example proto.
Args:
row: Input Row instance.
Returns:
An Example proto initialized with row values.
"""
example = example_pb2.Example()
example.features.feature["int64_col"].int64_list.value.append(row[0])
if row[1] is not None:
example.features.feature["string_col"].bytes_list.value.append(compat.as_bytes(row[1]))
if row[2] is not None:
example.features.feature["float_col"].float_list.value.append(row[2])
return example
class FakeBigQueryServer(threading.Thread):
"""Fake http server to return schema and data for sample table."""
def __init__(self, address, port):
"""Creates a FakeBigQueryServer.
Args:
address: Server address
port: Server port. Pass 0 to automatically pick an empty port.
"""
threading.Thread.__init__(self)
self.handler = BigQueryRequestHandler
self.httpd = socketserver.TCPServer((address, port), self.handler)
def run(self):
self.httpd.serve_forever()
def shutdown(self):
self.httpd.shutdown()
self.httpd.socket.close()
class BigQueryRequestHandler(SimpleHTTPServer.SimpleHTTPRequestHandler):
"""Responds to BigQuery HTTP requests.
Attributes:
num_rows: num_rows in the underlying table served by this class.
"""
num_rows = 0
def do_GET(self):
if "data?maxResults=" not in self.path:
# This is a schema request.
_SCHEMA["numRows"] = self.num_rows
response = json.dumps(_SCHEMA)
else:
# This is a data request.
#
# Extract max results and start index.
max_results = int(re.findall(r"maxResults=(\d+)", self.path)[0])
start_index = int(re.findall(r"startIndex=(\d+)", self.path)[0])
# Send the rows as JSON.
rows = []
for row in _ROWS[start_index:start_index + max_results]:
row_json = {
"f": [{
"v": str(row[0])
}, {
"v": str(row[1]) if row[1] is not None else None
}, {
"v": str(row[2]) if row[2] is not None else None
}]
}
rows.append(row_json)
response = json.dumps({
"kind": "bigquery#table",
"id": "test-project:test-dataset.test-table",
"rows": rows
})
self.send_response(200)
self.end_headers()
self.wfile.write(compat.as_bytes(response))
def _SetUpQueue(reader):
"""Sets up a queue for a reader."""
queue = tf.FIFOQueue(8, [types_pb2.DT_STRING], shapes=())
key, value = reader.read(queue)
queue.enqueue_many(reader.partitions()).run()
queue.close().run()
return key, value
class BigQueryReaderOpsTest(tf.test.TestCase):
def setUp(self):
super(BigQueryReaderOpsTest, self).setUp()
self.server = FakeBigQueryServer("127.0.0.1", 0)
self.server.start()
logging.info("server address is %s:%s", self.server.httpd.server_address[0],
self.server.httpd.server_address[1])
def tearDown(self):
self.server.shutdown()
super(BigQueryReaderOpsTest, self).tearDown()
def _ReadAndCheckRowsUsingFeatures(self, num_rows):
self.server.handler.num_rows = num_rows
with self.test_session() as sess:
feature_configs = {
"int64_col":
tf.FixedLenFeature(
[1], dtype=tf.int64),
"string_col":
tf.FixedLenFeature(
[1], dtype=tf.string, default_value="s_default"),
}
reader = tf.cloud.BigQueryReader(
project_id=_PROJECT,
dataset_id=_DATASET,
table_id=_TABLE,
num_partitions=4,
features=feature_configs,
timestamp_millis=1,
test_end_point=("%s:%s" % (self.server.httpd.server_address[0],
self.server.httpd.server_address[1])))
key, value = _SetUpQueue(reader)
seen_rows = []
features = tf.parse_example(tf.reshape(value, [1]), feature_configs)
for _ in range(num_rows):
int_value, str_value = sess.run(
[features["int64_col"], features["string_col"]])
# Parse values returned from the session.
self.assertEqual(int_value.shape, (1, 1))
self.assertEqual(str_value.shape, (1, 1))
int64_col = int_value[0][0]
string_col = str_value[0][0]
seen_rows.append(int64_col)
# Compare.
expected_row = _ROWS[int64_col]
self.assertEqual(int64_col, expected_row[0])
self.assertEqual(compat.as_str(string_col), ("s_%d" % int64_col) if expected_row[1]
else "s_default")
self.assertItemsEqual(seen_rows, range(num_rows))
with self.assertRaisesOpError("is closed and has insufficient elements "
"\\(requested 1, current size 0\\)"):
sess.run([key, value])
def testReadingSingleRowUsingFeatures(self):
self._ReadAndCheckRowsUsingFeatures(1)
def testReadingMultipleRowsUsingFeatures(self):
self._ReadAndCheckRowsUsingFeatures(10)
def testReadingMultipleRowsUsingColumns(self):
num_rows = 10
self.server.handler.num_rows = num_rows
with self.test_session() as sess:
reader = tf.cloud.BigQueryReader(
project_id=_PROJECT,
dataset_id=_DATASET,
table_id=_TABLE,
num_partitions=4,
columns=["int64_col", "float_col", "string_col"],
timestamp_millis=1,
test_end_point=("%s:%s" % (self.server.httpd.server_address[0],
self.server.httpd.server_address[1])))
key, value = _SetUpQueue(reader)
seen_rows = []
for row_index in range(num_rows):
returned_row_id, example_proto = sess.run([key, value])
example = example_pb2.Example()
example.ParseFromString(example_proto)
self.assertIn("int64_col", example.features.feature)
feature = example.features.feature["int64_col"]
self.assertEqual(len(feature.int64_list.value), 1)
int64_col = feature.int64_list.value[0]
seen_rows.append(int64_col)
# Create our expected Example.
expected_example = example_pb2.Example()
expected_example = _ConvertRowToExampleProto(_ROWS[int64_col])
# Compare.
self.assertProtoEquals(example, expected_example)
self.assertEqual(row_index, int(returned_row_id))
self.assertItemsEqual(seen_rows, range(num_rows))
with self.assertRaisesOpError("is closed and has insufficient elements "
"\\(requested 1, current size 0\\)"):
sess.run([key, value])
if __name__ == "__main__":
tf.test.main()