Remove AWS/S3 related code from tensorflow/contrib.
These have oved to github.com/tensorflow/io PiperOrigin-RevId: 266454456
This commit is contained in:
parent
ed9dc39b0b
commit
37479ad906
@ -5,7 +5,6 @@
|
|||||||
/tenosrflow/core/debug @caisq
|
/tenosrflow/core/debug @caisq
|
||||||
/tensorflow/core/nccl/ @azaks2 @chsigg
|
/tensorflow/core/nccl/ @azaks2 @chsigg
|
||||||
/tensorflow/core/platform/windows/ @mrry
|
/tensorflow/core/platform/windows/ @mrry
|
||||||
/tensorflow/core/platform/s3 @yongtang
|
|
||||||
/tensorflow/python/autograph/ @mdanatg @kkimdev
|
/tensorflow/python/autograph/ @mdanatg @kkimdev
|
||||||
/tensorflow/python/debug @caisq
|
/tensorflow/python/debug @caisq
|
||||||
/tensorflow/python/eager @jaingaurav @alextp
|
/tensorflow/python/eager @jaingaurav @alextp
|
||||||
@ -38,7 +37,6 @@
|
|||||||
/tensorflow/contrib/hvx/ @satok16
|
/tensorflow/contrib/hvx/ @satok16
|
||||||
/tensorflow/contrib/integrate/ @shoyer
|
/tensorflow/contrib/integrate/ @shoyer
|
||||||
/tensorflow/contrib/kernel_methods/ @petrosmol
|
/tensorflow/contrib/kernel_methods/ @petrosmol
|
||||||
/tensorflow/contrib/kinesis @yongtang
|
|
||||||
/tensorflow/contrib/ios_examples/ @petewarden
|
/tensorflow/contrib/ios_examples/ @petewarden
|
||||||
/tensorflow/contrib/labeled_tensor/ @shoyer
|
/tensorflow/contrib/labeled_tensor/ @shoyer
|
||||||
/tensorflow/contrib/layers/ @fchollet @martinwicke
|
/tensorflow/contrib/layers/ @fchollet @martinwicke
|
||||||
|
@ -108,15 +108,6 @@ py_library(
|
|||||||
"//tensorflow/python:util",
|
"//tensorflow/python:util",
|
||||||
"//tensorflow/python/estimator:estimator_py",
|
"//tensorflow/python/estimator:estimator_py",
|
||||||
] + select({
|
] + select({
|
||||||
"//tensorflow:android": [],
|
|
||||||
"//tensorflow:ios": [],
|
|
||||||
"//tensorflow:linux_s390x": [],
|
|
||||||
"//tensorflow:windows": [],
|
|
||||||
"//tensorflow:no_aws_support": [],
|
|
||||||
"//conditions:default": [
|
|
||||||
"//tensorflow/contrib/kinesis",
|
|
||||||
],
|
|
||||||
}) + select({
|
|
||||||
"//tensorflow:android": [],
|
"//tensorflow:android": [],
|
||||||
"//tensorflow:ios": [],
|
"//tensorflow:ios": [],
|
||||||
"//tensorflow:linux_s390x": [],
|
"//tensorflow:linux_s390x": [],
|
||||||
@ -165,16 +156,7 @@ cc_library(
|
|||||||
"//tensorflow/contrib/tensor_forest:stats_ops_kernels",
|
"//tensorflow/contrib/tensor_forest:stats_ops_kernels",
|
||||||
"//tensorflow/contrib/tensor_forest:tensor_forest_kernels",
|
"//tensorflow/contrib/tensor_forest:tensor_forest_kernels",
|
||||||
"//tensorflow/contrib/text:all_kernels",
|
"//tensorflow/contrib/text:all_kernels",
|
||||||
] + select({
|
] + if_not_windows([
|
||||||
"//tensorflow:android": [],
|
|
||||||
"//tensorflow:ios": [],
|
|
||||||
"//tensorflow:linux_s390x": [],
|
|
||||||
"//tensorflow:windows": [],
|
|
||||||
"//tensorflow:no_aws_support": [],
|
|
||||||
"//conditions:default": [
|
|
||||||
"//tensorflow/contrib/kinesis:dataset_kernels",
|
|
||||||
],
|
|
||||||
}) + if_not_windows([
|
|
||||||
"//tensorflow/contrib/tensorrt:trt_op_kernels",
|
"//tensorflow/contrib/tensorrt:trt_op_kernels",
|
||||||
]),
|
]),
|
||||||
)
|
)
|
||||||
@ -196,16 +178,7 @@ cc_library(
|
|||||||
"//tensorflow/contrib/tensor_forest:stats_ops_op_lib",
|
"//tensorflow/contrib/tensor_forest:stats_ops_op_lib",
|
||||||
"//tensorflow/contrib/tensor_forest:tensor_forest_ops_op_lib",
|
"//tensorflow/contrib/tensor_forest:tensor_forest_ops_op_lib",
|
||||||
"//tensorflow/contrib/text:all_ops",
|
"//tensorflow/contrib/text:all_ops",
|
||||||
] + select({
|
] + if_not_windows([
|
||||||
"//tensorflow:android": [],
|
|
||||||
"//tensorflow:ios": [],
|
|
||||||
"//tensorflow:linux_s390x": [],
|
|
||||||
"//tensorflow:windows": [],
|
|
||||||
"//tensorflow:no_aws_support": [],
|
|
||||||
"//conditions:default": [
|
|
||||||
"//tensorflow/contrib/kinesis:dataset_ops_op_lib",
|
|
||||||
],
|
|
||||||
}) + if_not_windows([
|
|
||||||
"//tensorflow/compiler/tf2tensorrt:trt_op_libs",
|
"//tensorflow/compiler/tf2tensorrt:trt_op_libs",
|
||||||
]) + select({
|
]) + select({
|
||||||
"//tensorflow:android": [],
|
"//tensorflow:android": [],
|
||||||
|
@ -1,113 +0,0 @@
|
|||||||
package(default_visibility = ["//tensorflow:internal"])
|
|
||||||
|
|
||||||
licenses(["notice"]) # Apache 2.0
|
|
||||||
|
|
||||||
exports_files(["LICENSE"])
|
|
||||||
|
|
||||||
load(
|
|
||||||
"//tensorflow:tensorflow.bzl",
|
|
||||||
"tf_custom_op_library",
|
|
||||||
"tf_custom_op_py_library",
|
|
||||||
"tf_gen_op_libs",
|
|
||||||
"tf_gen_op_wrapper_py",
|
|
||||||
"tf_kernel_library",
|
|
||||||
"tf_py_test",
|
|
||||||
)
|
|
||||||
|
|
||||||
py_library(
|
|
||||||
name = "kinesis",
|
|
||||||
srcs = ["__init__.py"],
|
|
||||||
srcs_version = "PY2AND3",
|
|
||||||
deps = [
|
|
||||||
":dataset_ops",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_custom_op_library(
|
|
||||||
name = "_dataset_ops.so",
|
|
||||||
srcs = ["ops/dataset_ops.cc"],
|
|
||||||
deps = [":dataset_kernels"],
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_gen_op_libs(
|
|
||||||
op_lib_names = ["dataset_ops"],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "dataset_kernels",
|
|
||||||
srcs = [
|
|
||||||
"kernels/kinesis_dataset_ops.cc",
|
|
||||||
],
|
|
||||||
deps = [
|
|
||||||
"//tensorflow/core:framework_headers_lib",
|
|
||||||
"//tensorflow/core/platform/s3:aws_crypto",
|
|
||||||
"//third_party/eigen3",
|
|
||||||
"@aws",
|
|
||||||
"@com_google_protobuf//:protobuf_headers",
|
|
||||||
],
|
|
||||||
alwayslink = 1,
|
|
||||||
)
|
|
||||||
|
|
||||||
py_library(
|
|
||||||
name = "dataset_ops",
|
|
||||||
srcs = [
|
|
||||||
"python/ops/kinesis_dataset_ops.py",
|
|
||||||
],
|
|
||||||
srcs_version = "PY2AND3",
|
|
||||||
deps = [
|
|
||||||
":kinesis_op_loader",
|
|
||||||
"//tensorflow/python:dataset_ops_gen",
|
|
||||||
"//tensorflow/python:util",
|
|
||||||
"//tensorflow/python/data/ops:dataset_ops",
|
|
||||||
"//tensorflow/python/data/util:nest",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_gen_op_wrapper_py(
|
|
||||||
name = "gen_dataset_ops",
|
|
||||||
out = "python/ops/gen_dataset_ops.py",
|
|
||||||
deps = ["//tensorflow/contrib/kinesis:dataset_ops_op_lib"],
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_kernel_library(
|
|
||||||
name = "dataset_ops_kernels",
|
|
||||||
deps = [
|
|
||||||
":dataset_kernels",
|
|
||||||
"//tensorflow/core:framework",
|
|
||||||
],
|
|
||||||
alwayslink = 1,
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_custom_op_py_library(
|
|
||||||
name = "kinesis_op_loader",
|
|
||||||
srcs = ["python/ops/kinesis_op_loader.py"],
|
|
||||||
dso = ["//tensorflow/contrib/kinesis:_dataset_ops.so"],
|
|
||||||
kernels = [
|
|
||||||
":dataset_ops_kernels",
|
|
||||||
"//tensorflow/contrib/kinesis:dataset_ops_op_lib",
|
|
||||||
],
|
|
||||||
srcs_version = "PY2AND3",
|
|
||||||
deps = [
|
|
||||||
":gen_dataset_ops",
|
|
||||||
"//tensorflow/contrib/util:util_py",
|
|
||||||
"//tensorflow/python:platform",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_py_test(
|
|
||||||
name = "kinesis_test",
|
|
||||||
srcs = ["python/kernel_tests/kinesis_test.py"],
|
|
||||||
additional_deps = [
|
|
||||||
":kinesis",
|
|
||||||
"//third_party/py/numpy",
|
|
||||||
"//tensorflow/python:client_testlib",
|
|
||||||
"//tensorflow/python:framework",
|
|
||||||
"//tensorflow/python:framework_test_lib",
|
|
||||||
"//tensorflow/python:platform_test",
|
|
||||||
],
|
|
||||||
tags = [
|
|
||||||
"manual",
|
|
||||||
"no_windows",
|
|
||||||
"notap",
|
|
||||||
],
|
|
||||||
)
|
|
@ -1,32 +0,0 @@
|
|||||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
"""Kinesis Dataset.
|
|
||||||
|
|
||||||
@@KinesisDataset
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
from tensorflow.contrib.kinesis.python.ops.kinesis_dataset_ops import KinesisDataset
|
|
||||||
|
|
||||||
from tensorflow.python.util.all_util import remove_undocumented
|
|
||||||
|
|
||||||
_allowed_symbols = [
|
|
||||||
"KinesisDataset",
|
|
||||||
]
|
|
||||||
|
|
||||||
remove_undocumented(__name__)
|
|
@ -1,362 +0,0 @@
|
|||||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#include <aws/core/Aws.h>
|
|
||||||
#include <aws/core/config/AWSProfileConfigLoader.h>
|
|
||||||
#include <aws/core/utils/Outcome.h>
|
|
||||||
#include <aws/kinesis/KinesisClient.h>
|
|
||||||
#include <aws/kinesis/model/DescribeStreamRequest.h>
|
|
||||||
#include <aws/kinesis/model/GetRecordsRequest.h>
|
|
||||||
#include <aws/kinesis/model/GetShardIteratorRequest.h>
|
|
||||||
#include <aws/kinesis/model/PutRecordsRequest.h>
|
|
||||||
#include <aws/kinesis/model/ShardIteratorType.h>
|
|
||||||
#include "tensorflow/core/framework/dataset.h"
|
|
||||||
#include "tensorflow/core/platform/s3/aws_crypto.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
Aws::Client::ClientConfiguration* InitializeDefaultClientConfig() {
|
|
||||||
static Aws::Client::ClientConfiguration config;
|
|
||||||
const char* endpoint = getenv("KINESIS_ENDPOINT");
|
|
||||||
if (endpoint) {
|
|
||||||
config.endpointOverride = Aws::String(endpoint);
|
|
||||||
}
|
|
||||||
const char* region = getenv("AWS_REGION");
|
|
||||||
if (region) {
|
|
||||||
config.region = Aws::String(region);
|
|
||||||
} else {
|
|
||||||
// Load config file (e.g., ~/.aws/config) only if AWS_SDK_LOAD_CONFIG
|
|
||||||
// is set with a truthy value.
|
|
||||||
const char* load_config_env = getenv("AWS_SDK_LOAD_CONFIG");
|
|
||||||
string load_config =
|
|
||||||
load_config_env ? absl::AsciiStrToLower(load_config_env) : "";
|
|
||||||
if (load_config == "true" || load_config == "1") {
|
|
||||||
Aws::String config_file;
|
|
||||||
// If AWS_CONFIG_FILE is set then use it, otherwise use ~/.aws/config.
|
|
||||||
const char* config_file_env = getenv("AWS_CONFIG_FILE");
|
|
||||||
if (config_file_env) {
|
|
||||||
config_file = config_file_env;
|
|
||||||
} else {
|
|
||||||
const char* home_env = getenv("HOME");
|
|
||||||
if (home_env) {
|
|
||||||
config_file = home_env;
|
|
||||||
config_file += "/.aws/config";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Aws::Config::AWSConfigFileProfileConfigLoader loader(config_file);
|
|
||||||
// Load the configuration. If successful, get the region.
|
|
||||||
// If the load is not successful, then generate a warning.
|
|
||||||
if (loader.Load()) {
|
|
||||||
auto profiles = loader.GetProfiles();
|
|
||||||
if (!profiles["default"].GetRegion().empty()) {
|
|
||||||
config.region = profiles["default"].GetRegion();
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
LOG(WARNING) << "Failed to load the profile in " << config_file << ".";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
const char* use_https = getenv("KINESIS_USE_HTTPS");
|
|
||||||
if (use_https) {
|
|
||||||
if (use_https[0] == '0') {
|
|
||||||
config.scheme = Aws::Http::Scheme::HTTP;
|
|
||||||
} else {
|
|
||||||
config.scheme = Aws::Http::Scheme::HTTPS;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
const char* verify_ssl = getenv("KINESIS_VERIFY_SSL");
|
|
||||||
if (verify_ssl) {
|
|
||||||
if (verify_ssl[0] == '0') {
|
|
||||||
config.verifySSL = false;
|
|
||||||
} else {
|
|
||||||
config.verifySSL = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
const char* connect_timeout = getenv("KINESIS_CONNECT_TIMEOUT_MSEC");
|
|
||||||
if (connect_timeout) {
|
|
||||||
int64 timeout;
|
|
||||||
|
|
||||||
if (strings::safe_strto64(connect_timeout, &timeout)) {
|
|
||||||
config.connectTimeoutMs = timeout;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
const char* request_timeout = getenv("KINESIS_REQUEST_TIMEOUT_MSEC");
|
|
||||||
if (request_timeout) {
|
|
||||||
int64 timeout;
|
|
||||||
|
|
||||||
if (strings::safe_strto64(request_timeout, &timeout)) {
|
|
||||||
config.requestTimeoutMs = timeout;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &config;
|
|
||||||
}
|
|
||||||
|
|
||||||
Aws::Client::ClientConfiguration& GetDefaultClientConfig() {
|
|
||||||
static Aws::Client::ClientConfiguration* config =
|
|
||||||
InitializeDefaultClientConfig();
|
|
||||||
return *config;
|
|
||||||
}
|
|
||||||
|
|
||||||
static mutex mu(LINKER_INITIALIZED);
|
|
||||||
static unsigned count(0);
|
|
||||||
void AwsInitAPI() {
|
|
||||||
mutex_lock lock(mu);
|
|
||||||
count++;
|
|
||||||
if (count == 1) {
|
|
||||||
Aws::SDKOptions options;
|
|
||||||
options.cryptoOptions.sha256Factory_create_fn = []() {
|
|
||||||
return Aws::MakeShared<AWSSHA256Factory>(AWSCryptoAllocationTag);
|
|
||||||
};
|
|
||||||
options.cryptoOptions.sha256HMACFactory_create_fn = []() {
|
|
||||||
return Aws::MakeShared<AWSSHA256HmacFactory>(AWSCryptoAllocationTag);
|
|
||||||
};
|
|
||||||
Aws::InitAPI(options);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
void AwsShutdownAPI() {
|
|
||||||
mutex_lock lock(mu);
|
|
||||||
count--;
|
|
||||||
if (count == 0) {
|
|
||||||
Aws::SDKOptions options;
|
|
||||||
Aws::ShutdownAPI(options);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
void ShutdownClient(Aws::Kinesis::KinesisClient* client) {
|
|
||||||
if (client != nullptr) {
|
|
||||||
delete client;
|
|
||||||
AwsShutdownAPI();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
class KinesisDatasetOp : public DatasetOpKernel {
|
|
||||||
public:
|
|
||||||
using DatasetOpKernel::DatasetOpKernel;
|
|
||||||
|
|
||||||
void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
|
|
||||||
std::string stream = "";
|
|
||||||
OP_REQUIRES_OK(ctx,
|
|
||||||
data::ParseScalarArgument<tstring>(ctx, "stream", &stream));
|
|
||||||
std::string shard = "";
|
|
||||||
OP_REQUIRES_OK(ctx,
|
|
||||||
data::ParseScalarArgument<tstring>(ctx, "shard", &shard));
|
|
||||||
bool read_indefinitely = true;
|
|
||||||
OP_REQUIRES_OK(ctx, data::ParseScalarArgument<bool>(
|
|
||||||
ctx, "read_indefinitely", &read_indefinitely));
|
|
||||||
int64 interval = -1;
|
|
||||||
OP_REQUIRES_OK(
|
|
||||||
ctx, data::ParseScalarArgument<int64>(ctx, "interval", &interval));
|
|
||||||
OP_REQUIRES(ctx, (interval > 0),
|
|
||||||
errors::InvalidArgument(
|
|
||||||
"Interval value should be large than 0, got ", interval));
|
|
||||||
*output = new Dataset(ctx, stream, shard, read_indefinitely, interval);
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
class Dataset : public DatasetBase {
|
|
||||||
public:
|
|
||||||
Dataset(OpKernelContext* ctx, const string& stream, const string& shard,
|
|
||||||
const bool read_indefinitely, const int64 interval)
|
|
||||||
: DatasetBase(DatasetContext(ctx)),
|
|
||||||
stream_(stream),
|
|
||||||
shard_(shard),
|
|
||||||
read_indefinitely_(read_indefinitely),
|
|
||||||
interval_(interval) {}
|
|
||||||
|
|
||||||
std::unique_ptr<IteratorBase> MakeIteratorInternal(
|
|
||||||
const string& prefix) const override {
|
|
||||||
return std::unique_ptr<IteratorBase>(
|
|
||||||
new Iterator({this, strings::StrCat(prefix, "::Kinesis")}));
|
|
||||||
}
|
|
||||||
|
|
||||||
const DataTypeVector& output_dtypes() const override {
|
|
||||||
static DataTypeVector* dtypes = new DataTypeVector({DT_STRING});
|
|
||||||
return *dtypes;
|
|
||||||
}
|
|
||||||
|
|
||||||
const std::vector<PartialTensorShape>& output_shapes() const override {
|
|
||||||
static std::vector<PartialTensorShape>* shapes =
|
|
||||||
new std::vector<PartialTensorShape>({{}});
|
|
||||||
return *shapes;
|
|
||||||
}
|
|
||||||
|
|
||||||
string DebugString() const override { return "KinesisDatasetOp::Dataset"; }
|
|
||||||
|
|
||||||
protected:
|
|
||||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
|
||||||
DatasetGraphDefBuilder* b,
|
|
||||||
Node** output) const override {
|
|
||||||
Node* stream = nullptr;
|
|
||||||
TF_RETURN_IF_ERROR(b->AddScalar(stream_, &stream));
|
|
||||||
Node* shard = nullptr;
|
|
||||||
TF_RETURN_IF_ERROR(b->AddScalar(shard_, &shard));
|
|
||||||
Node* read_indefinitely = nullptr;
|
|
||||||
TF_RETURN_IF_ERROR(b->AddScalar(read_indefinitely_, &read_indefinitely));
|
|
||||||
Node* interval = nullptr;
|
|
||||||
TF_RETURN_IF_ERROR(b->AddScalar(interval_, &interval));
|
|
||||||
TF_RETURN_IF_ERROR(b->AddDataset(
|
|
||||||
this, {stream, shard, read_indefinitely, interval}, output));
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
class Iterator : public DatasetIterator<Dataset> {
|
|
||||||
public:
|
|
||||||
explicit Iterator(const Params& params)
|
|
||||||
: DatasetIterator<Dataset>(params),
|
|
||||||
client_(nullptr, ShutdownClient) {}
|
|
||||||
|
|
||||||
Status GetNextInternal(IteratorContext* ctx,
|
|
||||||
std::vector<Tensor>* out_tensors,
|
|
||||||
bool* end_of_sequence) override {
|
|
||||||
mutex_lock l(mu_);
|
|
||||||
if (iterator_ == "") {
|
|
||||||
TF_RETURN_IF_ERROR(SetupStreamsLocked());
|
|
||||||
}
|
|
||||||
do {
|
|
||||||
Aws::Kinesis::Model::GetRecordsRequest request;
|
|
||||||
auto outcome = client_->GetRecords(
|
|
||||||
request.WithShardIterator(iterator_).WithLimit(1));
|
|
||||||
if (!outcome.IsSuccess()) {
|
|
||||||
return errors::Unknown(outcome.GetError().GetExceptionName(), ": ",
|
|
||||||
outcome.GetError().GetMessage());
|
|
||||||
}
|
|
||||||
if (outcome.GetResult().GetRecords().size() == 0) {
|
|
||||||
// If no records were returned then nothing is available at the
|
|
||||||
// moment.
|
|
||||||
if (!dataset()->read_indefinitely_) {
|
|
||||||
*end_of_sequence = true;
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
// Continue the loop after a period of time.
|
|
||||||
ctx->env()->SleepForMicroseconds(dataset()->interval_);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if (outcome.GetResult().GetRecords().size() != 1) {
|
|
||||||
return errors::Unknown("invalid number of records ",
|
|
||||||
outcome.GetResult().GetRecords().size(),
|
|
||||||
" returned");
|
|
||||||
}
|
|
||||||
|
|
||||||
iterator_ = outcome.GetResult().GetNextShardIterator();
|
|
||||||
|
|
||||||
const auto& data = outcome.GetResult().GetRecords()[0].GetData();
|
|
||||||
StringPiece value(
|
|
||||||
reinterpret_cast<const char*>(data.GetUnderlyingData()),
|
|
||||||
data.GetLength());
|
|
||||||
Tensor value_tensor(ctx->allocator({}), DT_STRING, {});
|
|
||||||
value_tensor.scalar<std::string>()() = std::string(value);
|
|
||||||
out_tensors->emplace_back(std::move(value_tensor));
|
|
||||||
|
|
||||||
*end_of_sequence = false;
|
|
||||||
return Status::OK();
|
|
||||||
} while (true);
|
|
||||||
}
|
|
||||||
|
|
||||||
protected:
|
|
||||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
|
||||||
return errors::Unimplemented("SaveInternal is currently not supported");
|
|
||||||
}
|
|
||||||
|
|
||||||
Status RestoreInternal(IteratorContext* ctx,
|
|
||||||
IteratorStateReader* reader) override {
|
|
||||||
return errors::Unimplemented(
|
|
||||||
"RestoreInternal is currently not supported");
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
// Sets up Kinesis streams to read from.
|
|
||||||
Status SetupStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
|
||||||
AwsInitAPI();
|
|
||||||
client_.reset(
|
|
||||||
new Aws::Kinesis::KinesisClient(GetDefaultClientConfig()));
|
|
||||||
|
|
||||||
Aws::Kinesis::Model::DescribeStreamRequest request;
|
|
||||||
auto outcome = client_->DescribeStream(
|
|
||||||
request.WithStreamName(dataset()->stream_.c_str()));
|
|
||||||
if (!outcome.IsSuccess()) {
|
|
||||||
return errors::Unknown(outcome.GetError().GetExceptionName(), ": ",
|
|
||||||
outcome.GetError().GetMessage());
|
|
||||||
}
|
|
||||||
Aws::String shard;
|
|
||||||
Aws::String sequence;
|
|
||||||
if (dataset()->shard_ == "") {
|
|
||||||
if (outcome.GetResult().GetStreamDescription().GetShards().size() !=
|
|
||||||
1) {
|
|
||||||
return errors::InvalidArgument(
|
|
||||||
"shard has to be provided unless the stream only have one "
|
|
||||||
"shard, there are ",
|
|
||||||
outcome.GetResult().GetStreamDescription().GetShards().size(),
|
|
||||||
" shards in stream ", dataset()->stream_);
|
|
||||||
}
|
|
||||||
shard = outcome.GetResult()
|
|
||||||
.GetStreamDescription()
|
|
||||||
.GetShards()[0]
|
|
||||||
.GetShardId();
|
|
||||||
sequence = outcome.GetResult()
|
|
||||||
.GetStreamDescription()
|
|
||||||
.GetShards()[0]
|
|
||||||
.GetSequenceNumberRange()
|
|
||||||
.GetStartingSequenceNumber();
|
|
||||||
} else {
|
|
||||||
for (const auto& entry :
|
|
||||||
outcome.GetResult().GetStreamDescription().GetShards()) {
|
|
||||||
if (entry.GetShardId() == dataset()->shard_.c_str()) {
|
|
||||||
shard = entry.GetShardId();
|
|
||||||
sequence =
|
|
||||||
entry.GetSequenceNumberRange().GetStartingSequenceNumber();
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (shard == "") {
|
|
||||||
return errors::InvalidArgument("no shard ", dataset()->shard_,
|
|
||||||
" in stream ", dataset()->stream_);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Aws::Kinesis::Model::GetShardIteratorRequest iterator_request;
|
|
||||||
auto iterator_outcome = client_->GetShardIterator(
|
|
||||||
iterator_request.WithStreamName(dataset()->stream_.c_str())
|
|
||||||
.WithShardId(shard)
|
|
||||||
.WithShardIteratorType(
|
|
||||||
Aws::Kinesis::Model::ShardIteratorType::AT_SEQUENCE_NUMBER)
|
|
||||||
.WithStartingSequenceNumber(sequence));
|
|
||||||
if (!iterator_outcome.IsSuccess()) {
|
|
||||||
return errors::Unknown(iterator_outcome.GetError().GetExceptionName(),
|
|
||||||
": ",
|
|
||||||
iterator_outcome.GetError().GetMessage());
|
|
||||||
}
|
|
||||||
iterator_ = iterator_outcome.GetResult().GetShardIterator();
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
mutex mu_;
|
|
||||||
Aws::String iterator_ GUARDED_BY(mu_);
|
|
||||||
std::unique_ptr<Aws::Kinesis::KinesisClient, decltype(&ShutdownClient)>
|
|
||||||
client_ GUARDED_BY(mu_);
|
|
||||||
};
|
|
||||||
|
|
||||||
const std::string stream_;
|
|
||||||
const std::string shard_;
|
|
||||||
const bool read_indefinitely_;
|
|
||||||
const int64 interval_;
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("KinesisDataset").Device(DEVICE_CPU),
|
|
||||||
KinesisDatasetOp);
|
|
||||||
|
|
||||||
} // namespace tensorflow
|
|
@ -1,42 +0,0 @@
|
|||||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
|
|
||||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
|
||||||
#include "tensorflow/core/framework/op.h"
|
|
||||||
#include "tensorflow/core/framework/shape_inference.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
|
|
||||||
REGISTER_OP("KinesisDataset")
|
|
||||||
.Input("stream: string")
|
|
||||||
.Input("shard: string")
|
|
||||||
.Input("read_indefinitely: bool")
|
|
||||||
.Input("interval: int64")
|
|
||||||
.Output("handle: variant")
|
|
||||||
.SetIsStateful()
|
|
||||||
.SetShapeFn(shape_inference::ScalarShape)
|
|
||||||
.Doc(R"doc(
|
|
||||||
Creates a dataset that emits the messages of one or more Kinesis topics.
|
|
||||||
|
|
||||||
stream: A `tf.string` tensor containing the name of the stream.
|
|
||||||
shard: A `tf.string` tensor containing the id of the shard.
|
|
||||||
read_indefinitely: If `True`, the Kinesis dataset will keep retry
|
|
||||||
again on `EOF` after the `interval` period. If `False`, then
|
|
||||||
the dataset will stop on `EOF`. The default value is `True`.
|
|
||||||
interval: The interval for the Kinesis Client to wait before
|
|
||||||
it tries to get records again (in millisecond).
|
|
||||||
)doc");
|
|
||||||
|
|
||||||
} // namespace tensorflow
|
|
@ -1,142 +0,0 @@
|
|||||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
|
|
||||||
# use this file except in compliance with the License. You may obtain a copy of
|
|
||||||
# the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
# License for the specific language governing permissions and limitations under
|
|
||||||
# the License.
|
|
||||||
# ==============================================================================
|
|
||||||
"""Tests for KinesisDataset.
|
|
||||||
NOTE: boto3 is needed and the test has to be invoked manually:
|
|
||||||
```
|
|
||||||
$ bazel test -s --verbose_failures --config=opt \
|
|
||||||
--action_env=AWS_ACCESS_KEY_ID=XXXXXX \
|
|
||||||
--action_env=AWS_SECRET_ACCESS_KEY=XXXXXX \
|
|
||||||
//tensorflow/contrib/kinesis:kinesis_test
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import boto3
|
|
||||||
|
|
||||||
from tensorflow.contrib.kinesis.python.ops import kinesis_dataset_ops
|
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
|
||||||
from tensorflow.python.data.ops import iterator_ops
|
|
||||||
from tensorflow.python.framework import dtypes
|
|
||||||
from tensorflow.python.framework import errors
|
|
||||||
from tensorflow.python.ops import array_ops
|
|
||||||
from tensorflow.python.platform import test
|
|
||||||
|
|
||||||
|
|
||||||
class KinesisDatasetTest(test.TestCase):
|
|
||||||
|
|
||||||
def testKinesisDatasetOneShard(self):
|
|
||||||
client = boto3.client('kinesis', region_name='us-east-1')
|
|
||||||
|
|
||||||
# Setup the Kinesis with 1 shard.
|
|
||||||
stream_name = "tf_kinesis_test_1"
|
|
||||||
client.create_stream(StreamName=stream_name, ShardCount=1)
|
|
||||||
# Wait until stream exists, default is 10 * 18 seconds.
|
|
||||||
client.get_waiter('stream_exists').wait(StreamName=stream_name)
|
|
||||||
for i in range(10):
|
|
||||||
data = "D" + str(i)
|
|
||||||
client.put_record(
|
|
||||||
StreamName=stream_name, Data=data, PartitionKey="TensorFlow" + str(i))
|
|
||||||
|
|
||||||
stream = array_ops.placeholder(dtypes.string, shape=[])
|
|
||||||
num_epochs = array_ops.placeholder(dtypes.int64, shape=[])
|
|
||||||
batch_size = array_ops.placeholder(dtypes.int64, shape=[])
|
|
||||||
|
|
||||||
repeat_dataset = kinesis_dataset_ops.KinesisDataset(
|
|
||||||
stream, read_indefinitely=False).repeat(num_epochs)
|
|
||||||
batch_dataset = repeat_dataset.batch(batch_size)
|
|
||||||
|
|
||||||
iterator = iterator_ops.Iterator.from_structure(
|
|
||||||
dataset_ops.get_legacy_output_types(batch_dataset))
|
|
||||||
init_op = iterator.make_initializer(repeat_dataset)
|
|
||||||
init_batch_op = iterator.make_initializer(batch_dataset)
|
|
||||||
get_next = iterator.get_next()
|
|
||||||
|
|
||||||
with self.cached_session() as sess:
|
|
||||||
# Basic test: read from shard 0 of stream 1.
|
|
||||||
sess.run(init_op, feed_dict={stream: stream_name, num_epochs: 1})
|
|
||||||
for i in range(10):
|
|
||||||
self.assertEqual("D" + str(i), sess.run(get_next))
|
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
|
||||||
sess.run(get_next)
|
|
||||||
|
|
||||||
client.delete_stream(StreamName=stream_name)
|
|
||||||
# Wait until stream deleted, default is 10 * 18 seconds.
|
|
||||||
client.get_waiter('stream_not_exists').wait(StreamName=stream_name)
|
|
||||||
|
|
||||||
def testKinesisDatasetTwoShards(self):
|
|
||||||
client = boto3.client('kinesis', region_name='us-east-1')
|
|
||||||
|
|
||||||
# Setup the Kinesis with 2 shards.
|
|
||||||
stream_name = "tf_kinesis_test_2"
|
|
||||||
client.create_stream(StreamName=stream_name, ShardCount=2)
|
|
||||||
# Wait until stream exists, default is 10 * 18 seconds.
|
|
||||||
client.get_waiter('stream_exists').wait(StreamName=stream_name)
|
|
||||||
|
|
||||||
for i in range(10):
|
|
||||||
data = "D" + str(i)
|
|
||||||
client.put_record(
|
|
||||||
StreamName=stream_name, Data=data, PartitionKey="TensorFlow" + str(i))
|
|
||||||
response = client.describe_stream(StreamName=stream_name)
|
|
||||||
shard_id_0 = response["StreamDescription"]["Shards"][0]["ShardId"]
|
|
||||||
shard_id_1 = response["StreamDescription"]["Shards"][1]["ShardId"]
|
|
||||||
|
|
||||||
stream = array_ops.placeholder(dtypes.string, shape=[])
|
|
||||||
shard = array_ops.placeholder(dtypes.string, shape=[])
|
|
||||||
num_epochs = array_ops.placeholder(dtypes.int64, shape=[])
|
|
||||||
batch_size = array_ops.placeholder(dtypes.int64, shape=[])
|
|
||||||
|
|
||||||
repeat_dataset = kinesis_dataset_ops.KinesisDataset(
|
|
||||||
stream, shard, read_indefinitely=False).repeat(num_epochs)
|
|
||||||
batch_dataset = repeat_dataset.batch(batch_size)
|
|
||||||
|
|
||||||
iterator = iterator_ops.Iterator.from_structure(
|
|
||||||
dataset_ops.get_legacy_output_types(batch_dataset))
|
|
||||||
init_op = iterator.make_initializer(repeat_dataset)
|
|
||||||
init_batch_op = iterator.make_initializer(batch_dataset)
|
|
||||||
get_next = iterator.get_next()
|
|
||||||
|
|
||||||
data = []
|
|
||||||
with self.cached_session() as sess:
|
|
||||||
# Basic test: read from shard 0 of stream 2.
|
|
||||||
sess.run(
|
|
||||||
init_op, feed_dict={
|
|
||||||
stream: stream_name, shard: shard_id_0, num_epochs: 1})
|
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
|
||||||
# Use range(11) to guarantee the OutOfRangeError.
|
|
||||||
for i in range(11):
|
|
||||||
data.append(sess.run(get_next))
|
|
||||||
|
|
||||||
# Basic test: read from shard 1 of stream 2.
|
|
||||||
sess.run(
|
|
||||||
init_op, feed_dict={
|
|
||||||
stream: stream_name, shard: shard_id_1, num_epochs: 1})
|
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
|
||||||
# Use range(11) to guarantee the OutOfRangeError.
|
|
||||||
for i in range(11):
|
|
||||||
data.append(sess.run(get_next))
|
|
||||||
|
|
||||||
data.sort()
|
|
||||||
self.assertEqual(data, ["D" + str(i) for i in range(10)])
|
|
||||||
|
|
||||||
client.delete_stream(StreamName=stream_name)
|
|
||||||
# Wait until stream deleted, default is 10 * 18 seconds.
|
|
||||||
client.get_waiter('stream_not_exists').wait(StreamName=stream_name)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test.main()
|
|
@ -1,90 +0,0 @@
|
|||||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
"""Kinesis Dataset."""
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
from tensorflow.contrib.kinesis.python.ops import gen_dataset_ops
|
|
||||||
from tensorflow.contrib.kinesis.python.ops import kinesis_op_loader # pylint: disable=unused-import
|
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
|
||||||
from tensorflow.python.framework import dtypes
|
|
||||||
from tensorflow.python.framework import ops
|
|
||||||
from tensorflow.python.framework import tensor_spec
|
|
||||||
from tensorflow.python.util import deprecation
|
|
||||||
|
|
||||||
|
|
||||||
class KinesisDataset(dataset_ops.DatasetSource):
|
|
||||||
"""A Kinesis Dataset that consumes the message.
|
|
||||||
|
|
||||||
Kinesis is a managed service provided by AWS for data streaming.
|
|
||||||
This dataset reads messages from Kinesis with each message presented
|
|
||||||
as a `tf.string`.
|
|
||||||
|
|
||||||
For example, we can construct and use the KinesisDataset as follows:
|
|
||||||
```python
|
|
||||||
tf.compat.v1.enable_eager_execution()
|
|
||||||
|
|
||||||
dataset = tf.contrib.kinesis.KinesisDataset(
|
|
||||||
"kinesis_stream_name", read_indefinitely=False)
|
|
||||||
for element in dataset:
|
|
||||||
print(element)
|
|
||||||
```
|
|
||||||
|
|
||||||
Since Kinesis is a data streaming service, data may not be available
|
|
||||||
at the time it is being read. The argument `read_indefinitely` is
|
|
||||||
used to control the behavior in this situation. If `read_indefinitely`
|
|
||||||
is `True`, then `KinesisDataset` will keep retrying to retrieve data
|
|
||||||
from the stream. If `read_indefinitely` is `False`, an `OutOfRangeError`
|
|
||||||
is returned immediately instead.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@deprecation.deprecated(
|
|
||||||
None,
|
|
||||||
"tf.contrib.kinesis will be removed in 2.0, the support for Kinesis "
|
|
||||||
"will continue to be provided through the tensorflow/io GitHub project.")
|
|
||||||
def __init__(self,
|
|
||||||
stream,
|
|
||||||
shard="",
|
|
||||||
read_indefinitely=True,
|
|
||||||
interval=100000):
|
|
||||||
"""Create a KinesisDataset.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
stream: A `tf.string` tensor containing the name of the stream.
|
|
||||||
shard: A `tf.string` tensor containing the id of the shard.
|
|
||||||
read_indefinitely: If `True`, the Kinesis dataset will keep retry
|
|
||||||
again on `EOF` after the `interval` period. If `False`, then
|
|
||||||
the dataset will stop on `EOF`. The default value is `True`.
|
|
||||||
interval: The interval for the Kinesis Client to wait before
|
|
||||||
it tries to get records again (in millisecond).
|
|
||||||
"""
|
|
||||||
self._stream = ops.convert_to_tensor(
|
|
||||||
stream, dtype=dtypes.string, name="stream")
|
|
||||||
self._shard = ops.convert_to_tensor(
|
|
||||||
shard, dtype=dtypes.string, name="shard")
|
|
||||||
self._read_indefinitely = ops.convert_to_tensor(
|
|
||||||
read_indefinitely, dtype=dtypes.bool, name="read_indefinitely")
|
|
||||||
self._interval = ops.convert_to_tensor(
|
|
||||||
interval, dtype=dtypes.int64, name="interval")
|
|
||||||
super(KinesisDataset, self).__init__(self._as_variant_tensor())
|
|
||||||
|
|
||||||
def _as_variant_tensor(self):
|
|
||||||
return gen_dataset_ops.kinesis_dataset(
|
|
||||||
self._stream, self._shard, self._read_indefinitely, self._interval)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def element_spec(self):
|
|
||||||
return tensor_spec.TensorSpec([], dtypes.string)
|
|
@ -1,24 +0,0 @@
|
|||||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ==============================================================================
|
|
||||||
"""Python helper for loading kinesis ops and kernels."""
|
|
||||||
from __future__ import absolute_import
|
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
from tensorflow.contrib.util import loader
|
|
||||||
from tensorflow.python.platform import resource_loader
|
|
||||||
|
|
||||||
_dataset_ops = loader.load_op_library(
|
|
||||||
resource_loader.get_path_to_datafile("../../_dataset_ops.so"))
|
|
Loading…
Reference in New Issue
Block a user