Add EncodeBase64 and DecodeBase64 ops.
Change: 130534089
This commit is contained in:
parent
20f029cc37
commit
60662676e4
@ -887,6 +887,7 @@ cc_library(
|
||||
"lib/random/random.h",
|
||||
"lib/random/random_distributions.h",
|
||||
"lib/random/weighted_picker.h",
|
||||
"lib/strings/base64.h",
|
||||
"lib/strings/ordered_code.h",
|
||||
"lib/strings/proto_text_util.h",
|
||||
"lib/strings/regexp.h",
|
||||
@ -1320,6 +1321,7 @@ tf_cc_tests(
|
||||
"lib/random/random_distributions_test.cc",
|
||||
"lib/random/random_test.cc",
|
||||
"lib/random/simple_philox_test.cc",
|
||||
"lib/strings/base64_test.cc",
|
||||
"lib/strings/numbers_test.cc",
|
||||
"lib/strings/scanner_test.cc",
|
||||
"lib/strings/str_util_test.cc",
|
||||
|
@ -1706,6 +1706,7 @@ tf_kernel_libraries(
|
||||
"string_join_op",
|
||||
"string_split_op",
|
||||
"as_string_op",
|
||||
"base64_ops",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
|
77
tensorflow/core/kernels/base64_ops.cc
Normal file
77
tensorflow/core/kernels/base64_ops.cc
Normal file
@ -0,0 +1,77 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// See docs in ../ops/string_ops.cc.
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/strings/base64.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class EncodeBase64Op : public OpKernel {
|
||||
public:
|
||||
explicit EncodeBase64Op(OpKernelConstruction* context) : OpKernel(context) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("pad", &pad_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor& input_tensor = context->input(0);
|
||||
Tensor* output_tensor = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
|
||||
&output_tensor));
|
||||
|
||||
auto input = input_tensor.flat<string>();
|
||||
auto output = output_tensor->flat<string>();
|
||||
|
||||
for (int64 i = 0; i < input.dimension(0); ++i) {
|
||||
OP_REQUIRES_OK(context, Base64Encode(input(i), pad_, &output(i)));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
bool pad_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("EncodeBase64").Device(DEVICE_CPU),
|
||||
EncodeBase64Op);
|
||||
|
||||
class DecodeBase64Op : public OpKernel {
|
||||
public:
|
||||
using OpKernel::OpKernel;
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor& input_tensor = context->input(0);
|
||||
Tensor* output_tensor = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
|
||||
&output_tensor));
|
||||
|
||||
auto input = input_tensor.flat<string>();
|
||||
auto output = output_tensor->flat<string>();
|
||||
|
||||
for (int64 i = 0; i < input.dimension(0); ++i) {
|
||||
OP_REQUIRES_OK(context, Base64Decode(input(i), &output(i)));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("DecodeBase64").Device(DEVICE_CPU),
|
||||
DecodeBase64Op);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
217
tensorflow/core/lib/strings/base64.cc
Normal file
217
tensorflow/core/lib/strings/base64.cc
Normal file
@ -0,0 +1,217 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/lib/strings/base64.h"
|
||||
|
||||
#include <memory>
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
// This array must have signed type.
|
||||
// clang-format off
|
||||
constexpr int8 kBase64Bytes[128] = {
|
||||
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
|
||||
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
|
||||
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
|
||||
-1, -1, -1, -1, -1, -1, -1, -1, -1, 0x3E, -1, -1,
|
||||
0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3A, 0x3B, 0x3C, 0x3D, -1, -1,
|
||||
-1, -1, -1, -1, -1, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06,
|
||||
0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12,
|
||||
0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, -1, -1, -1, -1, 0x3F,
|
||||
-1, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20, 0x21, 0x22, 0x23, 0x24,
|
||||
0x25, 0x26, 0x27, 0x28, 0x29, 0x2A, 0x2B, 0x2C, 0x2D, 0x2E, 0x2F, 0x30,
|
||||
0x31, 0x32, 0x33, -1, -1, -1, -1, -1};
|
||||
// clang-format on
|
||||
|
||||
constexpr char kBase64UrlSafeChars[65] =
|
||||
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
|
||||
|
||||
constexpr char kPadChar = '=';
|
||||
|
||||
// Converts a char (8 bits) into a 6-bit value for decoding. If the input char
|
||||
// is invalid for base64 encoding, the return value has at least its upper 25
|
||||
// bits set.
|
||||
inline uint32 Convert(char x) {
|
||||
// If x < 128, then we look up x in the table. If x is valid, then the table
|
||||
// will have a value <= 0x3F, otherwise the table will have -1. If x >= 128,
|
||||
// we still do some table lookup, but the value is ignored since we explicitly
|
||||
// set the high bit of y to 1. Either way, y is negative (high bit set) in
|
||||
// case of error.
|
||||
const int8 y = kBase64Bytes[x & 0x7F] | (x & 0x80);
|
||||
// Casting from int8 to int32 preserves sign by sign extension. If y was
|
||||
// negative, at least its 25 high bits of the return value are set.
|
||||
const int32 z = static_cast<int32>(y);
|
||||
return static_cast<uint32>(z);
|
||||
}
|
||||
|
||||
Status DecodeOneChar(const char* codes, char* result) {
|
||||
const uint32 packed = (Convert(codes[0]) << 2) |
|
||||
(Convert(codes[1]) >> 4);
|
||||
// Convert() return value has upper 25 bits set if input is invalid.
|
||||
// Therefore `packed` has high bits set iff at least one of code is invalid.
|
||||
if (TF_PREDICT_FALSE((packed & 0xFF000000) != 0)) {
|
||||
return errors::InvalidArgument("Invalid character found in base64.");
|
||||
}
|
||||
*result = static_cast<char>(packed);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DecodeTwoChars(const char* codes, char* result) {
|
||||
const uint32 packed = (Convert(codes[0]) << 10) |
|
||||
(Convert(codes[1]) << 4) |
|
||||
(Convert(codes[2]) >> 2);
|
||||
if (TF_PREDICT_FALSE((packed & 0xFF000000) != 0)) {
|
||||
return errors::InvalidArgument("Invalid character found in base64.");
|
||||
}
|
||||
result[0] = static_cast<char>(packed >> 8);
|
||||
result[1] = static_cast<char>(packed);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DecodeThreeChars(const char* codes, char* result) {
|
||||
const uint32 packed = (Convert(codes[0]) << 18) |
|
||||
(Convert(codes[1]) << 12) |
|
||||
(Convert(codes[2]) << 6) |
|
||||
(Convert(codes[3]));
|
||||
if (TF_PREDICT_FALSE((packed & 0xFF000000) != 0)) {
|
||||
return errors::InvalidArgument("Invalid character found in base64.");
|
||||
}
|
||||
result[0] = static_cast<char>(packed >> 16);
|
||||
result[1] = static_cast<char>(packed >> 8);
|
||||
result[2] = static_cast<char>(packed);
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Status Base64Decode(StringPiece data, string* decoded) {
|
||||
if (decoded == nullptr) {
|
||||
return errors::Internal("'decoded' cannot be nullptr.");
|
||||
}
|
||||
|
||||
if (data.empty()) {
|
||||
decoded->clear();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// max_decoded_size may overestimate by up to 3 bytes.
|
||||
const size_t max_decoded_size = 3 * (data.size() / 4) + 3;
|
||||
std::unique_ptr<char[]> buffer(new char[max_decoded_size]);
|
||||
char* current = buffer.get();
|
||||
if (current == nullptr) {
|
||||
return errors::ResourceExhausted(
|
||||
"Failed to allocate buffer for decoded string.");
|
||||
}
|
||||
|
||||
const char* b64 = data.data();
|
||||
const char* end = data.data() + data.size();
|
||||
|
||||
while (end - b64 > 4) {
|
||||
TF_RETURN_IF_ERROR(DecodeThreeChars(b64, current));
|
||||
b64 += 4;
|
||||
current += 3;
|
||||
}
|
||||
|
||||
if (end - b64 == 4) {
|
||||
// The data length is a multiple of 4. Check for padding.
|
||||
// Base64 cannot have more than 2 paddings.
|
||||
if (b64[2] == kPadChar && b64[3] == kPadChar) {
|
||||
end -= 2;
|
||||
}
|
||||
if (b64[2] != kPadChar && b64[3] == kPadChar) {
|
||||
end -= 1;
|
||||
}
|
||||
}
|
||||
|
||||
switch (end - b64) {
|
||||
case 4:
|
||||
TF_RETURN_IF_ERROR(DecodeThreeChars(b64, current));
|
||||
current += 3;
|
||||
break;
|
||||
case 3:
|
||||
TF_RETURN_IF_ERROR(DecodeTwoChars(b64, current));
|
||||
current += 2;
|
||||
break;
|
||||
case 2:
|
||||
TF_RETURN_IF_ERROR(DecodeOneChar(b64, current));
|
||||
current += 1;
|
||||
break;
|
||||
default: // case 1
|
||||
// We may check this condition early by checking data.size() % 4 == 1.
|
||||
return errors::InvalidArgument(
|
||||
"Base64 string length cannot be 1 modulo 4.");
|
||||
}
|
||||
|
||||
decoded->assign(buffer.get(), current - buffer.get());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Base64Encode(StringPiece source, string* encoded) {
|
||||
return Base64Encode(source, false, encoded);
|
||||
}
|
||||
|
||||
Status Base64Encode(StringPiece source, bool with_padding, string* encoded) {
|
||||
const char* const base64_chars = kBase64UrlSafeChars;
|
||||
if (encoded == nullptr) {
|
||||
return errors::Internal("'encoded' cannot be nullptr.");
|
||||
}
|
||||
|
||||
// max_encoded_size may overestimate by up to 4 bytes.
|
||||
const size_t max_encoded_size = 4 * (source.size() / 3) + 4;
|
||||
std::unique_ptr<char[]> buffer(new char[max_encoded_size]);
|
||||
char* current = buffer.get();
|
||||
if (current == nullptr) {
|
||||
return errors::ResourceExhausted(
|
||||
"Failed to allocate buffer for encoded string.");
|
||||
}
|
||||
|
||||
const char* data = source.data();
|
||||
const char* const end = source.data() + source.size();
|
||||
|
||||
// Encode each block.
|
||||
while (end - data >= 3) {
|
||||
*current++ = base64_chars[(data[0] >> 2) & 0x3F];
|
||||
*current++ =
|
||||
base64_chars[((data[0] & 0x03) << 4) | ((data[1] >> 4) & 0x0F)];
|
||||
*current++ =
|
||||
base64_chars[((data[1] & 0x0F) << 2) | ((data[2] >> 6) & 0x03)];
|
||||
*current++ = base64_chars[data[2] & 0x3F];
|
||||
|
||||
data += 3;
|
||||
}
|
||||
|
||||
// Take care of the tail.
|
||||
if (end - data == 2) {
|
||||
*current++ = base64_chars[(data[0] >> 2) & 0x3F];
|
||||
*current++ =
|
||||
base64_chars[((data[0] & 0x03) << 4) | ((data[1] >> 4) & 0x0F)];
|
||||
*current++ = base64_chars[(data[1] & 0x0F) << 2];
|
||||
if (with_padding) {
|
||||
*current++ = kPadChar;
|
||||
}
|
||||
} else if (end - data == 1) {
|
||||
*current++ = base64_chars[(data[0] >> 2) & 0x3F];
|
||||
*current++ = base64_chars[(data[0] & 0x03) << 4];
|
||||
if (with_padding) {
|
||||
*current++ = kPadChar;
|
||||
*current++ = kPadChar;
|
||||
}
|
||||
}
|
||||
|
||||
encoded->assign(buffer.get(), current - buffer.get());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -13,24 +13,25 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_PLATFORM_B64_H_
|
||||
#define TENSORFLOW_CORE_PLATFORM_B64_H_
|
||||
#ifndef TENSORFLOW_LIB_STRINGS_B64_H_
|
||||
#define TENSORFLOW_LIB_STRINGS_B64_H_
|
||||
|
||||
#include <string>
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
/// \brief Converts data into base64 encoding.
|
||||
/// \brief Converts data into web-safe base64 encoding.
|
||||
///
|
||||
/// See https://en.wikipedia.org/wiki/Base64
|
||||
Status Base64Encode(StringPiece data, string* encoded);
|
||||
Status Base64Encode(StringPiece data, bool with_padding, string* encoded);
|
||||
Status Base64Encode(StringPiece data, string* encoded); // with_padding=false.
|
||||
|
||||
/// \brief Converts data from base64 encoding.
|
||||
/// \brief Converts data from web-safe base64 encoding.
|
||||
///
|
||||
/// See https://en.wikipedia.org/wiki/Base64
|
||||
Status Base64Decode(StringPiece data, string* decoded);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_PLATFORM_B64_H_
|
||||
#endif // TENSORFLOW_LIB_STRINGS_B64_H_
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/platform/cloud/base64.h"
|
||||
#include "tensorflow/core/lib/strings/base64.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
@ -247,4 +247,38 @@ shape: a length-2 vector of int64 representing the shape of the sparse
|
||||
of tokens in a single input entry.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("EncodeBase64")
|
||||
.Input("input: string")
|
||||
.Output("output: string")
|
||||
.Attr("pad: bool = false")
|
||||
.SetShapeFn(shape_inference::UnchangedShape)
|
||||
.Doc(R"doc(
|
||||
Encode strings into web-safe base64 format.
|
||||
|
||||
Refer to the following article for more information on base64 format:
|
||||
en.wikipedia.org/wiki/Base64. Base64 strings may have padding with '=' at the
|
||||
end so that the encoded has length multiple of 4. See Padding section of the
|
||||
link above.
|
||||
|
||||
Web-safe means that the encoder uses - and _ instead of + and /.
|
||||
|
||||
input: Strings to be encoded.
|
||||
output: Input strings encoded in base64.
|
||||
pad: Bool whether padding is applied at the ends.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("DecodeBase64")
|
||||
.Input("input: string")
|
||||
.Output("output: string")
|
||||
.SetShapeFn(shape_inference::UnchangedShape)
|
||||
.Doc(R"doc(
|
||||
Decode web-safe base64-encoded strings.
|
||||
|
||||
Input may or may not have padding at the end. See EncodeBase64 for padding.
|
||||
Web-safe means that input must use - and _ instead of + and /.
|
||||
|
||||
input: Base64 strings to decode.
|
||||
output: Decoded strings.
|
||||
)doc");
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -83,10 +83,10 @@ cc_library(
|
||||
"google_auth_provider.h",
|
||||
],
|
||||
deps = [
|
||||
":base64",
|
||||
":http_request",
|
||||
":oauth_client",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"@jsoncpp_git//:jsoncpp",
|
||||
],
|
||||
)
|
||||
@ -100,27 +100,14 @@ cc_library(
|
||||
"oauth_client.h",
|
||||
],
|
||||
deps = [
|
||||
":base64",
|
||||
":http_request",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"@boringssl_git//:crypto",
|
||||
"@jsoncpp_git//:jsoncpp",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "base64",
|
||||
srcs = [
|
||||
"base64.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"base64.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "retrying_file_system",
|
||||
srcs = [
|
||||
@ -165,7 +152,6 @@ tf_cc_test(
|
||||
"testdata/service_account_public_key.txt",
|
||||
],
|
||||
deps = [
|
||||
":base64",
|
||||
":http_request_fake",
|
||||
":oauth_client",
|
||||
"//tensorflow/core:lib",
|
||||
@ -176,18 +162,6 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "base64_test",
|
||||
size = "small",
|
||||
deps = [
|
||||
":base64",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "google_auth_provider_test",
|
||||
size = "small",
|
||||
@ -196,7 +170,6 @@ tf_cc_test(
|
||||
"testdata/service_account_credentials.json",
|
||||
],
|
||||
deps = [
|
||||
":base64",
|
||||
":google_auth_provider",
|
||||
":http_request_fake",
|
||||
":oauth_client",
|
||||
|
@ -1,221 +0,0 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/platform/cloud/base64.h"
|
||||
#include <memory>
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr signed char kBase64Bytes[] = {
|
||||
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
|
||||
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
|
||||
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
|
||||
-1, -1, -1, -1, -1, -1, -1, 0x3E, -1, -1, -1, 0x3F,
|
||||
0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3A, 0x3B, 0x3C, 0x3D, -1, -1,
|
||||
-1, 0x7F, -1, -1, -1, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06,
|
||||
0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12,
|
||||
0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, -1, -1, -1, -1, -1,
|
||||
-1, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20, 0x21, 0x22, 0x23, 0x24,
|
||||
0x25, 0x26, 0x27, 0x28, 0x29, 0x2A, 0x2B, 0x2C, 0x2D, 0x2E, 0x2F, 0x30,
|
||||
0x31, 0x32, 0x33, -1, -1, -1, -1, -1};
|
||||
|
||||
constexpr char kBase64UrlSafeChars[] =
|
||||
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
|
||||
|
||||
constexpr char kPadChar = '=';
|
||||
constexpr char kPadByte = 0x7F;
|
||||
constexpr int kMultilineLineLen = 76;
|
||||
constexpr int kMultilineNumBlocks = kMultilineLineLen / 4;
|
||||
|
||||
Status Base64Encode(StringPiece source, bool multiline, bool with_padding,
|
||||
string *encoded) {
|
||||
if (!encoded) {
|
||||
return errors::FailedPrecondition("'encoded' cannot be nullptr.");
|
||||
}
|
||||
size_t data_size = source.size();
|
||||
const char *data = source.data();
|
||||
const char *base64_chars = kBase64UrlSafeChars;
|
||||
const size_t result_projected_size =
|
||||
4 * ((data_size + 3) / 3) +
|
||||
2 * (multiline ? (data_size / (3 * kMultilineNumBlocks)) : 0) + 1;
|
||||
size_t num_blocks = 0;
|
||||
size_t i = 0;
|
||||
std::unique_ptr<char[]> result(new char[result_projected_size]);
|
||||
char *current = result.get();
|
||||
|
||||
/* Encode each block. */
|
||||
while (data_size >= 3) {
|
||||
*current++ = base64_chars[(data[i] >> 2) & 0x3F];
|
||||
*current++ =
|
||||
base64_chars[((data[i] & 0x03) << 4) | ((data[i + 1] >> 4) & 0x0F)];
|
||||
*current++ =
|
||||
base64_chars[((data[i + 1] & 0x0F) << 2) | ((data[i + 2] >> 6) & 0x03)];
|
||||
*current++ = base64_chars[data[i + 2] & 0x3F];
|
||||
|
||||
data_size -= 3;
|
||||
i += 3;
|
||||
if (multiline && (++num_blocks == kMultilineNumBlocks)) {
|
||||
*current++ = '\r';
|
||||
*current++ = '\n';
|
||||
num_blocks = 0;
|
||||
}
|
||||
}
|
||||
|
||||
/* Take care of the tail. */
|
||||
if (data_size == 2) {
|
||||
*current++ = base64_chars[(data[i] >> 2) & 0x3F];
|
||||
*current++ =
|
||||
base64_chars[((data[i] & 0x03) << 4) | ((data[i + 1] >> 4) & 0x0F)];
|
||||
*current++ = base64_chars[(data[i + 1] & 0x0F) << 2];
|
||||
if (with_padding) {
|
||||
*current++ = kPadChar;
|
||||
}
|
||||
} else if (data_size == 1) {
|
||||
*current++ = base64_chars[(data[i] >> 2) & 0x3F];
|
||||
*current++ = base64_chars[(data[i] & 0x03) << 4];
|
||||
if (with_padding) {
|
||||
*current++ = kPadChar;
|
||||
*current++ = kPadChar;
|
||||
}
|
||||
}
|
||||
|
||||
if (current < result.get() ||
|
||||
current >= result.get() + result_projected_size) {
|
||||
return errors::Internal("Unexpected encoding bug.");
|
||||
}
|
||||
*current++ = '\0';
|
||||
*encoded = result.get();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void DecodeOneChar(const unsigned char *codes, unsigned char *result,
|
||||
size_t *result_offset) {
|
||||
const uint32_t packed = ((uint32_t)codes[0] << 2) | ((uint32_t)codes[1] >> 4);
|
||||
result[(*result_offset)++] = (unsigned char)packed;
|
||||
}
|
||||
|
||||
void DecodeTwoChars(const unsigned char *codes, unsigned char *result,
|
||||
size_t *result_offset) {
|
||||
const uint32_t packed = ((uint32_t)codes[0] << 10) |
|
||||
((uint32_t)codes[1] << 4) | ((uint32_t)codes[2] >> 2);
|
||||
result[(*result_offset)++] = (unsigned char)(packed >> 8);
|
||||
result[(*result_offset)++] = (unsigned char)(packed);
|
||||
}
|
||||
|
||||
Status DecodeGroup(const unsigned char *codes, size_t num_codes,
|
||||
unsigned char *result, size_t *result_offset) {
|
||||
if (num_codes > 4) {
|
||||
return errors::FailedPrecondition("Expected 4 or fewer codes.");
|
||||
}
|
||||
|
||||
/* Short end groups that may not have padding. */
|
||||
if (num_codes == 1) {
|
||||
return errors::FailedPrecondition(
|
||||
"Invalid group. Must be at least 2 bytes.");
|
||||
}
|
||||
if (num_codes == 2) {
|
||||
DecodeOneChar(codes, result, result_offset);
|
||||
return Status::OK();
|
||||
}
|
||||
if (num_codes == 3) {
|
||||
DecodeTwoChars(codes, result, result_offset);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/* Regular 4 byte groups with padding or not. */
|
||||
if (num_codes != 4) {
|
||||
return errors::FailedPrecondition("Expected exactly 4 codes.");
|
||||
}
|
||||
if (codes[0] == kPadByte || codes[1] == kPadByte) {
|
||||
return errors::FailedPrecondition("Invalid padding detected.");
|
||||
}
|
||||
if (codes[2] == kPadByte) {
|
||||
if (codes[3] == kPadByte) {
|
||||
DecodeOneChar(codes, result, result_offset);
|
||||
} else {
|
||||
return errors::FailedPrecondition("Invalid padding detected.");
|
||||
}
|
||||
} else if (codes[3] == kPadByte) {
|
||||
DecodeTwoChars(codes, result, result_offset);
|
||||
} else {
|
||||
/* No padding. */
|
||||
const uint32_t packed = ((uint32_t)codes[0] << 18) |
|
||||
((uint32_t)codes[1] << 12) |
|
||||
((uint32_t)codes[2] << 6) | codes[3];
|
||||
result[(*result_offset)++] = (unsigned char)(packed >> 16);
|
||||
result[(*result_offset)++] = (unsigned char)(packed >> 8);
|
||||
result[(*result_offset)++] = (unsigned char)(packed);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status Base64Encode(StringPiece source, string *encoded) {
|
||||
return Base64Encode(source, false, false, encoded);
|
||||
}
|
||||
|
||||
Status Base64Decode(StringPiece data, string *decoded) {
|
||||
if (!decoded) {
|
||||
return errors::FailedPrecondition("'decoded' cannot be nullptr.");
|
||||
}
|
||||
std::unique_ptr<unsigned char[]> result(new unsigned char[data.size()]);
|
||||
unsigned char *current = result.get();
|
||||
size_t result_size = 0;
|
||||
unsigned char codes[4];
|
||||
size_t num_codes = 0;
|
||||
|
||||
const char *b64 = data.data();
|
||||
size_t b64_len = data.size();
|
||||
while (b64_len--) {
|
||||
unsigned char c = (unsigned char)(*b64++);
|
||||
signed char code;
|
||||
if (c >= sizeof(kBase64Bytes)) continue;
|
||||
if (c == '+' || c == '/') {
|
||||
return errors::FailedPrecondition(
|
||||
strings::StrCat("Invalid character for url safe base64 ", c));
|
||||
}
|
||||
if (c == '-') {
|
||||
c = '+';
|
||||
} else if (c == '_') {
|
||||
c = '/';
|
||||
}
|
||||
code = kBase64Bytes[c];
|
||||
if (code == -1) {
|
||||
if (c != '\r' && c != '\n') {
|
||||
return errors::FailedPrecondition(
|
||||
strings::StrCat("Invalid character ", c));
|
||||
}
|
||||
} else {
|
||||
codes[num_codes++] = (unsigned char)code;
|
||||
if (num_codes == 4) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
DecodeGroup(codes, num_codes, current, &result_size));
|
||||
num_codes = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (num_codes != 0) {
|
||||
TF_RETURN_IF_ERROR(DecodeGroup(codes, num_codes, current, &result_size));
|
||||
}
|
||||
*decoded = string(reinterpret_cast<char *>(result.get()), result_size);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -21,7 +21,7 @@ limitations under the License.
|
||||
#include "include/json/json.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/platform/cloud/base64.h"
|
||||
#include "tensorflow/core/lib/strings/base64.h"
|
||||
#include "tensorflow/core/platform/cloud/http_request.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
|
||||
|
@ -22,7 +22,7 @@ limitations under the License.
|
||||
#include <openssl/evp.h>
|
||||
#include <openssl/pem.h>
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/cloud/base64.h"
|
||||
#include "tensorflow/core/lib/strings/base64.h"
|
||||
#include "tensorflow/core/platform/cloud/http_request.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
|
||||
|
@ -20,8 +20,8 @@ limitations under the License.
|
||||
#include <openssl/pem.h>
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/lib/strings/base64.h"
|
||||
#include "tensorflow/core/lib/strings/scanner.h"
|
||||
#include "tensorflow/core/platform/cloud/base64.h"
|
||||
#include "tensorflow/core/platform/cloud/http_request_fake.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
@ -22,6 +22,7 @@ py_tests(
|
||||
"as_string_op_test.py",
|
||||
"attention_ops_test.py",
|
||||
"barrier_ops_test.py",
|
||||
"base64_ops_test.py",
|
||||
"bcast_ops_test.py",
|
||||
"benchmark_test.py",
|
||||
"candidate_sampler_ops_test.py",
|
||||
|
155
tensorflow/python/kernel_tests/base64_ops_test.py
Normal file
155
tensorflow/python/kernel_tests/base64_ops_test.py
Normal file
@ -0,0 +1,155 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Tests for EncodeBase64 and DecodeBase64."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import base64
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class Base64OpsTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def setUp(self):
|
||||
self._msg = array_ops.placeholder(dtype=dtypes.string)
|
||||
self._encoded_f = string_ops.encode_base64(self._msg, pad=False)
|
||||
self._decoded_f = string_ops.decode_base64(self._encoded_f)
|
||||
self._encoded_t = string_ops.encode_base64(self._msg, pad=True)
|
||||
self._decoded_t = string_ops.decode_base64(self._encoded_t)
|
||||
|
||||
def _RemovePad(self, msg, base64_msg):
|
||||
if len(msg) % 3 == 1:
|
||||
return base64_msg[:-2]
|
||||
if len(msg) % 3 == 2:
|
||||
return base64_msg[:-1]
|
||||
return base64_msg
|
||||
|
||||
def _RunTest(self, msg, pad):
|
||||
with self.test_session() as sess:
|
||||
if pad:
|
||||
encoded, decoded = sess.run([self._encoded_t, self._decoded_t],
|
||||
feed_dict={self._msg: msg})
|
||||
else:
|
||||
encoded, decoded = sess.run([self._encoded_f, self._decoded_f],
|
||||
feed_dict={self._msg: msg})
|
||||
|
||||
if not isinstance(msg, (list, tuple)):
|
||||
msg = [msg]
|
||||
encoded = [encoded]
|
||||
decoded = [decoded]
|
||||
|
||||
base64_msg = [base64.urlsafe_b64encode(m) for m in msg]
|
||||
if not pad:
|
||||
base64_msg = [self._RemovePad(m, b) for m, b in zip(msg, base64_msg)]
|
||||
|
||||
for i in range(len(msg)):
|
||||
self.assertEqual(base64_msg[i], encoded[i])
|
||||
self.assertEqual(msg[i], decoded[i])
|
||||
|
||||
def testWithPythonBase64(self):
|
||||
for pad in (False, True):
|
||||
self._RunTest(b"", pad=pad)
|
||||
|
||||
for _ in range(100):
|
||||
length = np.random.randint(1024 * 1024)
|
||||
msg = np.random.bytes(length)
|
||||
self._RunTest(msg, pad=pad)
|
||||
|
||||
def testShape(self):
|
||||
for pad in (False, True):
|
||||
for _ in range(10):
|
||||
msg = [np.random.bytes(np.random.randint(20))
|
||||
for _ in range(np.random.randint(10))]
|
||||
self._RunTest(msg, pad=pad)
|
||||
|
||||
# Zero-element, non-trivial shapes.
|
||||
for _ in range(10):
|
||||
k = np.random.randint(10)
|
||||
msg = np.empty((0, k), dtype=bytes)
|
||||
encoded = string_ops.encode_base64(msg, pad=pad)
|
||||
decoded = string_ops.decode_base64(encoded)
|
||||
|
||||
with self.test_session() as sess:
|
||||
encoded_value, decoded_value = sess.run([encoded, decoded])
|
||||
|
||||
self.assertEqual(encoded_value.shape, msg.shape)
|
||||
self.assertEqual(decoded_value.shape, msg.shape)
|
||||
|
||||
def testInvalidInput(self):
|
||||
def try_decode(enc):
|
||||
self._decoded_f.eval(feed_dict={self._encoded_f: enc})
|
||||
|
||||
with self.test_session():
|
||||
# Invalid length.
|
||||
msg = np.random.bytes(99)
|
||||
enc = base64.urlsafe_b64encode(msg)
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError, "1 modulo 4"):
|
||||
try_decode(enc + b"a")
|
||||
|
||||
# Invalid char used in encoding.
|
||||
msg = np.random.bytes(34)
|
||||
enc = base64.urlsafe_b64encode(msg)
|
||||
for i in range(len(msg)):
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
try_decode(enc[:i] + b"?" + enc[(i + 1):])
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
try_decode(enc[:i] + b"\x80" + enc[(i + 1):]) # outside ascii range.
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
try_decode(enc[:i] + b"+" + enc[(i + 1):]) # not url-safe.
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
try_decode(enc[:i] + b"/" + enc[(i + 1):]) # not url-safe.
|
||||
|
||||
# Partial padding.
|
||||
msg = np.random.bytes(34)
|
||||
enc = base64.urlsafe_b64encode(msg)
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
# enc contains == at the end. Partial padding is not allowed.
|
||||
try_decode(enc[:-1])
|
||||
|
||||
# Unnecessary padding.
|
||||
msg = np.random.bytes(33)
|
||||
enc = base64.urlsafe_b64encode(msg)
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
try_decode(enc + b"==")
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
try_decode(enc + b"===")
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
try_decode(enc + b"====")
|
||||
|
||||
# Padding in the middle. (Previous implementation was ok with this as long
|
||||
# as padding char location was 2 or 3 (mod 4).
|
||||
msg = np.random.bytes(33)
|
||||
enc = base64.urlsafe_b64encode(msg)
|
||||
for i in range(len(msg) - 1):
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
try_decode(enc[:i] + b"=" + enc[(i + 1):])
|
||||
for i in range(len(msg) - 2):
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
try_decode(enc[:i] + b"==" + enc[(i + 2):])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -37,6 +37,8 @@ string tensor.
|
||||
## Conversion
|
||||
|
||||
@@as_string
|
||||
@@encode_base64
|
||||
@@decode_base64
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
@ -118,11 +120,15 @@ ops.NoGradient("ReduceJoin")
|
||||
ops.NoGradient("StringJoin")
|
||||
ops.NoGradient("StringSplit")
|
||||
ops.NoGradient("AsString")
|
||||
ops.NoGradient("EncodeBase64")
|
||||
ops.NoGradient("DecodeBase64")
|
||||
|
||||
ops.RegisterShape("StringToHashBucket")(common_shapes.unchanged_shape)
|
||||
ops.RegisterShape("StringToHashBucketFast")(common_shapes.unchanged_shape)
|
||||
ops.RegisterShape("StringToHashBucketStrong")(common_shapes.unchanged_shape)
|
||||
ops.RegisterShape("AsString")(common_shapes.unchanged_shape)
|
||||
ops.RegisterShape("EncodeBase64")(common_shapes.unchanged_shape)
|
||||
ops.RegisterShape("DecodeBase64")(common_shapes.unchanged_shape)
|
||||
|
||||
|
||||
@ops.RegisterShape("ReduceJoin")
|
||||
|
Loading…
Reference in New Issue
Block a user