diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index a9ac82ae1c2..6f53d5f9127 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -887,6 +887,7 @@ cc_library( "lib/random/random.h", "lib/random/random_distributions.h", "lib/random/weighted_picker.h", + "lib/strings/base64.h", "lib/strings/ordered_code.h", "lib/strings/proto_text_util.h", "lib/strings/regexp.h", @@ -1320,6 +1321,7 @@ tf_cc_tests( "lib/random/random_distributions_test.cc", "lib/random/random_test.cc", "lib/random/simple_philox_test.cc", + "lib/strings/base64_test.cc", "lib/strings/numbers_test.cc", "lib/strings/scanner_test.cc", "lib/strings/str_util_test.cc", diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index a270fdefa88..04316fb0855 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -1706,6 +1706,7 @@ tf_kernel_libraries( "string_join_op", "string_split_op", "as_string_op", + "base64_ops", ], deps = [ "//tensorflow/core:framework", diff --git a/tensorflow/core/kernels/base64_ops.cc b/tensorflow/core/kernels/base64_ops.cc new file mode 100644 index 00000000000..74e6b39390a --- /dev/null +++ b/tensorflow/core/kernels/base64_ops.cc @@ -0,0 +1,77 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs in ../ops/string_ops.cc. + +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/strings/base64.h" + +namespace tensorflow { +namespace { + +class EncodeBase64Op : public OpKernel { + public: + explicit EncodeBase64Op(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("pad", &pad_)); + } + + void Compute(OpKernelContext* context) override { + const Tensor& input_tensor = context->input(0); + Tensor* output_tensor = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), + &output_tensor)); + + auto input = input_tensor.flat(); + auto output = output_tensor->flat(); + + for (int64 i = 0; i < input.dimension(0); ++i) { + OP_REQUIRES_OK(context, Base64Encode(input(i), pad_, &output(i))); + } + } + + private: + bool pad_; +}; + +REGISTER_KERNEL_BUILDER(Name("EncodeBase64").Device(DEVICE_CPU), + EncodeBase64Op); + +class DecodeBase64Op : public OpKernel { + public: + using OpKernel::OpKernel; + + void Compute(OpKernelContext* context) override { + const Tensor& input_tensor = context->input(0); + Tensor* output_tensor = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), + &output_tensor)); + + auto input = input_tensor.flat(); + auto output = output_tensor->flat(); + + for (int64 i = 0; i < input.dimension(0); ++i) { + OP_REQUIRES_OK(context, Base64Decode(input(i), &output(i))); + } + } +}; + +REGISTER_KERNEL_BUILDER(Name("DecodeBase64").Device(DEVICE_CPU), + DecodeBase64Op); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/lib/strings/base64.cc b/tensorflow/core/lib/strings/base64.cc new file mode 100644 index 00000000000..5eaa35f2b72 --- /dev/null +++ b/tensorflow/core/lib/strings/base64.cc @@ -0,0 +1,217 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/lib/strings/base64.h" + +#include +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace { +// This array must have signed type. +// clang-format off +constexpr int8 kBase64Bytes[128] = { + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, 0x3E, -1, -1, + 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3A, 0x3B, 0x3C, 0x3D, -1, -1, + -1, -1, -1, -1, -1, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, + 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12, + 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, -1, -1, -1, -1, 0x3F, + -1, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20, 0x21, 0x22, 0x23, 0x24, + 0x25, 0x26, 0x27, 0x28, 0x29, 0x2A, 0x2B, 0x2C, 0x2D, 0x2E, 0x2F, 0x30, + 0x31, 0x32, 0x33, -1, -1, -1, -1, -1}; +// clang-format on + +constexpr char kBase64UrlSafeChars[65] = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"; + +constexpr char kPadChar = '='; + +// Converts a char (8 bits) into a 6-bit value for decoding. If the input char +// is invalid for base64 encoding, the return value has at least its upper 25 +// bits set. +inline uint32 Convert(char x) { + // If x < 128, then we look up x in the table. If x is valid, then the table + // will have a value <= 0x3F, otherwise the table will have -1. If x >= 128, + // we still do some table lookup, but the value is ignored since we explicitly + // set the high bit of y to 1. Either way, y is negative (high bit set) in + // case of error. + const int8 y = kBase64Bytes[x & 0x7F] | (x & 0x80); + // Casting from int8 to int32 preserves sign by sign extension. If y was + // negative, at least its 25 high bits of the return value are set. + const int32 z = static_cast(y); + return static_cast(z); +} + +Status DecodeOneChar(const char* codes, char* result) { + const uint32 packed = (Convert(codes[0]) << 2) | + (Convert(codes[1]) >> 4); + // Convert() return value has upper 25 bits set if input is invalid. + // Therefore `packed` has high bits set iff at least one of code is invalid. + if (TF_PREDICT_FALSE((packed & 0xFF000000) != 0)) { + return errors::InvalidArgument("Invalid character found in base64."); + } + *result = static_cast(packed); + return Status::OK(); +} + +Status DecodeTwoChars(const char* codes, char* result) { + const uint32 packed = (Convert(codes[0]) << 10) | + (Convert(codes[1]) << 4) | + (Convert(codes[2]) >> 2); + if (TF_PREDICT_FALSE((packed & 0xFF000000) != 0)) { + return errors::InvalidArgument("Invalid character found in base64."); + } + result[0] = static_cast(packed >> 8); + result[1] = static_cast(packed); + return Status::OK(); +} + +Status DecodeThreeChars(const char* codes, char* result) { + const uint32 packed = (Convert(codes[0]) << 18) | + (Convert(codes[1]) << 12) | + (Convert(codes[2]) << 6) | + (Convert(codes[3])); + if (TF_PREDICT_FALSE((packed & 0xFF000000) != 0)) { + return errors::InvalidArgument("Invalid character found in base64."); + } + result[0] = static_cast(packed >> 16); + result[1] = static_cast(packed >> 8); + result[2] = static_cast(packed); + return Status::OK(); +} +} // namespace + +Status Base64Decode(StringPiece data, string* decoded) { + if (decoded == nullptr) { + return errors::Internal("'decoded' cannot be nullptr."); + } + + if (data.empty()) { + decoded->clear(); + return Status::OK(); + } + + // max_decoded_size may overestimate by up to 3 bytes. + const size_t max_decoded_size = 3 * (data.size() / 4) + 3; + std::unique_ptr buffer(new char[max_decoded_size]); + char* current = buffer.get(); + if (current == nullptr) { + return errors::ResourceExhausted( + "Failed to allocate buffer for decoded string."); + } + + const char* b64 = data.data(); + const char* end = data.data() + data.size(); + + while (end - b64 > 4) { + TF_RETURN_IF_ERROR(DecodeThreeChars(b64, current)); + b64 += 4; + current += 3; + } + + if (end - b64 == 4) { + // The data length is a multiple of 4. Check for padding. + // Base64 cannot have more than 2 paddings. + if (b64[2] == kPadChar && b64[3] == kPadChar) { + end -= 2; + } + if (b64[2] != kPadChar && b64[3] == kPadChar) { + end -= 1; + } + } + + switch (end - b64) { + case 4: + TF_RETURN_IF_ERROR(DecodeThreeChars(b64, current)); + current += 3; + break; + case 3: + TF_RETURN_IF_ERROR(DecodeTwoChars(b64, current)); + current += 2; + break; + case 2: + TF_RETURN_IF_ERROR(DecodeOneChar(b64, current)); + current += 1; + break; + default: // case 1 + // We may check this condition early by checking data.size() % 4 == 1. + return errors::InvalidArgument( + "Base64 string length cannot be 1 modulo 4."); + } + + decoded->assign(buffer.get(), current - buffer.get()); + return Status::OK(); +} + +Status Base64Encode(StringPiece source, string* encoded) { + return Base64Encode(source, false, encoded); +} + +Status Base64Encode(StringPiece source, bool with_padding, string* encoded) { + const char* const base64_chars = kBase64UrlSafeChars; + if (encoded == nullptr) { + return errors::Internal("'encoded' cannot be nullptr."); + } + + // max_encoded_size may overestimate by up to 4 bytes. + const size_t max_encoded_size = 4 * (source.size() / 3) + 4; + std::unique_ptr buffer(new char[max_encoded_size]); + char* current = buffer.get(); + if (current == nullptr) { + return errors::ResourceExhausted( + "Failed to allocate buffer for encoded string."); + } + + const char* data = source.data(); + const char* const end = source.data() + source.size(); + + // Encode each block. + while (end - data >= 3) { + *current++ = base64_chars[(data[0] >> 2) & 0x3F]; + *current++ = + base64_chars[((data[0] & 0x03) << 4) | ((data[1] >> 4) & 0x0F)]; + *current++ = + base64_chars[((data[1] & 0x0F) << 2) | ((data[2] >> 6) & 0x03)]; + *current++ = base64_chars[data[2] & 0x3F]; + + data += 3; + } + + // Take care of the tail. + if (end - data == 2) { + *current++ = base64_chars[(data[0] >> 2) & 0x3F]; + *current++ = + base64_chars[((data[0] & 0x03) << 4) | ((data[1] >> 4) & 0x0F)]; + *current++ = base64_chars[(data[1] & 0x0F) << 2]; + if (with_padding) { + *current++ = kPadChar; + } + } else if (end - data == 1) { + *current++ = base64_chars[(data[0] >> 2) & 0x3F]; + *current++ = base64_chars[(data[0] & 0x03) << 4]; + if (with_padding) { + *current++ = kPadChar; + *current++ = kPadChar; + } + } + + encoded->assign(buffer.get(), current - buffer.get()); + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/platform/cloud/base64.h b/tensorflow/core/lib/strings/base64.h similarity index 71% rename from tensorflow/core/platform/cloud/base64.h rename to tensorflow/core/lib/strings/base64.h index 75fe393dcaf..48a7f42b81d 100644 --- a/tensorflow/core/platform/cloud/base64.h +++ b/tensorflow/core/lib/strings/base64.h @@ -13,24 +13,25 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CORE_PLATFORM_B64_H_ -#define TENSORFLOW_CORE_PLATFORM_B64_H_ +#ifndef TENSORFLOW_LIB_STRINGS_B64_H_ +#define TENSORFLOW_LIB_STRINGS_B64_H_ #include #include "tensorflow/core/lib/core/status.h" namespace tensorflow { -/// \brief Converts data into base64 encoding. +/// \brief Converts data into web-safe base64 encoding. /// /// See https://en.wikipedia.org/wiki/Base64 -Status Base64Encode(StringPiece data, string* encoded); +Status Base64Encode(StringPiece data, bool with_padding, string* encoded); +Status Base64Encode(StringPiece data, string* encoded); // with_padding=false. -/// \brief Converts data from base64 encoding. +/// \brief Converts data from web-safe base64 encoding. /// /// See https://en.wikipedia.org/wiki/Base64 Status Base64Decode(StringPiece data, string* decoded); } // namespace tensorflow -#endif // TENSORFLOW_CORE_PLATFORM_B64_H_ +#endif // TENSORFLOW_LIB_STRINGS_B64_H_ diff --git a/tensorflow/core/platform/cloud/base64_test.cc b/tensorflow/core/lib/strings/base64_test.cc similarity index 95% rename from tensorflow/core/platform/cloud/base64_test.cc rename to tensorflow/core/lib/strings/base64_test.cc index 3c1f07e25b8..3e03d595d27 100644 --- a/tensorflow/core/platform/cloud/base64_test.cc +++ b/tensorflow/core/lib/strings/base64_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/platform/cloud/base64.h" +#include "tensorflow/core/lib/strings/base64.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/core/ops/string_ops.cc b/tensorflow/core/ops/string_ops.cc index d940b0f563d..3b9f96e4964 100644 --- a/tensorflow/core/ops/string_ops.cc +++ b/tensorflow/core/ops/string_ops.cc @@ -247,4 +247,38 @@ shape: a length-2 vector of int64 representing the shape of the sparse of tokens in a single input entry. )doc"); +REGISTER_OP("EncodeBase64") + .Input("input: string") + .Output("output: string") + .Attr("pad: bool = false") + .SetShapeFn(shape_inference::UnchangedShape) + .Doc(R"doc( +Encode strings into web-safe base64 format. + +Refer to the following article for more information on base64 format: +en.wikipedia.org/wiki/Base64. Base64 strings may have padding with '=' at the +end so that the encoded has length multiple of 4. See Padding section of the +link above. + +Web-safe means that the encoder uses - and _ instead of + and /. + +input: Strings to be encoded. +output: Input strings encoded in base64. +pad: Bool whether padding is applied at the ends. +)doc"); + +REGISTER_OP("DecodeBase64") + .Input("input: string") + .Output("output: string") + .SetShapeFn(shape_inference::UnchangedShape) + .Doc(R"doc( +Decode web-safe base64-encoded strings. + +Input may or may not have padding at the end. See EncodeBase64 for padding. +Web-safe means that input must use - and _ instead of + and /. + +input: Base64 strings to decode. +output: Decoded strings. +)doc"); + } // namespace tensorflow diff --git a/tensorflow/core/platform/cloud/BUILD b/tensorflow/core/platform/cloud/BUILD index f5af42103a5..7ee9f4e400c 100644 --- a/tensorflow/core/platform/cloud/BUILD +++ b/tensorflow/core/platform/cloud/BUILD @@ -83,10 +83,10 @@ cc_library( "google_auth_provider.h", ], deps = [ - ":base64", ":http_request", ":oauth_client", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "@jsoncpp_git//:jsoncpp", ], ) @@ -100,27 +100,14 @@ cc_library( "oauth_client.h", ], deps = [ - ":base64", ":http_request", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "@boringssl_git//:crypto", "@jsoncpp_git//:jsoncpp", ], ) -cc_library( - name = "base64", - srcs = [ - "base64.cc", - ], - hdrs = [ - "base64.h", - ], - deps = [ - "//tensorflow/core:lib", - ], -) - cc_library( name = "retrying_file_system", srcs = [ @@ -165,7 +152,6 @@ tf_cc_test( "testdata/service_account_public_key.txt", ], deps = [ - ":base64", ":http_request_fake", ":oauth_client", "//tensorflow/core:lib", @@ -176,18 +162,6 @@ tf_cc_test( ], ) -tf_cc_test( - name = "base64_test", - size = "small", - deps = [ - ":base64", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - ], -) - tf_cc_test( name = "google_auth_provider_test", size = "small", @@ -196,7 +170,6 @@ tf_cc_test( "testdata/service_account_credentials.json", ], deps = [ - ":base64", ":google_auth_provider", ":http_request_fake", ":oauth_client", diff --git a/tensorflow/core/platform/cloud/base64.cc b/tensorflow/core/platform/cloud/base64.cc deleted file mode 100644 index 68dbf475d4d..00000000000 --- a/tensorflow/core/platform/cloud/base64.cc +++ /dev/null @@ -1,221 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/platform/cloud/base64.h" -#include -#include "tensorflow/core/lib/core/errors.h" - -namespace tensorflow { - -namespace { - -constexpr signed char kBase64Bytes[] = { - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, 0x3E, -1, -1, -1, 0x3F, - 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3A, 0x3B, 0x3C, 0x3D, -1, -1, - -1, 0x7F, -1, -1, -1, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, - 0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12, - 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, -1, -1, -1, -1, -1, - -1, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20, 0x21, 0x22, 0x23, 0x24, - 0x25, 0x26, 0x27, 0x28, 0x29, 0x2A, 0x2B, 0x2C, 0x2D, 0x2E, 0x2F, 0x30, - 0x31, 0x32, 0x33, -1, -1, -1, -1, -1}; - -constexpr char kBase64UrlSafeChars[] = - "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"; - -constexpr char kPadChar = '='; -constexpr char kPadByte = 0x7F; -constexpr int kMultilineLineLen = 76; -constexpr int kMultilineNumBlocks = kMultilineLineLen / 4; - -Status Base64Encode(StringPiece source, bool multiline, bool with_padding, - string *encoded) { - if (!encoded) { - return errors::FailedPrecondition("'encoded' cannot be nullptr."); - } - size_t data_size = source.size(); - const char *data = source.data(); - const char *base64_chars = kBase64UrlSafeChars; - const size_t result_projected_size = - 4 * ((data_size + 3) / 3) + - 2 * (multiline ? (data_size / (3 * kMultilineNumBlocks)) : 0) + 1; - size_t num_blocks = 0; - size_t i = 0; - std::unique_ptr result(new char[result_projected_size]); - char *current = result.get(); - - /* Encode each block. */ - while (data_size >= 3) { - *current++ = base64_chars[(data[i] >> 2) & 0x3F]; - *current++ = - base64_chars[((data[i] & 0x03) << 4) | ((data[i + 1] >> 4) & 0x0F)]; - *current++ = - base64_chars[((data[i + 1] & 0x0F) << 2) | ((data[i + 2] >> 6) & 0x03)]; - *current++ = base64_chars[data[i + 2] & 0x3F]; - - data_size -= 3; - i += 3; - if (multiline && (++num_blocks == kMultilineNumBlocks)) { - *current++ = '\r'; - *current++ = '\n'; - num_blocks = 0; - } - } - - /* Take care of the tail. */ - if (data_size == 2) { - *current++ = base64_chars[(data[i] >> 2) & 0x3F]; - *current++ = - base64_chars[((data[i] & 0x03) << 4) | ((data[i + 1] >> 4) & 0x0F)]; - *current++ = base64_chars[(data[i + 1] & 0x0F) << 2]; - if (with_padding) { - *current++ = kPadChar; - } - } else if (data_size == 1) { - *current++ = base64_chars[(data[i] >> 2) & 0x3F]; - *current++ = base64_chars[(data[i] & 0x03) << 4]; - if (with_padding) { - *current++ = kPadChar; - *current++ = kPadChar; - } - } - - if (current < result.get() || - current >= result.get() + result_projected_size) { - return errors::Internal("Unexpected encoding bug."); - } - *current++ = '\0'; - *encoded = result.get(); - return Status::OK(); -} - -void DecodeOneChar(const unsigned char *codes, unsigned char *result, - size_t *result_offset) { - const uint32_t packed = ((uint32_t)codes[0] << 2) | ((uint32_t)codes[1] >> 4); - result[(*result_offset)++] = (unsigned char)packed; -} - -void DecodeTwoChars(const unsigned char *codes, unsigned char *result, - size_t *result_offset) { - const uint32_t packed = ((uint32_t)codes[0] << 10) | - ((uint32_t)codes[1] << 4) | ((uint32_t)codes[2] >> 2); - result[(*result_offset)++] = (unsigned char)(packed >> 8); - result[(*result_offset)++] = (unsigned char)(packed); -} - -Status DecodeGroup(const unsigned char *codes, size_t num_codes, - unsigned char *result, size_t *result_offset) { - if (num_codes > 4) { - return errors::FailedPrecondition("Expected 4 or fewer codes."); - } - - /* Short end groups that may not have padding. */ - if (num_codes == 1) { - return errors::FailedPrecondition( - "Invalid group. Must be at least 2 bytes."); - } - if (num_codes == 2) { - DecodeOneChar(codes, result, result_offset); - return Status::OK(); - } - if (num_codes == 3) { - DecodeTwoChars(codes, result, result_offset); - return Status::OK(); - } - - /* Regular 4 byte groups with padding or not. */ - if (num_codes != 4) { - return errors::FailedPrecondition("Expected exactly 4 codes."); - } - if (codes[0] == kPadByte || codes[1] == kPadByte) { - return errors::FailedPrecondition("Invalid padding detected."); - } - if (codes[2] == kPadByte) { - if (codes[3] == kPadByte) { - DecodeOneChar(codes, result, result_offset); - } else { - return errors::FailedPrecondition("Invalid padding detected."); - } - } else if (codes[3] == kPadByte) { - DecodeTwoChars(codes, result, result_offset); - } else { - /* No padding. */ - const uint32_t packed = ((uint32_t)codes[0] << 18) | - ((uint32_t)codes[1] << 12) | - ((uint32_t)codes[2] << 6) | codes[3]; - result[(*result_offset)++] = (unsigned char)(packed >> 16); - result[(*result_offset)++] = (unsigned char)(packed >> 8); - result[(*result_offset)++] = (unsigned char)(packed); - } - return Status::OK(); -} - -} // namespace - -Status Base64Encode(StringPiece source, string *encoded) { - return Base64Encode(source, false, false, encoded); -} - -Status Base64Decode(StringPiece data, string *decoded) { - if (!decoded) { - return errors::FailedPrecondition("'decoded' cannot be nullptr."); - } - std::unique_ptr result(new unsigned char[data.size()]); - unsigned char *current = result.get(); - size_t result_size = 0; - unsigned char codes[4]; - size_t num_codes = 0; - - const char *b64 = data.data(); - size_t b64_len = data.size(); - while (b64_len--) { - unsigned char c = (unsigned char)(*b64++); - signed char code; - if (c >= sizeof(kBase64Bytes)) continue; - if (c == '+' || c == '/') { - return errors::FailedPrecondition( - strings::StrCat("Invalid character for url safe base64 ", c)); - } - if (c == '-') { - c = '+'; - } else if (c == '_') { - c = '/'; - } - code = kBase64Bytes[c]; - if (code == -1) { - if (c != '\r' && c != '\n') { - return errors::FailedPrecondition( - strings::StrCat("Invalid character ", c)); - } - } else { - codes[num_codes++] = (unsigned char)code; - if (num_codes == 4) { - TF_RETURN_IF_ERROR( - DecodeGroup(codes, num_codes, current, &result_size)); - num_codes = 0; - } - } - } - - if (num_codes != 0) { - TF_RETURN_IF_ERROR(DecodeGroup(codes, num_codes, current, &result_size)); - } - *decoded = string(reinterpret_cast(result.get()), result_size); - return Status::OK(); -} - -} // namespace tensorflow diff --git a/tensorflow/core/platform/cloud/google_auth_provider.cc b/tensorflow/core/platform/cloud/google_auth_provider.cc index 6d4ebb1d61c..4acdcd2f4a2 100644 --- a/tensorflow/core/platform/cloud/google_auth_provider.cc +++ b/tensorflow/core/platform/cloud/google_auth_provider.cc @@ -21,7 +21,7 @@ limitations under the License. #include "include/json/json.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/io/path.h" -#include "tensorflow/core/platform/cloud/base64.h" +#include "tensorflow/core/lib/strings/base64.h" #include "tensorflow/core/platform/cloud/http_request.h" #include "tensorflow/core/platform/env.h" diff --git a/tensorflow/core/platform/cloud/oauth_client.cc b/tensorflow/core/platform/cloud/oauth_client.cc index 36d1b0549e5..914509a6bda 100644 --- a/tensorflow/core/platform/cloud/oauth_client.cc +++ b/tensorflow/core/platform/cloud/oauth_client.cc @@ -22,7 +22,7 @@ limitations under the License. #include #include #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/platform/cloud/base64.h" +#include "tensorflow/core/lib/strings/base64.h" #include "tensorflow/core/platform/cloud/http_request.h" #include "tensorflow/core/platform/env.h" diff --git a/tensorflow/core/platform/cloud/oauth_client_test.cc b/tensorflow/core/platform/cloud/oauth_client_test.cc index 4b23b6c7bd7..236259dbc16 100644 --- a/tensorflow/core/platform/cloud/oauth_client_test.cc +++ b/tensorflow/core/platform/cloud/oauth_client_test.cc @@ -20,8 +20,8 @@ limitations under the License. #include #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/base64.h" #include "tensorflow/core/lib/strings/scanner.h" -#include "tensorflow/core/platform/cloud/base64.h" #include "tensorflow/core/platform/cloud/http_request_fake.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index dc52ba44560..70bc65e0e44 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -22,6 +22,7 @@ py_tests( "as_string_op_test.py", "attention_ops_test.py", "barrier_ops_test.py", + "base64_ops_test.py", "bcast_ops_test.py", "benchmark_test.py", "candidate_sampler_ops_test.py", diff --git a/tensorflow/python/kernel_tests/base64_ops_test.py b/tensorflow/python/kernel_tests/base64_ops_test.py new file mode 100644 index 00000000000..be96f454979 --- /dev/null +++ b/tensorflow/python/kernel_tests/base64_ops_test.py @@ -0,0 +1,155 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for EncodeBase64 and DecodeBase64.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import base64 + +import numpy as np + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import string_ops +from tensorflow.python.platform import test + + +class Base64OpsTest(test_util.TensorFlowTestCase): + + def setUp(self): + self._msg = array_ops.placeholder(dtype=dtypes.string) + self._encoded_f = string_ops.encode_base64(self._msg, pad=False) + self._decoded_f = string_ops.decode_base64(self._encoded_f) + self._encoded_t = string_ops.encode_base64(self._msg, pad=True) + self._decoded_t = string_ops.decode_base64(self._encoded_t) + + def _RemovePad(self, msg, base64_msg): + if len(msg) % 3 == 1: + return base64_msg[:-2] + if len(msg) % 3 == 2: + return base64_msg[:-1] + return base64_msg + + def _RunTest(self, msg, pad): + with self.test_session() as sess: + if pad: + encoded, decoded = sess.run([self._encoded_t, self._decoded_t], + feed_dict={self._msg: msg}) + else: + encoded, decoded = sess.run([self._encoded_f, self._decoded_f], + feed_dict={self._msg: msg}) + + if not isinstance(msg, (list, tuple)): + msg = [msg] + encoded = [encoded] + decoded = [decoded] + + base64_msg = [base64.urlsafe_b64encode(m) for m in msg] + if not pad: + base64_msg = [self._RemovePad(m, b) for m, b in zip(msg, base64_msg)] + + for i in range(len(msg)): + self.assertEqual(base64_msg[i], encoded[i]) + self.assertEqual(msg[i], decoded[i]) + + def testWithPythonBase64(self): + for pad in (False, True): + self._RunTest(b"", pad=pad) + + for _ in range(100): + length = np.random.randint(1024 * 1024) + msg = np.random.bytes(length) + self._RunTest(msg, pad=pad) + + def testShape(self): + for pad in (False, True): + for _ in range(10): + msg = [np.random.bytes(np.random.randint(20)) + for _ in range(np.random.randint(10))] + self._RunTest(msg, pad=pad) + + # Zero-element, non-trivial shapes. + for _ in range(10): + k = np.random.randint(10) + msg = np.empty((0, k), dtype=bytes) + encoded = string_ops.encode_base64(msg, pad=pad) + decoded = string_ops.decode_base64(encoded) + + with self.test_session() as sess: + encoded_value, decoded_value = sess.run([encoded, decoded]) + + self.assertEqual(encoded_value.shape, msg.shape) + self.assertEqual(decoded_value.shape, msg.shape) + + def testInvalidInput(self): + def try_decode(enc): + self._decoded_f.eval(feed_dict={self._encoded_f: enc}) + + with self.test_session(): + # Invalid length. + msg = np.random.bytes(99) + enc = base64.urlsafe_b64encode(msg) + with self.assertRaisesRegexp(errors.InvalidArgumentError, "1 modulo 4"): + try_decode(enc + b"a") + + # Invalid char used in encoding. + msg = np.random.bytes(34) + enc = base64.urlsafe_b64encode(msg) + for i in range(len(msg)): + with self.assertRaises(errors.InvalidArgumentError): + try_decode(enc[:i] + b"?" + enc[(i + 1):]) + with self.assertRaises(errors.InvalidArgumentError): + try_decode(enc[:i] + b"\x80" + enc[(i + 1):]) # outside ascii range. + with self.assertRaises(errors.InvalidArgumentError): + try_decode(enc[:i] + b"+" + enc[(i + 1):]) # not url-safe. + with self.assertRaises(errors.InvalidArgumentError): + try_decode(enc[:i] + b"/" + enc[(i + 1):]) # not url-safe. + + # Partial padding. + msg = np.random.bytes(34) + enc = base64.urlsafe_b64encode(msg) + with self.assertRaises(errors.InvalidArgumentError): + # enc contains == at the end. Partial padding is not allowed. + try_decode(enc[:-1]) + + # Unnecessary padding. + msg = np.random.bytes(33) + enc = base64.urlsafe_b64encode(msg) + with self.assertRaises(errors.InvalidArgumentError): + try_decode(enc + b"==") + with self.assertRaises(errors.InvalidArgumentError): + try_decode(enc + b"===") + with self.assertRaises(errors.InvalidArgumentError): + try_decode(enc + b"====") + + # Padding in the middle. (Previous implementation was ok with this as long + # as padding char location was 2 or 3 (mod 4). + msg = np.random.bytes(33) + enc = base64.urlsafe_b64encode(msg) + for i in range(len(msg) - 1): + with self.assertRaises(errors.InvalidArgumentError): + try_decode(enc[:i] + b"=" + enc[(i + 1):]) + for i in range(len(msg) - 2): + with self.assertRaises(errors.InvalidArgumentError): + try_decode(enc[:i] + b"==" + enc[(i + 2):]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/ops/string_ops.py b/tensorflow/python/ops/string_ops.py index e3450ee0b76..7ffa521ac68 100644 --- a/tensorflow/python/ops/string_ops.py +++ b/tensorflow/python/ops/string_ops.py @@ -37,6 +37,8 @@ string tensor. ## Conversion @@as_string +@@encode_base64 +@@decode_base64 """ from __future__ import absolute_import @@ -118,11 +120,15 @@ ops.NoGradient("ReduceJoin") ops.NoGradient("StringJoin") ops.NoGradient("StringSplit") ops.NoGradient("AsString") +ops.NoGradient("EncodeBase64") +ops.NoGradient("DecodeBase64") ops.RegisterShape("StringToHashBucket")(common_shapes.unchanged_shape) ops.RegisterShape("StringToHashBucketFast")(common_shapes.unchanged_shape) ops.RegisterShape("StringToHashBucketStrong")(common_shapes.unchanged_shape) ops.RegisterShape("AsString")(common_shapes.unchanged_shape) +ops.RegisterShape("EncodeBase64")(common_shapes.unchanged_shape) +ops.RegisterShape("DecodeBase64")(common_shapes.unchanged_shape) @ops.RegisterShape("ReduceJoin")