Migrate tensorflow/contrib/bigtable to use cloud/bigtable/table.h instead of cloud/bigtable/internal/table.h
PiperOrigin-RevId: 245271315
This commit is contained in:
parent
113c1ee358
commit
6b8c6cb57f
@ -14,7 +14,6 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
|
||||
@ -262,16 +261,16 @@ class ToBigtableOp : public AsyncOpKernel {
|
||||
}
|
||||
components.clear();
|
||||
}
|
||||
grpc::Status mutation_status;
|
||||
::google::cloud::Status mutation_status;
|
||||
std::vector<::google::cloud::bigtable::FailedMutation> failures =
|
||||
resource->table().BulkApply(mutation, mutation_status);
|
||||
if (!mutation_status.ok()) {
|
||||
LOG(ERROR) << "Failure applying mutation: "
|
||||
<< mutation_status.error_code() << " - "
|
||||
<< mutation_status.error_message() << " ("
|
||||
<< mutation_status.error_details() << ").";
|
||||
}
|
||||
resource->table().BulkApply(mutation);
|
||||
if (!failures.empty()) {
|
||||
mutation_status = failures.front().status();
|
||||
if (!mutation_status.ok()) {
|
||||
LOG(ERROR) << "Failure applying mutation: "
|
||||
<< mutation_status.code() << " - "
|
||||
<< mutation_status.message() << ".";
|
||||
}
|
||||
::google::bigtable::v2::MutateRowsRequest request;
|
||||
mutation.MoveTo(&request);
|
||||
for (const auto& failure : failures) {
|
||||
@ -282,12 +281,11 @@ class ToBigtableOp : public AsyncOpKernel {
|
||||
}
|
||||
}
|
||||
OP_REQUIRES_ASYNC(
|
||||
ctx, failures.empty() && mutation_status.ok(),
|
||||
ctx, failures.empty(),
|
||||
errors::Unknown("Failure while writing to Cloud Bigtable: ",
|
||||
mutation_status.error_code(), " - ",
|
||||
mutation_status.error_message(), " (",
|
||||
mutation_status.error_details(),
|
||||
"), # of mutation failures: ", failures.size(),
|
||||
mutation_status.code(), " - ",
|
||||
mutation_status.message(),
|
||||
"; # of mutation failures: ", failures.size(),
|
||||
". See the log for the specific error details."),
|
||||
done);
|
||||
} while (!end_of_sequence);
|
||||
|
@ -16,22 +16,6 @@ limitations under the License.
|
||||
#include "tensorflow/contrib/bigtable/kernels/bigtable_lib.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
Status GrpcStatusToTfStatus(const ::grpc::Status& status) {
|
||||
if (status.ok()) {
|
||||
return Status::OK();
|
||||
}
|
||||
auto grpc_code = status.error_code();
|
||||
if (status.error_code() == ::grpc::StatusCode::ABORTED ||
|
||||
status.error_code() == ::grpc::StatusCode::UNAVAILABLE ||
|
||||
status.error_code() == ::grpc::StatusCode::OUT_OF_RANGE) {
|
||||
grpc_code = ::grpc::StatusCode::INTERNAL;
|
||||
}
|
||||
return Status(static_cast<::tensorflow::error::Code>(grpc_code),
|
||||
strings::StrCat("Error reading from Cloud Bigtable: ",
|
||||
status.error_message()));
|
||||
}
|
||||
|
||||
namespace {
|
||||
::tensorflow::error::Code GcpErrorCodeToTfErrorCode(
|
||||
::google::cloud::StatusCode code) {
|
||||
|
@ -16,16 +16,13 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_LIB_H_
|
||||
#define TENSORFLOW_CONTRIB_BIGTABLE_KERNELS_BIGTABLE_LIB_H_
|
||||
|
||||
// Note: we use bigtable/client/internal/table.h as this is the no-exception API
|
||||
|
||||
#include "google/cloud/bigtable/data_client.h"
|
||||
#include "google/cloud/bigtable/internal/table.h"
|
||||
#include "google/cloud/bigtable/table.h"
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
Status GrpcStatusToTfStatus(const ::grpc::Status& status);
|
||||
Status GcpStatusToTfStatus(const ::google::cloud::Status& status);
|
||||
|
||||
string RegexFromStringSet(const std::vector<string>& strs);
|
||||
@ -66,7 +63,7 @@ class BigtableTableResource : public ResourceBase {
|
||||
|
||||
~BigtableTableResource() override { client_->Unref(); }
|
||||
|
||||
::google::cloud::bigtable::noex::Table& table() { return table_; }
|
||||
::google::cloud::bigtable::Table& table() { return table_; }
|
||||
|
||||
string DebugString() const override {
|
||||
return strings::StrCat(
|
||||
@ -77,7 +74,7 @@ class BigtableTableResource : public ResourceBase {
|
||||
private:
|
||||
BigtableClientResource* client_; // Ownes one ref.
|
||||
const string table_name_;
|
||||
::google::cloud::bigtable::noex::Table table_;
|
||||
::google::cloud::bigtable::Table table_;
|
||||
};
|
||||
|
||||
namespace data {
|
||||
|
@ -152,18 +152,19 @@ class BigtableLookupDatasetOp : public UnaryDatasetOpKernel {
|
||||
}
|
||||
if (input_tensors[0].NumElements() == 1) {
|
||||
// Single key lookup.
|
||||
::google::cloud::Status status;
|
||||
auto pair = dataset()->table_->table().ReadRow(
|
||||
input_tensors[0].scalar<string>()(), dataset()->filter_, status);
|
||||
if (!status.ok()) {
|
||||
return GcpStatusToTfStatus(status);
|
||||
::google::cloud::StatusOr<
|
||||
std::pair<bool, ::google::cloud::bigtable::Row>>
|
||||
row = dataset()->table_->table().ReadRow(
|
||||
input_tensors[0].scalar<string>()(), dataset()->filter_);
|
||||
if (!row.ok()) {
|
||||
return GcpStatusToTfStatus(row.status());
|
||||
}
|
||||
if (!pair.first) {
|
||||
if (!row->first) {
|
||||
return errors::DataLoss("Row key '",
|
||||
input_tensors[0].scalar<string>()(),
|
||||
"' not found.");
|
||||
}
|
||||
TF_RETURN_IF_ERROR(ParseRow(ctx, pair.second, out_tensors));
|
||||
TF_RETURN_IF_ERROR(ParseRow(ctx, row->second, out_tensors));
|
||||
} else {
|
||||
// Batched get.
|
||||
return errors::Unimplemented(
|
||||
|
@ -125,15 +125,15 @@ class BigtableSampleKeyPairsDatasetOp : public DatasetOpKernel {
|
||||
// ensure we don't accidentally miss any subsets of the requested range by
|
||||
// including `begin_key()` and `end_key()` as appropriate.
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
grpc::Status status;
|
||||
std::vector<google::cloud::bigtable::RowKeySample> row_keys =
|
||||
dataset()->table().table().SampleRows(status);
|
||||
if (!status.ok()) {
|
||||
return GrpcStatusToTfStatus(status);
|
||||
::google::cloud::StatusOr<
|
||||
std::vector<::google::cloud::bigtable::RowKeySample>>
|
||||
row_key_samples = dataset()->table().table().SampleRows();
|
||||
if (!row_key_samples.ok()) {
|
||||
return GcpStatusToTfStatus(row_key_samples.status());
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < row_keys.size(); ++i) {
|
||||
string row_key(row_keys[i].row_key);
|
||||
for (const auto& row_key_sample : *row_key_samples) {
|
||||
string row_key(row_key_sample.row_key);
|
||||
if (dataset()->key_range_.contains_key(row_key)) {
|
||||
// First key: check to see if we need to add the begin_key.
|
||||
if (keys_.empty() && dataset()->key_range_.begin_key() != row_key) {
|
||||
|
@ -80,12 +80,14 @@ class BigtableSampleKeysDatasetOp : public DatasetOpKernel {
|
||||
: DatasetIterator<Dataset>(params) {}
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
::grpc::Status status;
|
||||
row_keys_ = dataset()->table()->table().SampleRows(status);
|
||||
if (!status.ok()) {
|
||||
::google::cloud::StatusOr<
|
||||
std::vector<::google::cloud::bigtable::RowKeySample>>
|
||||
sampled_rows = dataset()->table()->table().SampleRows();
|
||||
if (!sampled_rows.ok()) {
|
||||
row_keys_.clear();
|
||||
return GrpcStatusToTfStatus(status);
|
||||
return GcpStatusToTfStatus(sampled_rows.status());
|
||||
}
|
||||
row_keys_ = std::move(*sampled_rows);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user