Remove AWS/S3 related code from tensorflow/contrib.
These have oved to github.com/tensorflow/io PiperOrigin-RevId: 266454456
This commit is contained in:
parent
ed9dc39b0b
commit
37479ad906
@ -5,7 +5,6 @@
|
||||
/tenosrflow/core/debug @caisq
|
||||
/tensorflow/core/nccl/ @azaks2 @chsigg
|
||||
/tensorflow/core/platform/windows/ @mrry
|
||||
/tensorflow/core/platform/s3 @yongtang
|
||||
/tensorflow/python/autograph/ @mdanatg @kkimdev
|
||||
/tensorflow/python/debug @caisq
|
||||
/tensorflow/python/eager @jaingaurav @alextp
|
||||
@ -38,7 +37,6 @@
|
||||
/tensorflow/contrib/hvx/ @satok16
|
||||
/tensorflow/contrib/integrate/ @shoyer
|
||||
/tensorflow/contrib/kernel_methods/ @petrosmol
|
||||
/tensorflow/contrib/kinesis @yongtang
|
||||
/tensorflow/contrib/ios_examples/ @petewarden
|
||||
/tensorflow/contrib/labeled_tensor/ @shoyer
|
||||
/tensorflow/contrib/layers/ @fchollet @martinwicke
|
||||
|
@ -108,15 +108,6 @@ py_library(
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/estimator:estimator_py",
|
||||
] + select({
|
||||
"//tensorflow:android": [],
|
||||
"//tensorflow:ios": [],
|
||||
"//tensorflow:linux_s390x": [],
|
||||
"//tensorflow:windows": [],
|
||||
"//tensorflow:no_aws_support": [],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/contrib/kinesis",
|
||||
],
|
||||
}) + select({
|
||||
"//tensorflow:android": [],
|
||||
"//tensorflow:ios": [],
|
||||
"//tensorflow:linux_s390x": [],
|
||||
@ -165,16 +156,7 @@ cc_library(
|
||||
"//tensorflow/contrib/tensor_forest:stats_ops_kernels",
|
||||
"//tensorflow/contrib/tensor_forest:tensor_forest_kernels",
|
||||
"//tensorflow/contrib/text:all_kernels",
|
||||
] + select({
|
||||
"//tensorflow:android": [],
|
||||
"//tensorflow:ios": [],
|
||||
"//tensorflow:linux_s390x": [],
|
||||
"//tensorflow:windows": [],
|
||||
"//tensorflow:no_aws_support": [],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/contrib/kinesis:dataset_kernels",
|
||||
],
|
||||
}) + if_not_windows([
|
||||
] + if_not_windows([
|
||||
"//tensorflow/contrib/tensorrt:trt_op_kernels",
|
||||
]),
|
||||
)
|
||||
@ -196,16 +178,7 @@ cc_library(
|
||||
"//tensorflow/contrib/tensor_forest:stats_ops_op_lib",
|
||||
"//tensorflow/contrib/tensor_forest:tensor_forest_ops_op_lib",
|
||||
"//tensorflow/contrib/text:all_ops",
|
||||
] + select({
|
||||
"//tensorflow:android": [],
|
||||
"//tensorflow:ios": [],
|
||||
"//tensorflow:linux_s390x": [],
|
||||
"//tensorflow:windows": [],
|
||||
"//tensorflow:no_aws_support": [],
|
||||
"//conditions:default": [
|
||||
"//tensorflow/contrib/kinesis:dataset_ops_op_lib",
|
||||
],
|
||||
}) + if_not_windows([
|
||||
] + if_not_windows([
|
||||
"//tensorflow/compiler/tf2tensorrt:trt_op_libs",
|
||||
]) + select({
|
||||
"//tensorflow:android": [],
|
||||
|
@ -1,113 +0,0 @@
|
||||
package(default_visibility = ["//tensorflow:internal"])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_custom_op_library",
|
||||
"tf_custom_op_py_library",
|
||||
"tf_gen_op_libs",
|
||||
"tf_gen_op_wrapper_py",
|
||||
"tf_kernel_library",
|
||||
"tf_py_test",
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "kinesis",
|
||||
srcs = ["__init__.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":dataset_ops",
|
||||
],
|
||||
)
|
||||
|
||||
tf_custom_op_library(
|
||||
name = "_dataset_ops.so",
|
||||
srcs = ["ops/dataset_ops.cc"],
|
||||
deps = [":dataset_kernels"],
|
||||
)
|
||||
|
||||
tf_gen_op_libs(
|
||||
op_lib_names = ["dataset_ops"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "dataset_kernels",
|
||||
srcs = [
|
||||
"kernels/kinesis_dataset_ops.cc",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
"//tensorflow/core/platform/s3:aws_crypto",
|
||||
"//third_party/eigen3",
|
||||
"@aws",
|
||||
"@com_google_protobuf//:protobuf_headers",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "dataset_ops",
|
||||
srcs = [
|
||||
"python/ops/kinesis_dataset_ops.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":kinesis_op_loader",
|
||||
"//tensorflow/python:dataset_ops_gen",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/util:nest",
|
||||
],
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_py(
|
||||
name = "gen_dataset_ops",
|
||||
out = "python/ops/gen_dataset_ops.py",
|
||||
deps = ["//tensorflow/contrib/kinesis:dataset_ops_op_lib"],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "dataset_ops_kernels",
|
||||
deps = [
|
||||
":dataset_kernels",
|
||||
"//tensorflow/core:framework",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_custom_op_py_library(
|
||||
name = "kinesis_op_loader",
|
||||
srcs = ["python/ops/kinesis_op_loader.py"],
|
||||
dso = ["//tensorflow/contrib/kinesis:_dataset_ops.so"],
|
||||
kernels = [
|
||||
":dataset_ops_kernels",
|
||||
"//tensorflow/contrib/kinesis:dataset_ops_op_lib",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":gen_dataset_ops",
|
||||
"//tensorflow/contrib/util:util_py",
|
||||
"//tensorflow/python:platform",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "kinesis_test",
|
||||
srcs = ["python/kernel_tests/kinesis_test.py"],
|
||||
additional_deps = [
|
||||
":kinesis",
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
tags = [
|
||||
"manual",
|
||||
"no_windows",
|
||||
"notap",
|
||||
],
|
||||
)
|
@ -1,32 +0,0 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Kinesis Dataset.
|
||||
|
||||
@@KinesisDataset
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.kinesis.python.ops.kinesis_dataset_ops import KinesisDataset
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
_allowed_symbols = [
|
||||
"KinesisDataset",
|
||||
]
|
||||
|
||||
remove_undocumented(__name__)
|
@ -1,362 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <aws/core/Aws.h>
|
||||
#include <aws/core/config/AWSProfileConfigLoader.h>
|
||||
#include <aws/core/utils/Outcome.h>
|
||||
#include <aws/kinesis/KinesisClient.h>
|
||||
#include <aws/kinesis/model/DescribeStreamRequest.h>
|
||||
#include <aws/kinesis/model/GetRecordsRequest.h>
|
||||
#include <aws/kinesis/model/GetShardIteratorRequest.h>
|
||||
#include <aws/kinesis/model/PutRecordsRequest.h>
|
||||
#include <aws/kinesis/model/ShardIteratorType.h>
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/platform/s3/aws_crypto.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
Aws::Client::ClientConfiguration* InitializeDefaultClientConfig() {
|
||||
static Aws::Client::ClientConfiguration config;
|
||||
const char* endpoint = getenv("KINESIS_ENDPOINT");
|
||||
if (endpoint) {
|
||||
config.endpointOverride = Aws::String(endpoint);
|
||||
}
|
||||
const char* region = getenv("AWS_REGION");
|
||||
if (region) {
|
||||
config.region = Aws::String(region);
|
||||
} else {
|
||||
// Load config file (e.g., ~/.aws/config) only if AWS_SDK_LOAD_CONFIG
|
||||
// is set with a truthy value.
|
||||
const char* load_config_env = getenv("AWS_SDK_LOAD_CONFIG");
|
||||
string load_config =
|
||||
load_config_env ? absl::AsciiStrToLower(load_config_env) : "";
|
||||
if (load_config == "true" || load_config == "1") {
|
||||
Aws::String config_file;
|
||||
// If AWS_CONFIG_FILE is set then use it, otherwise use ~/.aws/config.
|
||||
const char* config_file_env = getenv("AWS_CONFIG_FILE");
|
||||
if (config_file_env) {
|
||||
config_file = config_file_env;
|
||||
} else {
|
||||
const char* home_env = getenv("HOME");
|
||||
if (home_env) {
|
||||
config_file = home_env;
|
||||
config_file += "/.aws/config";
|
||||
}
|
||||
}
|
||||
Aws::Config::AWSConfigFileProfileConfigLoader loader(config_file);
|
||||
// Load the configuration. If successful, get the region.
|
||||
// If the load is not successful, then generate a warning.
|
||||
if (loader.Load()) {
|
||||
auto profiles = loader.GetProfiles();
|
||||
if (!profiles["default"].GetRegion().empty()) {
|
||||
config.region = profiles["default"].GetRegion();
|
||||
}
|
||||
} else {
|
||||
LOG(WARNING) << "Failed to load the profile in " << config_file << ".";
|
||||
}
|
||||
}
|
||||
}
|
||||
const char* use_https = getenv("KINESIS_USE_HTTPS");
|
||||
if (use_https) {
|
||||
if (use_https[0] == '0') {
|
||||
config.scheme = Aws::Http::Scheme::HTTP;
|
||||
} else {
|
||||
config.scheme = Aws::Http::Scheme::HTTPS;
|
||||
}
|
||||
}
|
||||
const char* verify_ssl = getenv("KINESIS_VERIFY_SSL");
|
||||
if (verify_ssl) {
|
||||
if (verify_ssl[0] == '0') {
|
||||
config.verifySSL = false;
|
||||
} else {
|
||||
config.verifySSL = true;
|
||||
}
|
||||
}
|
||||
const char* connect_timeout = getenv("KINESIS_CONNECT_TIMEOUT_MSEC");
|
||||
if (connect_timeout) {
|
||||
int64 timeout;
|
||||
|
||||
if (strings::safe_strto64(connect_timeout, &timeout)) {
|
||||
config.connectTimeoutMs = timeout;
|
||||
}
|
||||
}
|
||||
const char* request_timeout = getenv("KINESIS_REQUEST_TIMEOUT_MSEC");
|
||||
if (request_timeout) {
|
||||
int64 timeout;
|
||||
|
||||
if (strings::safe_strto64(request_timeout, &timeout)) {
|
||||
config.requestTimeoutMs = timeout;
|
||||
}
|
||||
}
|
||||
|
||||
return &config;
|
||||
}
|
||||
|
||||
Aws::Client::ClientConfiguration& GetDefaultClientConfig() {
|
||||
static Aws::Client::ClientConfiguration* config =
|
||||
InitializeDefaultClientConfig();
|
||||
return *config;
|
||||
}
|
||||
|
||||
static mutex mu(LINKER_INITIALIZED);
|
||||
static unsigned count(0);
|
||||
void AwsInitAPI() {
|
||||
mutex_lock lock(mu);
|
||||
count++;
|
||||
if (count == 1) {
|
||||
Aws::SDKOptions options;
|
||||
options.cryptoOptions.sha256Factory_create_fn = []() {
|
||||
return Aws::MakeShared<AWSSHA256Factory>(AWSCryptoAllocationTag);
|
||||
};
|
||||
options.cryptoOptions.sha256HMACFactory_create_fn = []() {
|
||||
return Aws::MakeShared<AWSSHA256HmacFactory>(AWSCryptoAllocationTag);
|
||||
};
|
||||
Aws::InitAPI(options);
|
||||
}
|
||||
}
|
||||
void AwsShutdownAPI() {
|
||||
mutex_lock lock(mu);
|
||||
count--;
|
||||
if (count == 0) {
|
||||
Aws::SDKOptions options;
|
||||
Aws::ShutdownAPI(options);
|
||||
}
|
||||
}
|
||||
void ShutdownClient(Aws::Kinesis::KinesisClient* client) {
|
||||
if (client != nullptr) {
|
||||
delete client;
|
||||
AwsShutdownAPI();
|
||||
}
|
||||
}
|
||||
}
|
||||
class KinesisDatasetOp : public DatasetOpKernel {
|
||||
public:
|
||||
using DatasetOpKernel::DatasetOpKernel;
|
||||
|
||||
void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
|
||||
std::string stream = "";
|
||||
OP_REQUIRES_OK(ctx,
|
||||
data::ParseScalarArgument<tstring>(ctx, "stream", &stream));
|
||||
std::string shard = "";
|
||||
OP_REQUIRES_OK(ctx,
|
||||
data::ParseScalarArgument<tstring>(ctx, "shard", &shard));
|
||||
bool read_indefinitely = true;
|
||||
OP_REQUIRES_OK(ctx, data::ParseScalarArgument<bool>(
|
||||
ctx, "read_indefinitely", &read_indefinitely));
|
||||
int64 interval = -1;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, data::ParseScalarArgument<int64>(ctx, "interval", &interval));
|
||||
OP_REQUIRES(ctx, (interval > 0),
|
||||
errors::InvalidArgument(
|
||||
"Interval value should be large than 0, got ", interval));
|
||||
*output = new Dataset(ctx, stream, shard, read_indefinitely, interval);
|
||||
}
|
||||
|
||||
private:
|
||||
class Dataset : public DatasetBase {
|
||||
public:
|
||||
Dataset(OpKernelContext* ctx, const string& stream, const string& shard,
|
||||
const bool read_indefinitely, const int64 interval)
|
||||
: DatasetBase(DatasetContext(ctx)),
|
||||
stream_(stream),
|
||||
shard_(shard),
|
||||
read_indefinitely_(read_indefinitely),
|
||||
interval_(interval) {}
|
||||
|
||||
std::unique_ptr<IteratorBase> MakeIteratorInternal(
|
||||
const string& prefix) const override {
|
||||
return std::unique_ptr<IteratorBase>(
|
||||
new Iterator({this, strings::StrCat(prefix, "::Kinesis")}));
|
||||
}
|
||||
|
||||
const DataTypeVector& output_dtypes() const override {
|
||||
static DataTypeVector* dtypes = new DataTypeVector({DT_STRING});
|
||||
return *dtypes;
|
||||
}
|
||||
|
||||
const std::vector<PartialTensorShape>& output_shapes() const override {
|
||||
static std::vector<PartialTensorShape>* shapes =
|
||||
new std::vector<PartialTensorShape>({{}});
|
||||
return *shapes;
|
||||
}
|
||||
|
||||
string DebugString() const override { return "KinesisDatasetOp::Dataset"; }
|
||||
|
||||
protected:
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
DatasetGraphDefBuilder* b,
|
||||
Node** output) const override {
|
||||
Node* stream = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(stream_, &stream));
|
||||
Node* shard = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(shard_, &shard));
|
||||
Node* read_indefinitely = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(read_indefinitely_, &read_indefinitely));
|
||||
Node* interval = nullptr;
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(interval_, &interval));
|
||||
TF_RETURN_IF_ERROR(b->AddDataset(
|
||||
this, {stream, shard, read_indefinitely, interval}, output));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
class Iterator : public DatasetIterator<Dataset> {
|
||||
public:
|
||||
explicit Iterator(const Params& params)
|
||||
: DatasetIterator<Dataset>(params),
|
||||
client_(nullptr, ShutdownClient) {}
|
||||
|
||||
Status GetNextInternal(IteratorContext* ctx,
|
||||
std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) override {
|
||||
mutex_lock l(mu_);
|
||||
if (iterator_ == "") {
|
||||
TF_RETURN_IF_ERROR(SetupStreamsLocked());
|
||||
}
|
||||
do {
|
||||
Aws::Kinesis::Model::GetRecordsRequest request;
|
||||
auto outcome = client_->GetRecords(
|
||||
request.WithShardIterator(iterator_).WithLimit(1));
|
||||
if (!outcome.IsSuccess()) {
|
||||
return errors::Unknown(outcome.GetError().GetExceptionName(), ": ",
|
||||
outcome.GetError().GetMessage());
|
||||
}
|
||||
if (outcome.GetResult().GetRecords().size() == 0) {
|
||||
// If no records were returned then nothing is available at the
|
||||
// moment.
|
||||
if (!dataset()->read_indefinitely_) {
|
||||
*end_of_sequence = true;
|
||||
return Status::OK();
|
||||
}
|
||||
// Continue the loop after a period of time.
|
||||
ctx->env()->SleepForMicroseconds(dataset()->interval_);
|
||||
continue;
|
||||
}
|
||||
if (outcome.GetResult().GetRecords().size() != 1) {
|
||||
return errors::Unknown("invalid number of records ",
|
||||
outcome.GetResult().GetRecords().size(),
|
||||
" returned");
|
||||
}
|
||||
|
||||
iterator_ = outcome.GetResult().GetNextShardIterator();
|
||||
|
||||
const auto& data = outcome.GetResult().GetRecords()[0].GetData();
|
||||
StringPiece value(
|
||||
reinterpret_cast<const char*>(data.GetUnderlyingData()),
|
||||
data.GetLength());
|
||||
Tensor value_tensor(ctx->allocator({}), DT_STRING, {});
|
||||
value_tensor.scalar<std::string>()() = std::string(value);
|
||||
out_tensors->emplace_back(std::move(value_tensor));
|
||||
|
||||
*end_of_sequence = false;
|
||||
return Status::OK();
|
||||
} while (true);
|
||||
}
|
||||
|
||||
protected:
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
return errors::Unimplemented("SaveInternal is currently not supported");
|
||||
}
|
||||
|
||||
Status RestoreInternal(IteratorContext* ctx,
|
||||
IteratorStateReader* reader) override {
|
||||
return errors::Unimplemented(
|
||||
"RestoreInternal is currently not supported");
|
||||
}
|
||||
|
||||
private:
|
||||
// Sets up Kinesis streams to read from.
|
||||
Status SetupStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
AwsInitAPI();
|
||||
client_.reset(
|
||||
new Aws::Kinesis::KinesisClient(GetDefaultClientConfig()));
|
||||
|
||||
Aws::Kinesis::Model::DescribeStreamRequest request;
|
||||
auto outcome = client_->DescribeStream(
|
||||
request.WithStreamName(dataset()->stream_.c_str()));
|
||||
if (!outcome.IsSuccess()) {
|
||||
return errors::Unknown(outcome.GetError().GetExceptionName(), ": ",
|
||||
outcome.GetError().GetMessage());
|
||||
}
|
||||
Aws::String shard;
|
||||
Aws::String sequence;
|
||||
if (dataset()->shard_ == "") {
|
||||
if (outcome.GetResult().GetStreamDescription().GetShards().size() !=
|
||||
1) {
|
||||
return errors::InvalidArgument(
|
||||
"shard has to be provided unless the stream only have one "
|
||||
"shard, there are ",
|
||||
outcome.GetResult().GetStreamDescription().GetShards().size(),
|
||||
" shards in stream ", dataset()->stream_);
|
||||
}
|
||||
shard = outcome.GetResult()
|
||||
.GetStreamDescription()
|
||||
.GetShards()[0]
|
||||
.GetShardId();
|
||||
sequence = outcome.GetResult()
|
||||
.GetStreamDescription()
|
||||
.GetShards()[0]
|
||||
.GetSequenceNumberRange()
|
||||
.GetStartingSequenceNumber();
|
||||
} else {
|
||||
for (const auto& entry :
|
||||
outcome.GetResult().GetStreamDescription().GetShards()) {
|
||||
if (entry.GetShardId() == dataset()->shard_.c_str()) {
|
||||
shard = entry.GetShardId();
|
||||
sequence =
|
||||
entry.GetSequenceNumberRange().GetStartingSequenceNumber();
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (shard == "") {
|
||||
return errors::InvalidArgument("no shard ", dataset()->shard_,
|
||||
" in stream ", dataset()->stream_);
|
||||
}
|
||||
}
|
||||
|
||||
Aws::Kinesis::Model::GetShardIteratorRequest iterator_request;
|
||||
auto iterator_outcome = client_->GetShardIterator(
|
||||
iterator_request.WithStreamName(dataset()->stream_.c_str())
|
||||
.WithShardId(shard)
|
||||
.WithShardIteratorType(
|
||||
Aws::Kinesis::Model::ShardIteratorType::AT_SEQUENCE_NUMBER)
|
||||
.WithStartingSequenceNumber(sequence));
|
||||
if (!iterator_outcome.IsSuccess()) {
|
||||
return errors::Unknown(iterator_outcome.GetError().GetExceptionName(),
|
||||
": ",
|
||||
iterator_outcome.GetError().GetMessage());
|
||||
}
|
||||
iterator_ = iterator_outcome.GetResult().GetShardIterator();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
mutex mu_;
|
||||
Aws::String iterator_ GUARDED_BY(mu_);
|
||||
std::unique_ptr<Aws::Kinesis::KinesisClient, decltype(&ShutdownClient)>
|
||||
client_ GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
const std::string stream_;
|
||||
const std::string shard_;
|
||||
const bool read_indefinitely_;
|
||||
const int64 interval_;
|
||||
};
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("KinesisDataset").Device(DEVICE_CPU),
|
||||
KinesisDatasetOp);
|
||||
|
||||
} // namespace tensorflow
|
@ -1,42 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
REGISTER_OP("KinesisDataset")
|
||||
.Input("stream: string")
|
||||
.Input("shard: string")
|
||||
.Input("read_indefinitely: bool")
|
||||
.Input("interval: int64")
|
||||
.Output("handle: variant")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::ScalarShape)
|
||||
.Doc(R"doc(
|
||||
Creates a dataset that emits the messages of one or more Kinesis topics.
|
||||
|
||||
stream: A `tf.string` tensor containing the name of the stream.
|
||||
shard: A `tf.string` tensor containing the id of the shard.
|
||||
read_indefinitely: If `True`, the Kinesis dataset will keep retry
|
||||
again on `EOF` after the `interval` period. If `False`, then
|
||||
the dataset will stop on `EOF`. The default value is `True`.
|
||||
interval: The interval for the Kinesis Client to wait before
|
||||
it tries to get records again (in millisecond).
|
||||
)doc");
|
||||
|
||||
} // namespace tensorflow
|
@ -1,142 +0,0 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
|
||||
# use this file except in compliance with the License. You may obtain a copy of
|
||||
# the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations under
|
||||
# the License.
|
||||
# ==============================================================================
|
||||
"""Tests for KinesisDataset.
|
||||
NOTE: boto3 is needed and the test has to be invoked manually:
|
||||
```
|
||||
$ bazel test -s --verbose_failures --config=opt \
|
||||
--action_env=AWS_ACCESS_KEY_ID=XXXXXX \
|
||||
--action_env=AWS_SECRET_ACCESS_KEY=XXXXXX \
|
||||
//tensorflow/contrib/kinesis:kinesis_test
|
||||
```
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import boto3
|
||||
|
||||
from tensorflow.contrib.kinesis.python.ops import kinesis_dataset_ops
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class KinesisDatasetTest(test.TestCase):
|
||||
|
||||
def testKinesisDatasetOneShard(self):
|
||||
client = boto3.client('kinesis', region_name='us-east-1')
|
||||
|
||||
# Setup the Kinesis with 1 shard.
|
||||
stream_name = "tf_kinesis_test_1"
|
||||
client.create_stream(StreamName=stream_name, ShardCount=1)
|
||||
# Wait until stream exists, default is 10 * 18 seconds.
|
||||
client.get_waiter('stream_exists').wait(StreamName=stream_name)
|
||||
for i in range(10):
|
||||
data = "D" + str(i)
|
||||
client.put_record(
|
||||
StreamName=stream_name, Data=data, PartitionKey="TensorFlow" + str(i))
|
||||
|
||||
stream = array_ops.placeholder(dtypes.string, shape=[])
|
||||
num_epochs = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
batch_size = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
|
||||
repeat_dataset = kinesis_dataset_ops.KinesisDataset(
|
||||
stream, read_indefinitely=False).repeat(num_epochs)
|
||||
batch_dataset = repeat_dataset.batch(batch_size)
|
||||
|
||||
iterator = iterator_ops.Iterator.from_structure(
|
||||
dataset_ops.get_legacy_output_types(batch_dataset))
|
||||
init_op = iterator.make_initializer(repeat_dataset)
|
||||
init_batch_op = iterator.make_initializer(batch_dataset)
|
||||
get_next = iterator.get_next()
|
||||
|
||||
with self.cached_session() as sess:
|
||||
# Basic test: read from shard 0 of stream 1.
|
||||
sess.run(init_op, feed_dict={stream: stream_name, num_epochs: 1})
|
||||
for i in range(10):
|
||||
self.assertEqual("D" + str(i), sess.run(get_next))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
sess.run(get_next)
|
||||
|
||||
client.delete_stream(StreamName=stream_name)
|
||||
# Wait until stream deleted, default is 10 * 18 seconds.
|
||||
client.get_waiter('stream_not_exists').wait(StreamName=stream_name)
|
||||
|
||||
def testKinesisDatasetTwoShards(self):
|
||||
client = boto3.client('kinesis', region_name='us-east-1')
|
||||
|
||||
# Setup the Kinesis with 2 shards.
|
||||
stream_name = "tf_kinesis_test_2"
|
||||
client.create_stream(StreamName=stream_name, ShardCount=2)
|
||||
# Wait until stream exists, default is 10 * 18 seconds.
|
||||
client.get_waiter('stream_exists').wait(StreamName=stream_name)
|
||||
|
||||
for i in range(10):
|
||||
data = "D" + str(i)
|
||||
client.put_record(
|
||||
StreamName=stream_name, Data=data, PartitionKey="TensorFlow" + str(i))
|
||||
response = client.describe_stream(StreamName=stream_name)
|
||||
shard_id_0 = response["StreamDescription"]["Shards"][0]["ShardId"]
|
||||
shard_id_1 = response["StreamDescription"]["Shards"][1]["ShardId"]
|
||||
|
||||
stream = array_ops.placeholder(dtypes.string, shape=[])
|
||||
shard = array_ops.placeholder(dtypes.string, shape=[])
|
||||
num_epochs = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
batch_size = array_ops.placeholder(dtypes.int64, shape=[])
|
||||
|
||||
repeat_dataset = kinesis_dataset_ops.KinesisDataset(
|
||||
stream, shard, read_indefinitely=False).repeat(num_epochs)
|
||||
batch_dataset = repeat_dataset.batch(batch_size)
|
||||
|
||||
iterator = iterator_ops.Iterator.from_structure(
|
||||
dataset_ops.get_legacy_output_types(batch_dataset))
|
||||
init_op = iterator.make_initializer(repeat_dataset)
|
||||
init_batch_op = iterator.make_initializer(batch_dataset)
|
||||
get_next = iterator.get_next()
|
||||
|
||||
data = []
|
||||
with self.cached_session() as sess:
|
||||
# Basic test: read from shard 0 of stream 2.
|
||||
sess.run(
|
||||
init_op, feed_dict={
|
||||
stream: stream_name, shard: shard_id_0, num_epochs: 1})
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
# Use range(11) to guarantee the OutOfRangeError.
|
||||
for i in range(11):
|
||||
data.append(sess.run(get_next))
|
||||
|
||||
# Basic test: read from shard 1 of stream 2.
|
||||
sess.run(
|
||||
init_op, feed_dict={
|
||||
stream: stream_name, shard: shard_id_1, num_epochs: 1})
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
# Use range(11) to guarantee the OutOfRangeError.
|
||||
for i in range(11):
|
||||
data.append(sess.run(get_next))
|
||||
|
||||
data.sort()
|
||||
self.assertEqual(data, ["D" + str(i) for i in range(10)])
|
||||
|
||||
client.delete_stream(StreamName=stream_name)
|
||||
# Wait until stream deleted, default is 10 * 18 seconds.
|
||||
client.get_waiter('stream_not_exists').wait(StreamName=stream_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -1,90 +0,0 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Kinesis Dataset."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.kinesis.python.ops import gen_dataset_ops
|
||||
from tensorflow.contrib.kinesis.python.ops import kinesis_op_loader # pylint: disable=unused-import
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.util import deprecation
|
||||
|
||||
|
||||
class KinesisDataset(dataset_ops.DatasetSource):
|
||||
"""A Kinesis Dataset that consumes the message.
|
||||
|
||||
Kinesis is a managed service provided by AWS for data streaming.
|
||||
This dataset reads messages from Kinesis with each message presented
|
||||
as a `tf.string`.
|
||||
|
||||
For example, we can construct and use the KinesisDataset as follows:
|
||||
```python
|
||||
tf.compat.v1.enable_eager_execution()
|
||||
|
||||
dataset = tf.contrib.kinesis.KinesisDataset(
|
||||
"kinesis_stream_name", read_indefinitely=False)
|
||||
for element in dataset:
|
||||
print(element)
|
||||
```
|
||||
|
||||
Since Kinesis is a data streaming service, data may not be available
|
||||
at the time it is being read. The argument `read_indefinitely` is
|
||||
used to control the behavior in this situation. If `read_indefinitely`
|
||||
is `True`, then `KinesisDataset` will keep retrying to retrieve data
|
||||
from the stream. If `read_indefinitely` is `False`, an `OutOfRangeError`
|
||||
is returned immediately instead.
|
||||
"""
|
||||
|
||||
@deprecation.deprecated(
|
||||
None,
|
||||
"tf.contrib.kinesis will be removed in 2.0, the support for Kinesis "
|
||||
"will continue to be provided through the tensorflow/io GitHub project.")
|
||||
def __init__(self,
|
||||
stream,
|
||||
shard="",
|
||||
read_indefinitely=True,
|
||||
interval=100000):
|
||||
"""Create a KinesisDataset.
|
||||
|
||||
Args:
|
||||
stream: A `tf.string` tensor containing the name of the stream.
|
||||
shard: A `tf.string` tensor containing the id of the shard.
|
||||
read_indefinitely: If `True`, the Kinesis dataset will keep retry
|
||||
again on `EOF` after the `interval` period. If `False`, then
|
||||
the dataset will stop on `EOF`. The default value is `True`.
|
||||
interval: The interval for the Kinesis Client to wait before
|
||||
it tries to get records again (in millisecond).
|
||||
"""
|
||||
self._stream = ops.convert_to_tensor(
|
||||
stream, dtype=dtypes.string, name="stream")
|
||||
self._shard = ops.convert_to_tensor(
|
||||
shard, dtype=dtypes.string, name="shard")
|
||||
self._read_indefinitely = ops.convert_to_tensor(
|
||||
read_indefinitely, dtype=dtypes.bool, name="read_indefinitely")
|
||||
self._interval = ops.convert_to_tensor(
|
||||
interval, dtype=dtypes.int64, name="interval")
|
||||
super(KinesisDataset, self).__init__(self._as_variant_tensor())
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return gen_dataset_ops.kinesis_dataset(
|
||||
self._stream, self._shard, self._read_indefinitely, self._interval)
|
||||
|
||||
@property
|
||||
def element_spec(self):
|
||||
return tensor_spec.TensorSpec([], dtypes.string)
|
@ -1,24 +0,0 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Python helper for loading kinesis ops and kernels."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.util import loader
|
||||
from tensorflow.python.platform import resource_loader
|
||||
|
||||
_dataset_ops = loader.load_op_library(
|
||||
resource_loader.get_path_to_datafile("../../_dataset_ops.so"))
|
Loading…
Reference in New Issue
Block a user