Google authentication for GCS file system.
Implements an authentication mechanism based on Application Default Credentials: https://developers.google.com/identity/protocols/application-default-credentials https://developers.google.com/identity/protocols/OAuth2ServiceAccount Change: 122741738
This commit is contained in:
parent
7584aa60a4
commit
bb465cdc0e
410
boringssl.BUILD
Normal file
410
boringssl.BUILD
Normal file
@ -0,0 +1,410 @@
|
||||
package(default_visibility = ["//tensorflow:__subpackages__"])
|
||||
|
||||
licenses(["restricted"]) # OpenSSL license, partly BSD-like
|
||||
|
||||
# See https://boringssl.googlesource.com/boringssl/+/master/INCORPORATING.md
|
||||
# on how to re-generate the list of source files.
|
||||
|
||||
crypto_headers = [
|
||||
"include/openssl/aead.h",
|
||||
"include/openssl/aes.h",
|
||||
"include/openssl/arm_arch.h",
|
||||
"include/openssl/asn1.h",
|
||||
"include/openssl/asn1_mac.h",
|
||||
"include/openssl/asn1t.h",
|
||||
"include/openssl/base.h",
|
||||
"include/openssl/base64.h",
|
||||
"include/openssl/bio.h",
|
||||
"include/openssl/blowfish.h",
|
||||
"include/openssl/bn.h",
|
||||
"include/openssl/buf.h",
|
||||
"include/openssl/buffer.h",
|
||||
"include/openssl/bytestring.h",
|
||||
"include/openssl/cast.h",
|
||||
"include/openssl/chacha.h",
|
||||
"include/openssl/cipher.h",
|
||||
"include/openssl/cmac.h",
|
||||
"include/openssl/conf.h",
|
||||
"include/openssl/cpu.h",
|
||||
"include/openssl/crypto.h",
|
||||
"include/openssl/curve25519.h",
|
||||
"include/openssl/des.h",
|
||||
"include/openssl/dh.h",
|
||||
"include/openssl/digest.h",
|
||||
"include/openssl/dsa.h",
|
||||
"include/openssl/ec.h",
|
||||
"include/openssl/ec_key.h",
|
||||
"include/openssl/ecdh.h",
|
||||
"include/openssl/ecdsa.h",
|
||||
"include/openssl/engine.h",
|
||||
"include/openssl/err.h",
|
||||
"include/openssl/evp.h",
|
||||
"include/openssl/ex_data.h",
|
||||
"include/openssl/hkdf.h",
|
||||
"include/openssl/hmac.h",
|
||||
"include/openssl/lhash.h",
|
||||
"include/openssl/lhash_macros.h",
|
||||
"include/openssl/md4.h",
|
||||
"include/openssl/md5.h",
|
||||
"include/openssl/mem.h",
|
||||
"include/openssl/newhope.h",
|
||||
"include/openssl/nid.h",
|
||||
"include/openssl/obj.h",
|
||||
"include/openssl/obj_mac.h",
|
||||
"include/openssl/objects.h",
|
||||
"include/openssl/opensslconf.h",
|
||||
"include/openssl/opensslv.h",
|
||||
"include/openssl/ossl_typ.h",
|
||||
"include/openssl/pem.h",
|
||||
"include/openssl/pkcs12.h",
|
||||
"include/openssl/pkcs7.h",
|
||||
"include/openssl/pkcs8.h",
|
||||
"include/openssl/poly1305.h",
|
||||
"include/openssl/pqueue.h",
|
||||
"include/openssl/rand.h",
|
||||
"include/openssl/rc4.h",
|
||||
"include/openssl/ripemd.h",
|
||||
"include/openssl/rsa.h",
|
||||
"include/openssl/safestack.h",
|
||||
"include/openssl/sha.h",
|
||||
"include/openssl/srtp.h",
|
||||
"include/openssl/stack.h",
|
||||
"include/openssl/stack_macros.h",
|
||||
"include/openssl/thread.h",
|
||||
"include/openssl/time_support.h",
|
||||
"include/openssl/type_check.h",
|
||||
"include/openssl/x509.h",
|
||||
"include/openssl/x509_vfy.h",
|
||||
"include/openssl/x509v3.h",
|
||||
]
|
||||
|
||||
crypto_internal_headers = [
|
||||
"crypto/aes/internal.h",
|
||||
"crypto/asn1/asn1_locl.h",
|
||||
"crypto/bio/internal.h",
|
||||
"crypto/bn/internal.h",
|
||||
"crypto/bn/rsaz_exp.h",
|
||||
"crypto/bytestring/internal.h",
|
||||
"crypto/cipher/internal.h",
|
||||
"crypto/conf/conf_def.h",
|
||||
"crypto/conf/internal.h",
|
||||
"crypto/curve25519/internal.h",
|
||||
"crypto/des/internal.h",
|
||||
"crypto/dh/internal.h",
|
||||
"crypto/digest/internal.h",
|
||||
"crypto/digest/md32_common.h",
|
||||
"crypto/ec/internal.h",
|
||||
"crypto/ec/p256-x86_64-table.h",
|
||||
"crypto/evp/internal.h",
|
||||
"crypto/internal.h",
|
||||
"crypto/modes/internal.h",
|
||||
"crypto/newhope/internal.h",
|
||||
"crypto/obj/obj_dat.h",
|
||||
"crypto/obj/obj_xref.h",
|
||||
"crypto/pkcs8/internal.h",
|
||||
"crypto/poly1305/internal.h",
|
||||
"crypto/rand/internal.h",
|
||||
"crypto/rsa/internal.h",
|
||||
"crypto/test/scoped_types.h",
|
||||
"crypto/test/test_util.h",
|
||||
"crypto/x509/charmap.h",
|
||||
"crypto/x509/internal.h",
|
||||
"crypto/x509/vpm_int.h",
|
||||
"crypto/x509v3/ext_dat.h",
|
||||
"crypto/x509v3/pcy_int.h",
|
||||
]
|
||||
|
||||
crypto_sources = [
|
||||
":err_data_c",
|
||||
"crypto/aes/aes.c",
|
||||
"crypto/aes/mode_wrappers.c",
|
||||
"crypto/asn1/a_bitstr.c",
|
||||
"crypto/asn1/a_bool.c",
|
||||
"crypto/asn1/a_bytes.c",
|
||||
"crypto/asn1/a_d2i_fp.c",
|
||||
"crypto/asn1/a_dup.c",
|
||||
"crypto/asn1/a_enum.c",
|
||||
"crypto/asn1/a_gentm.c",
|
||||
"crypto/asn1/a_i2d_fp.c",
|
||||
"crypto/asn1/a_int.c",
|
||||
"crypto/asn1/a_mbstr.c",
|
||||
"crypto/asn1/a_object.c",
|
||||
"crypto/asn1/a_octet.c",
|
||||
"crypto/asn1/a_print.c",
|
||||
"crypto/asn1/a_strnid.c",
|
||||
"crypto/asn1/a_time.c",
|
||||
"crypto/asn1/a_type.c",
|
||||
"crypto/asn1/a_utctm.c",
|
||||
"crypto/asn1/a_utf8.c",
|
||||
"crypto/asn1/asn1_lib.c",
|
||||
"crypto/asn1/asn1_par.c",
|
||||
"crypto/asn1/asn_pack.c",
|
||||
"crypto/asn1/bio_asn1.c",
|
||||
"crypto/asn1/bio_ndef.c",
|
||||
"crypto/asn1/f_enum.c",
|
||||
"crypto/asn1/f_int.c",
|
||||
"crypto/asn1/f_string.c",
|
||||
"crypto/asn1/t_bitst.c",
|
||||
"crypto/asn1/tasn_dec.c",
|
||||
"crypto/asn1/tasn_enc.c",
|
||||
"crypto/asn1/tasn_fre.c",
|
||||
"crypto/asn1/tasn_new.c",
|
||||
"crypto/asn1/tasn_prn.c",
|
||||
"crypto/asn1/tasn_typ.c",
|
||||
"crypto/asn1/tasn_utl.c",
|
||||
"crypto/asn1/x_bignum.c",
|
||||
"crypto/asn1/x_long.c",
|
||||
"crypto/base64/base64.c",
|
||||
"crypto/bio/bio.c",
|
||||
"crypto/bio/bio_mem.c",
|
||||
"crypto/bio/buffer.c",
|
||||
"crypto/bio/connect.c",
|
||||
"crypto/bio/fd.c",
|
||||
"crypto/bio/file.c",
|
||||
"crypto/bio/hexdump.c",
|
||||
"crypto/bio/pair.c",
|
||||
"crypto/bio/printf.c",
|
||||
"crypto/bio/socket.c",
|
||||
"crypto/bio/socket_helper.c",
|
||||
"crypto/bn/add.c",
|
||||
"crypto/bn/asm/x86_64-gcc.c",
|
||||
"crypto/bn/bn.c",
|
||||
"crypto/bn/bn_asn1.c",
|
||||
"crypto/bn/cmp.c",
|
||||
"crypto/bn/convert.c",
|
||||
"crypto/bn/ctx.c",
|
||||
"crypto/bn/div.c",
|
||||
"crypto/bn/exponentiation.c",
|
||||
"crypto/bn/gcd.c",
|
||||
"crypto/bn/generic.c",
|
||||
"crypto/bn/kronecker.c",
|
||||
"crypto/bn/montgomery.c",
|
||||
"crypto/bn/mul.c",
|
||||
"crypto/bn/prime.c",
|
||||
"crypto/bn/random.c",
|
||||
"crypto/bn/rsaz_exp.c",
|
||||
"crypto/bn/shift.c",
|
||||
"crypto/bn/sqrt.c",
|
||||
"crypto/buf/buf.c",
|
||||
"crypto/bytestring/asn1_compat.c",
|
||||
"crypto/bytestring/ber.c",
|
||||
"crypto/bytestring/cbb.c",
|
||||
"crypto/bytestring/cbs.c",
|
||||
"crypto/chacha/chacha.c",
|
||||
"crypto/cipher/aead.c",
|
||||
"crypto/cipher/cipher.c",
|
||||
"crypto/cipher/derive_key.c",
|
||||
"crypto/cipher/e_aes.c",
|
||||
"crypto/cipher/e_chacha20poly1305.c",
|
||||
"crypto/cipher/e_des.c",
|
||||
"crypto/cipher/e_null.c",
|
||||
"crypto/cipher/e_rc2.c",
|
||||
"crypto/cipher/e_rc4.c",
|
||||
"crypto/cipher/e_ssl3.c",
|
||||
"crypto/cipher/e_tls.c",
|
||||
"crypto/cipher/tls_cbc.c",
|
||||
"crypto/cmac/cmac.c",
|
||||
"crypto/conf/conf.c",
|
||||
"crypto/cpu-aarch64-linux.c",
|
||||
"crypto/cpu-arm-linux.c",
|
||||
"crypto/cpu-arm.c",
|
||||
"crypto/cpu-intel.c",
|
||||
"crypto/crypto.c",
|
||||
"crypto/curve25519/curve25519.c",
|
||||
"crypto/curve25519/spake25519.c",
|
||||
"crypto/curve25519/x25519-x86_64.c",
|
||||
"crypto/des/des.c",
|
||||
"crypto/dh/check.c",
|
||||
"crypto/dh/dh.c",
|
||||
"crypto/dh/dh_asn1.c",
|
||||
"crypto/dh/params.c",
|
||||
"crypto/digest/digest.c",
|
||||
"crypto/digest/digests.c",
|
||||
"crypto/dsa/dsa.c",
|
||||
"crypto/dsa/dsa_asn1.c",
|
||||
"crypto/ec/ec.c",
|
||||
"crypto/ec/ec_asn1.c",
|
||||
"crypto/ec/ec_key.c",
|
||||
"crypto/ec/ec_montgomery.c",
|
||||
"crypto/ec/oct.c",
|
||||
"crypto/ec/p224-64.c",
|
||||
"crypto/ec/p256-64.c",
|
||||
"crypto/ec/p256-x86_64.c",
|
||||
"crypto/ec/simple.c",
|
||||
"crypto/ec/util-64.c",
|
||||
"crypto/ec/wnaf.c",
|
||||
"crypto/ecdh/ecdh.c",
|
||||
"crypto/ecdsa/ecdsa.c",
|
||||
"crypto/ecdsa/ecdsa_asn1.c",
|
||||
"crypto/engine/engine.c",
|
||||
"crypto/err/err.c",
|
||||
"crypto/evp/digestsign.c",
|
||||
"crypto/evp/evp.c",
|
||||
"crypto/evp/evp_asn1.c",
|
||||
"crypto/evp/evp_ctx.c",
|
||||
"crypto/evp/p_dsa_asn1.c",
|
||||
"crypto/evp/p_ec.c",
|
||||
"crypto/evp/p_ec_asn1.c",
|
||||
"crypto/evp/p_rsa.c",
|
||||
"crypto/evp/p_rsa_asn1.c",
|
||||
"crypto/evp/pbkdf.c",
|
||||
"crypto/evp/print.c",
|
||||
"crypto/evp/sign.c",
|
||||
"crypto/ex_data.c",
|
||||
"crypto/hkdf/hkdf.c",
|
||||
"crypto/hmac/hmac.c",
|
||||
"crypto/lhash/lhash.c",
|
||||
"crypto/md4/md4.c",
|
||||
"crypto/md5/md5.c",
|
||||
"crypto/mem.c",
|
||||
"crypto/modes/cbc.c",
|
||||
"crypto/modes/cfb.c",
|
||||
"crypto/modes/ctr.c",
|
||||
"crypto/modes/gcm.c",
|
||||
"crypto/modes/ofb.c",
|
||||
"crypto/newhope/error_correction.c",
|
||||
"crypto/newhope/newhope.c",
|
||||
"crypto/newhope/ntt.c",
|
||||
"crypto/newhope/poly.c",
|
||||
"crypto/newhope/precomp.c",
|
||||
"crypto/newhope/reduce.c",
|
||||
"crypto/obj/obj.c",
|
||||
"crypto/obj/obj_xref.c",
|
||||
"crypto/pem/pem_all.c",
|
||||
"crypto/pem/pem_info.c",
|
||||
"crypto/pem/pem_lib.c",
|
||||
"crypto/pem/pem_oth.c",
|
||||
"crypto/pem/pem_pk8.c",
|
||||
"crypto/pem/pem_pkey.c",
|
||||
"crypto/pem/pem_x509.c",
|
||||
"crypto/pem/pem_xaux.c",
|
||||
"crypto/pkcs8/p5_pbe.c",
|
||||
"crypto/pkcs8/p5_pbev2.c",
|
||||
"crypto/pkcs8/p8_pkey.c",
|
||||
"crypto/pkcs8/pkcs8.c",
|
||||
"crypto/poly1305/poly1305.c",
|
||||
"crypto/poly1305/poly1305_arm.c",
|
||||
"crypto/poly1305/poly1305_vec.c",
|
||||
"crypto/rand/deterministic.c",
|
||||
"crypto/rand/rand.c",
|
||||
"crypto/rand/urandom.c",
|
||||
"crypto/rand/windows.c",
|
||||
"crypto/rc4/rc4.c",
|
||||
"crypto/refcount_c11.c",
|
||||
"crypto/refcount_lock.c",
|
||||
"crypto/rsa/blinding.c",
|
||||
"crypto/rsa/padding.c",
|
||||
"crypto/rsa/rsa.c",
|
||||
"crypto/rsa/rsa_asn1.c",
|
||||
"crypto/rsa/rsa_impl.c",
|
||||
"crypto/sha/sha1.c",
|
||||
"crypto/sha/sha256.c",
|
||||
"crypto/sha/sha512.c",
|
||||
"crypto/stack/stack.c",
|
||||
"crypto/thread.c",
|
||||
"crypto/thread_none.c",
|
||||
"crypto/thread_pthread.c",
|
||||
"crypto/thread_win.c",
|
||||
"crypto/time_support.c",
|
||||
"crypto/x509/a_digest.c",
|
||||
"crypto/x509/a_sign.c",
|
||||
"crypto/x509/a_strex.c",
|
||||
"crypto/x509/a_verify.c",
|
||||
"crypto/x509/algorithm.c",
|
||||
"crypto/x509/asn1_gen.c",
|
||||
"crypto/x509/by_dir.c",
|
||||
"crypto/x509/by_file.c",
|
||||
"crypto/x509/i2d_pr.c",
|
||||
"crypto/x509/pkcs7.c",
|
||||
"crypto/x509/rsa_pss.c",
|
||||
"crypto/x509/t_crl.c",
|
||||
"crypto/x509/t_req.c",
|
||||
"crypto/x509/t_x509.c",
|
||||
"crypto/x509/t_x509a.c",
|
||||
"crypto/x509/x509.c",
|
||||
"crypto/x509/x509_att.c",
|
||||
"crypto/x509/x509_cmp.c",
|
||||
"crypto/x509/x509_d2.c",
|
||||
"crypto/x509/x509_def.c",
|
||||
"crypto/x509/x509_ext.c",
|
||||
"crypto/x509/x509_lu.c",
|
||||
"crypto/x509/x509_obj.c",
|
||||
"crypto/x509/x509_r2x.c",
|
||||
"crypto/x509/x509_req.c",
|
||||
"crypto/x509/x509_set.c",
|
||||
"crypto/x509/x509_trs.c",
|
||||
"crypto/x509/x509_txt.c",
|
||||
"crypto/x509/x509_v3.c",
|
||||
"crypto/x509/x509_vfy.c",
|
||||
"crypto/x509/x509_vpm.c",
|
||||
"crypto/x509/x509cset.c",
|
||||
"crypto/x509/x509name.c",
|
||||
"crypto/x509/x509rset.c",
|
||||
"crypto/x509/x509spki.c",
|
||||
"crypto/x509/x509type.c",
|
||||
"crypto/x509/x_algor.c",
|
||||
"crypto/x509/x_all.c",
|
||||
"crypto/x509/x_attrib.c",
|
||||
"crypto/x509/x_crl.c",
|
||||
"crypto/x509/x_exten.c",
|
||||
"crypto/x509/x_info.c",
|
||||
"crypto/x509/x_name.c",
|
||||
"crypto/x509/x_pkey.c",
|
||||
"crypto/x509/x_pubkey.c",
|
||||
"crypto/x509/x_req.c",
|
||||
"crypto/x509/x_sig.c",
|
||||
"crypto/x509/x_spki.c",
|
||||
"crypto/x509/x_val.c",
|
||||
"crypto/x509/x_x509.c",
|
||||
"crypto/x509/x_x509a.c",
|
||||
"crypto/x509v3/pcy_cache.c",
|
||||
"crypto/x509v3/pcy_data.c",
|
||||
"crypto/x509v3/pcy_lib.c",
|
||||
"crypto/x509v3/pcy_map.c",
|
||||
"crypto/x509v3/pcy_node.c",
|
||||
"crypto/x509v3/pcy_tree.c",
|
||||
"crypto/x509v3/v3_akey.c",
|
||||
"crypto/x509v3/v3_akeya.c",
|
||||
"crypto/x509v3/v3_alt.c",
|
||||
"crypto/x509v3/v3_bcons.c",
|
||||
"crypto/x509v3/v3_bitst.c",
|
||||
"crypto/x509v3/v3_conf.c",
|
||||
"crypto/x509v3/v3_cpols.c",
|
||||
"crypto/x509v3/v3_crld.c",
|
||||
"crypto/x509v3/v3_enum.c",
|
||||
"crypto/x509v3/v3_extku.c",
|
||||
"crypto/x509v3/v3_genn.c",
|
||||
"crypto/x509v3/v3_ia5.c",
|
||||
"crypto/x509v3/v3_info.c",
|
||||
"crypto/x509v3/v3_int.c",
|
||||
"crypto/x509v3/v3_lib.c",
|
||||
"crypto/x509v3/v3_ncons.c",
|
||||
"crypto/x509v3/v3_pci.c",
|
||||
"crypto/x509v3/v3_pcia.c",
|
||||
"crypto/x509v3/v3_pcons.c",
|
||||
"crypto/x509v3/v3_pku.c",
|
||||
"crypto/x509v3/v3_pmaps.c",
|
||||
"crypto/x509v3/v3_prn.c",
|
||||
"crypto/x509v3/v3_purp.c",
|
||||
"crypto/x509v3/v3_skey.c",
|
||||
"crypto/x509v3/v3_sxnet.c",
|
||||
"crypto/x509v3/v3_utl.c",
|
||||
]
|
||||
|
||||
# A trick to take the generated err_data.c from another package.
|
||||
genrule(
|
||||
name = "err_data_c",
|
||||
srcs = ["@//third_party/boringssl:err_data_c"],
|
||||
outs = ["err_data.c"],
|
||||
cmd = "cp $< $@",
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "crypto",
|
||||
srcs = crypto_internal_headers + crypto_sources,
|
||||
hdrs = crypto_headers,
|
||||
# To avoid linking platform-specific ASM files.
|
||||
defines = ["OPENSSL_NO_ASM"],
|
||||
includes = ["include"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
@ -36,6 +36,7 @@ cc_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"@jsoncpp_git//:jsoncpp",
|
||||
":google_auth_provider",
|
||||
":http_request",
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
@ -57,11 +58,74 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "http_request_fake",
|
||||
testonly = 1,
|
||||
hdrs = [
|
||||
"http_request_fake.h",
|
||||
],
|
||||
deps = [
|
||||
":http_request",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "google_auth_provider",
|
||||
srcs = [
|
||||
"google_auth_provider.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"auth_provider.h",
|
||||
"google_auth_provider.h",
|
||||
],
|
||||
deps = [
|
||||
"@jsoncpp_git//:jsoncpp",
|
||||
":base64",
|
||||
":http_request",
|
||||
":oauth_client",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "oauth_client",
|
||||
srcs = [
|
||||
"oauth_client.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"oauth_client.h",
|
||||
],
|
||||
deps = [
|
||||
"@boringssl_git//:crypto",
|
||||
"@jsoncpp_git//:jsoncpp",
|
||||
":base64",
|
||||
":http_request",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "base64",
|
||||
srcs = [
|
||||
"base64.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"base64.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "gcs_file_system_test",
|
||||
size = "small",
|
||||
deps = [
|
||||
":gcs_file_system",
|
||||
":http_request_fake",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
@ -77,3 +141,53 @@ tf_cc_test(
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "oauth_client_test",
|
||||
size = "small",
|
||||
data = [
|
||||
"testdata/service_account_credentials.json",
|
||||
"testdata/service_account_public_key.txt",
|
||||
],
|
||||
deps = [
|
||||
"@boringssl_git//:crypto",
|
||||
":base64",
|
||||
":http_request_fake",
|
||||
":oauth_client",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "base64_test",
|
||||
size = "small",
|
||||
deps = [
|
||||
":base64",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "google_auth_provider_test",
|
||||
size = "small",
|
||||
data = [
|
||||
"testdata/application_default_credentials.json",
|
||||
"testdata/service_account_credentials.json",
|
||||
],
|
||||
deps = [
|
||||
":base64",
|
||||
":google_auth_provider",
|
||||
":http_request_fake",
|
||||
":oauth_client",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
52
tensorflow/core/platform/cloud/auth_provider.h
Normal file
52
tensorflow/core/platform/cloud/auth_provider.h
Normal file
@ -0,0 +1,52 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_PLATFORM_AUTH_PROVIDER_H_
|
||||
#define TENSORFLOW_CORE_PLATFORM_AUTH_PROVIDER_H_
|
||||
|
||||
#include <string>
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
/// Interface for a provider of authentication bearer tokens.
|
||||
class AuthProvider {
|
||||
public:
|
||||
virtual ~AuthProvider() {}
|
||||
|
||||
/// Returns the short-term authentication bearer token.
|
||||
virtual Status GetToken(string* t) = 0;
|
||||
|
||||
static Status GetToken(AuthProvider* provider, string* token) {
|
||||
if (!provider) {
|
||||
return errors::Internal("Auth provider is required.");
|
||||
}
|
||||
return provider->GetToken(token);
|
||||
}
|
||||
};
|
||||
|
||||
/// No-op auth provider, which will only work for public objects.
|
||||
class EmptyAuthProvider : public AuthProvider {
|
||||
public:
|
||||
Status GetToken(string* token) override {
|
||||
*token = "";
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_PLATFORM_AUTH_PROVIDER_H_
|
221
tensorflow/core/platform/cloud/base64.cc
Normal file
221
tensorflow/core/platform/cloud/base64.cc
Normal file
@ -0,0 +1,221 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/platform/cloud/base64.h"
|
||||
#include <memory>
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr signed char kBase64Bytes[] = {
|
||||
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
|
||||
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
|
||||
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
|
||||
-1, -1, -1, -1, -1, -1, -1, 0x3E, -1, -1, -1, 0x3F,
|
||||
0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3A, 0x3B, 0x3C, 0x3D, -1, -1,
|
||||
-1, 0x7F, -1, -1, -1, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06,
|
||||
0x07, 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12,
|
||||
0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, -1, -1, -1, -1, -1,
|
||||
-1, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F, 0x20, 0x21, 0x22, 0x23, 0x24,
|
||||
0x25, 0x26, 0x27, 0x28, 0x29, 0x2A, 0x2B, 0x2C, 0x2D, 0x2E, 0x2F, 0x30,
|
||||
0x31, 0x32, 0x33, -1, -1, -1, -1, -1};
|
||||
|
||||
constexpr char kBase64UrlSafeChars[] =
|
||||
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
|
||||
|
||||
constexpr char kPadChar = '=';
|
||||
constexpr char kPadByte = 0x7F;
|
||||
constexpr int kMultilineLineLen = 76;
|
||||
constexpr int kMultilineNumBlocks = kMultilineLineLen / 4;
|
||||
|
||||
Status Base64Encode(StringPiece source, bool multiline, bool with_padding,
|
||||
string *encoded) {
|
||||
if (!encoded) {
|
||||
return errors::FailedPrecondition("'encoded' cannot be nullptr.");
|
||||
}
|
||||
size_t data_size = source.size();
|
||||
const char *data = source.data();
|
||||
const char *base64_chars = kBase64UrlSafeChars;
|
||||
const size_t result_projected_size =
|
||||
4 * ((data_size + 3) / 3) +
|
||||
2 * (multiline ? (data_size / (3 * kMultilineNumBlocks)) : 0) + 1;
|
||||
size_t num_blocks = 0;
|
||||
size_t i = 0;
|
||||
std::unique_ptr<char[]> result(new char[result_projected_size]);
|
||||
char *current = result.get();
|
||||
|
||||
/* Encode each block. */
|
||||
while (data_size >= 3) {
|
||||
*current++ = base64_chars[(data[i] >> 2) & 0x3F];
|
||||
*current++ =
|
||||
base64_chars[((data[i] & 0x03) << 4) | ((data[i + 1] >> 4) & 0x0F)];
|
||||
*current++ =
|
||||
base64_chars[((data[i + 1] & 0x0F) << 2) | ((data[i + 2] >> 6) & 0x03)];
|
||||
*current++ = base64_chars[data[i + 2] & 0x3F];
|
||||
|
||||
data_size -= 3;
|
||||
i += 3;
|
||||
if (multiline && (++num_blocks == kMultilineNumBlocks)) {
|
||||
*current++ = '\r';
|
||||
*current++ = '\n';
|
||||
num_blocks = 0;
|
||||
}
|
||||
}
|
||||
|
||||
/* Take care of the tail. */
|
||||
if (data_size == 2) {
|
||||
*current++ = base64_chars[(data[i] >> 2) & 0x3F];
|
||||
*current++ =
|
||||
base64_chars[((data[i] & 0x03) << 4) | ((data[i + 1] >> 4) & 0x0F)];
|
||||
*current++ = base64_chars[(data[i + 1] & 0x0F) << 2];
|
||||
if (with_padding) {
|
||||
*current++ = kPadChar;
|
||||
}
|
||||
} else if (data_size == 1) {
|
||||
*current++ = base64_chars[(data[i] >> 2) & 0x3F];
|
||||
*current++ = base64_chars[(data[i] & 0x03) << 4];
|
||||
if (with_padding) {
|
||||
*current++ = kPadChar;
|
||||
*current++ = kPadChar;
|
||||
}
|
||||
}
|
||||
|
||||
if (current < result.get() ||
|
||||
current >= result.get() + result_projected_size) {
|
||||
return errors::Internal("Unexpected encoding bug.");
|
||||
}
|
||||
*current++ = '\0';
|
||||
*encoded = result.get();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void DecodeOneChar(const unsigned char *codes, unsigned char *result,
|
||||
size_t *result_offset) {
|
||||
const uint32_t packed = ((uint32_t)codes[0] << 2) | ((uint32_t)codes[1] >> 4);
|
||||
result[(*result_offset)++] = (unsigned char)packed;
|
||||
}
|
||||
|
||||
void DecodeTwoChars(const unsigned char *codes, unsigned char *result,
|
||||
size_t *result_offset) {
|
||||
const uint32_t packed = ((uint32_t)codes[0] << 10) |
|
||||
((uint32_t)codes[1] << 4) | ((uint32_t)codes[2] >> 2);
|
||||
result[(*result_offset)++] = (unsigned char)(packed >> 8);
|
||||
result[(*result_offset)++] = (unsigned char)(packed);
|
||||
}
|
||||
|
||||
Status DecodeGroup(const unsigned char *codes, size_t num_codes,
|
||||
unsigned char *result, size_t *result_offset) {
|
||||
if (num_codes > 4) {
|
||||
return errors::FailedPrecondition("Expected 4 or fewer codes.");
|
||||
}
|
||||
|
||||
/* Short end groups that may not have padding. */
|
||||
if (num_codes == 1) {
|
||||
return errors::FailedPrecondition(
|
||||
"Invalid group. Must be at least 2 bytes.");
|
||||
}
|
||||
if (num_codes == 2) {
|
||||
DecodeOneChar(codes, result, result_offset);
|
||||
return Status::OK();
|
||||
}
|
||||
if (num_codes == 3) {
|
||||
DecodeTwoChars(codes, result, result_offset);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/* Regular 4 byte groups with padding or not. */
|
||||
if (num_codes != 4) {
|
||||
return errors::FailedPrecondition("Expected exactly 4 codes.");
|
||||
}
|
||||
if (codes[0] == kPadByte || codes[1] == kPadByte) {
|
||||
return errors::FailedPrecondition("Invalid padding detected.");
|
||||
}
|
||||
if (codes[2] == kPadByte) {
|
||||
if (codes[3] == kPadByte) {
|
||||
DecodeOneChar(codes, result, result_offset);
|
||||
} else {
|
||||
return errors::FailedPrecondition("Invalid padding detected.");
|
||||
}
|
||||
} else if (codes[3] == kPadByte) {
|
||||
DecodeTwoChars(codes, result, result_offset);
|
||||
} else {
|
||||
/* No padding. */
|
||||
const uint32_t packed = ((uint32_t)codes[0] << 18) |
|
||||
((uint32_t)codes[1] << 12) |
|
||||
((uint32_t)codes[2] << 6) | codes[3];
|
||||
result[(*result_offset)++] = (unsigned char)(packed >> 16);
|
||||
result[(*result_offset)++] = (unsigned char)(packed >> 8);
|
||||
result[(*result_offset)++] = (unsigned char)(packed);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status Base64Encode(StringPiece source, string *encoded) {
|
||||
return Base64Encode(source, false, false, encoded);
|
||||
}
|
||||
|
||||
Status Base64Decode(StringPiece data, string *decoded) {
|
||||
if (!decoded) {
|
||||
return errors::FailedPrecondition("'decoded' cannot be nullptr.");
|
||||
}
|
||||
std::unique_ptr<unsigned char[]> result(new unsigned char[data.size()]);
|
||||
unsigned char *current = result.get();
|
||||
size_t result_size = 0;
|
||||
unsigned char codes[4];
|
||||
size_t num_codes = 0;
|
||||
|
||||
const char *b64 = data.data();
|
||||
size_t b64_len = data.size();
|
||||
while (b64_len--) {
|
||||
unsigned char c = (unsigned char)(*b64++);
|
||||
signed char code;
|
||||
if (c >= sizeof(kBase64Bytes)) continue;
|
||||
if (c == '+' || c == '/') {
|
||||
return errors::FailedPrecondition(
|
||||
strings::StrCat("Invalid character for url safe base64 ", c));
|
||||
}
|
||||
if (c == '-') {
|
||||
c = '+';
|
||||
} else if (c == '_') {
|
||||
c = '/';
|
||||
}
|
||||
code = kBase64Bytes[c];
|
||||
if (code == -1) {
|
||||
if (c != '\r' && c != '\n') {
|
||||
return errors::FailedPrecondition(
|
||||
strings::StrCat("Invalid character ", c));
|
||||
}
|
||||
} else {
|
||||
codes[num_codes++] = (unsigned char)code;
|
||||
if (num_codes == 4) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
DecodeGroup(codes, num_codes, current, &result_size));
|
||||
num_codes = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (num_codes != 0) {
|
||||
TF_RETURN_IF_ERROR(DecodeGroup(codes, num_codes, current, &result_size));
|
||||
}
|
||||
*decoded = string(reinterpret_cast<char *>(result.get()), result_size);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
36
tensorflow/core/platform/cloud/base64.h
Normal file
36
tensorflow/core/platform/cloud/base64.h
Normal file
@ -0,0 +1,36 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_PLATFORM_B64_H_
|
||||
#define TENSORFLOW_CORE_PLATFORM_B64_H_
|
||||
|
||||
#include <string>
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
/// \brief Converts data into base64 encoding.
|
||||
///
|
||||
/// See https://en.wikipedia.org/wiki/Base64
|
||||
Status Base64Encode(StringPiece data, string* encoded);
|
||||
|
||||
/// \brief Converts data from base64 encoding.
|
||||
///
|
||||
/// See https://en.wikipedia.org/wiki/Base64
|
||||
Status Base64Decode(StringPiece data, string* decoded);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_PLATFORM_B64_H_
|
33
tensorflow/core/platform/cloud/base64_test.cc
Normal file
33
tensorflow/core/platform/cloud/base64_test.cc
Normal file
@ -0,0 +1,33 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/platform/cloud/base64.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
TEST(Base64, EncodeDecode) {
|
||||
const string original = "a simple test message!";
|
||||
string encoded;
|
||||
TF_EXPECT_OK(Base64Encode(original, &encoded));
|
||||
EXPECT_EQ("YSBzaW1wbGUgdGVzdCBtZXNzYWdlIQ", encoded);
|
||||
|
||||
string decoded;
|
||||
TF_EXPECT_OK(Base64Decode(encoded, &decoded));
|
||||
EXPECT_EQ(original, decoded);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/strings/numbers.h"
|
||||
#include "tensorflow/core/lib/strings/scanner.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/cloud/google_auth_provider.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/regexp.h"
|
||||
@ -56,22 +57,6 @@ Status GetTmpFilename(string* filename) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/// No-op auth provider, which will only work for public objects.
|
||||
class EmptyAuthProvider : public AuthProvider {
|
||||
public:
|
||||
Status GetToken(string* token) const override {
|
||||
*token = "";
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
Status GetAuthToken(const AuthProvider* provider, string* token) {
|
||||
if (!provider) {
|
||||
return errors::Internal("Auth provider is required.");
|
||||
}
|
||||
return provider->GetToken(token);
|
||||
}
|
||||
|
||||
/// \brief Splits a GCS path to a bucket and an object.
|
||||
///
|
||||
/// For example, "gs://bucket-name/path/to/file.txt" gets split into
|
||||
@ -109,7 +94,7 @@ class GcsRandomAccessFile : public RandomAccessFile {
|
||||
Status Read(uint64 offset, size_t n, StringPiece* result,
|
||||
char* scratch) const override {
|
||||
string auth_token;
|
||||
TF_RETURN_IF_ERROR(GetAuthToken(auth_provider_, &auth_token));
|
||||
TF_RETURN_IF_ERROR(AuthProvider::GetToken(auth_provider_, &auth_token));
|
||||
|
||||
std::unique_ptr<HttpRequest> request(http_request_factory_->Create());
|
||||
TF_RETURN_IF_ERROR(request->Init());
|
||||
@ -198,7 +183,7 @@ class GcsWritableFile : public WritableFile {
|
||||
TF_RETURN_IF_ERROR(CheckWritable());
|
||||
outfile_.flush();
|
||||
string auth_token;
|
||||
TF_RETURN_IF_ERROR(GetAuthToken(auth_provider_, &auth_token));
|
||||
TF_RETURN_IF_ERROR(AuthProvider::GetToken(auth_provider_, &auth_token));
|
||||
|
||||
std::unique_ptr<HttpRequest> request(http_request_factory_->Create());
|
||||
TF_RETURN_IF_ERROR(request->Init());
|
||||
@ -242,7 +227,7 @@ class GcsReadOnlyMemoryRegion : public ReadOnlyMemoryRegion {
|
||||
} // namespace
|
||||
|
||||
GcsFileSystem::GcsFileSystem()
|
||||
: auth_provider_(new EmptyAuthProvider()),
|
||||
: auth_provider_(new GoogleAuthProvider()),
|
||||
http_request_factory_(new HttpRequest::Factory()) {}
|
||||
|
||||
GcsFileSystem::GcsFileSystem(
|
||||
@ -334,7 +319,7 @@ bool GcsFileSystem::FileExists(const string& fname) {
|
||||
}
|
||||
|
||||
string auth_token;
|
||||
if (!GetAuthToken(auth_provider_.get(), &auth_token).ok()) {
|
||||
if (!AuthProvider::GetToken(auth_provider_.get(), &auth_token).ok()) {
|
||||
LOG(ERROR) << "Could not get an auth token.";
|
||||
return false;
|
||||
}
|
||||
@ -363,7 +348,7 @@ Status GcsFileSystem::GetChildren(const string& dirname,
|
||||
TF_RETURN_IF_ERROR(ParseGcsPath(sanitized_dirname, &bucket, &object_prefix));
|
||||
|
||||
string auth_token;
|
||||
TF_RETURN_IF_ERROR(GetAuthToken(auth_provider_.get(), &auth_token));
|
||||
TF_RETURN_IF_ERROR(AuthProvider::GetToken(auth_provider_.get(), &auth_token));
|
||||
|
||||
std::unique_ptr<char[]> scratch(new char[kBufferSize]);
|
||||
StringPiece response_piece;
|
||||
@ -417,7 +402,7 @@ Status GcsFileSystem::DeleteFile(const string& fname) {
|
||||
TF_RETURN_IF_ERROR(ParseGcsPath(fname, &bucket, &object));
|
||||
|
||||
string auth_token;
|
||||
TF_RETURN_IF_ERROR(GetAuthToken(auth_provider_.get(), &auth_token));
|
||||
TF_RETURN_IF_ERROR(AuthProvider::GetToken(auth_provider_.get(), &auth_token));
|
||||
|
||||
std::unique_ptr<HttpRequest> request(http_request_factory_->Create());
|
||||
TF_RETURN_IF_ERROR(request->Init());
|
||||
@ -452,7 +437,7 @@ Status GcsFileSystem::GetFileSize(const string& fname, uint64* file_size) {
|
||||
TF_RETURN_IF_ERROR(ParseGcsPath(fname, &bucket, &object_prefix));
|
||||
|
||||
string auth_token;
|
||||
TF_RETURN_IF_ERROR(GetAuthToken(auth_provider_.get(), &auth_token));
|
||||
TF_RETURN_IF_ERROR(AuthProvider::GetToken(auth_provider_.get(), &auth_token));
|
||||
|
||||
std::unique_ptr<char[]> scratch(new char[kBufferSize]);
|
||||
StringPiece response_piece;
|
||||
@ -496,7 +481,7 @@ Status GcsFileSystem::RenameFile(const string& src, const string& target) {
|
||||
TF_RETURN_IF_ERROR(ParseGcsPath(target, &target_bucket, &target_object));
|
||||
|
||||
string auth_token;
|
||||
TF_RETURN_IF_ERROR(GetAuthToken(auth_provider_.get(), &auth_token));
|
||||
TF_RETURN_IF_ERROR(AuthProvider::GetToken(auth_provider_.get(), &auth_token));
|
||||
|
||||
std::unique_ptr<HttpRequest> request(http_request_factory_->Create());
|
||||
TF_RETURN_IF_ERROR(request->Init());
|
||||
|
@ -19,18 +19,12 @@ limitations under the License.
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/cloud/auth_provider.h"
|
||||
#include "tensorflow/core/platform/cloud/http_request.h"
|
||||
#include "tensorflow/core/platform/file_system.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
/// Interface for a provider of HTTP auth bearer tokens.
|
||||
class AuthProvider {
|
||||
public:
|
||||
virtual ~AuthProvider() {}
|
||||
virtual Status GetToken(string* t) const = 0;
|
||||
};
|
||||
|
||||
/// Google Cloud Storage implementation of a file system.
|
||||
class GcsFileSystem : public FileSystem {
|
||||
public:
|
||||
|
@ -16,101 +16,15 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/cloud/gcs_file_system.h"
|
||||
#include <fstream>
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/cloud/http_request_fake.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class FakeHttpRequest : public HttpRequest {
|
||||
public:
|
||||
FakeHttpRequest(const string& request, const string& response)
|
||||
: FakeHttpRequest(request, response, Status::OK()) {}
|
||||
|
||||
FakeHttpRequest(const string& request, const string& response,
|
||||
Status response_status)
|
||||
: expected_request_(request),
|
||||
response_(response),
|
||||
response_status_(response_status) {}
|
||||
|
||||
Status Init() override { return Status::OK(); }
|
||||
Status SetUri(const string& uri) override {
|
||||
actual_request_ += "Uri: " + uri + "\n";
|
||||
return Status::OK();
|
||||
}
|
||||
Status SetRange(uint64 start, uint64 end) override {
|
||||
actual_request_ += strings::StrCat("Range: ", start, "-", end, "\n");
|
||||
return Status::OK();
|
||||
}
|
||||
Status AddAuthBearerHeader(const string& auth_token) override {
|
||||
actual_request_ += "Auth Token: " + auth_token + "\n";
|
||||
return Status::OK();
|
||||
}
|
||||
Status SetDeleteRequest() override {
|
||||
actual_request_ += "Delete: yes\n";
|
||||
return Status::OK();
|
||||
}
|
||||
Status SetPostRequest(const string& body_filepath) override {
|
||||
std::ifstream stream(body_filepath);
|
||||
string content((std::istreambuf_iterator<char>(stream)),
|
||||
std::istreambuf_iterator<char>());
|
||||
actual_request_ += "Post body: " + content + "\n";
|
||||
return Status::OK();
|
||||
}
|
||||
Status SetPostRequest() override {
|
||||
actual_request_ += "Post: yes\n";
|
||||
return Status::OK();
|
||||
}
|
||||
Status SetResultBuffer(char* scratch, size_t size,
|
||||
StringPiece* result) override {
|
||||
scratch_ = scratch;
|
||||
size_ = size;
|
||||
result_ = result;
|
||||
return Status::OK();
|
||||
}
|
||||
Status Send() override {
|
||||
EXPECT_EQ(expected_request_, actual_request_) << "Unexpected HTTP request.";
|
||||
if (scratch_ && result_) {
|
||||
auto actual_size = std::min(response_.size(), size_);
|
||||
memcpy(scratch_, response_.c_str(), actual_size);
|
||||
*result_ = StringPiece(scratch_, actual_size);
|
||||
}
|
||||
return response_status_;
|
||||
}
|
||||
|
||||
private:
|
||||
char* scratch_ = nullptr;
|
||||
size_t size_ = 0;
|
||||
StringPiece* result_ = nullptr;
|
||||
string expected_request_;
|
||||
string actual_request_;
|
||||
string response_;
|
||||
Status response_status_;
|
||||
};
|
||||
|
||||
class FakeHttpRequestFactory : public HttpRequest::Factory {
|
||||
public:
|
||||
FakeHttpRequestFactory(const std::vector<HttpRequest*>* requests)
|
||||
: requests_(requests) {}
|
||||
|
||||
~FakeHttpRequestFactory() {
|
||||
EXPECT_EQ(current_index_, requests_->size())
|
||||
<< "Not all expected requests were made.";
|
||||
}
|
||||
|
||||
HttpRequest* Create() override {
|
||||
EXPECT_LT(current_index_, requests_->size())
|
||||
<< "Too many calls of HttpRequest factory.";
|
||||
return (*requests_)[current_index_++];
|
||||
}
|
||||
|
||||
private:
|
||||
const std::vector<HttpRequest*>* requests_;
|
||||
int current_index_ = 0;
|
||||
};
|
||||
|
||||
class FakeAuthProvider : public AuthProvider {
|
||||
public:
|
||||
Status GetToken(string* token) const override {
|
||||
Status GetToken(string* token) override {
|
||||
*token = "fake_token";
|
||||
return Status::OK();
|
||||
}
|
||||
|
190
tensorflow/core/platform/cloud/google_auth_provider.cc
Normal file
190
tensorflow/core/platform/cloud/google_auth_provider.cc
Normal file
@ -0,0 +1,190 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/platform/cloud/google_auth_provider.h"
|
||||
#include <pwd.h>
|
||||
#include <sys/types.h>
|
||||
#include <unistd.h>
|
||||
#include <fstream>
|
||||
#include "include/json/json.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/platform/cloud/base64.h"
|
||||
#include "tensorflow/core/platform/cloud/http_request.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
// The environment variable pointing to the file with local
|
||||
// Application Default Credentials.
|
||||
constexpr char kGoogleApplicationCredentials[] =
|
||||
"GOOGLE_APPLICATION_CREDENTIALS";
|
||||
|
||||
// The environment variable which can override '~/.config/gcloud' if set.
|
||||
constexpr char kCloudSdkConfig[] = "CLOUDSDK_CONFIG";
|
||||
|
||||
// The default path to the gcloud config folder, relative to the home folder.
|
||||
constexpr char kGCloudConfigFolder[] = ".config/gcloud/";
|
||||
|
||||
// The name of the well-known credentials JSON file in the gcloud config folder.
|
||||
constexpr char kWellKnownCredentialsFile[] =
|
||||
"application_default_credentials.json";
|
||||
|
||||
// The minimum time delta between now and the token expiration time
|
||||
// for the token to be re-used.
|
||||
constexpr int kExpirationTimeMarginSec = 10;
|
||||
|
||||
// The URL to retrieve the auth bearer token via OAuth with a refresh token.
|
||||
constexpr char kOAuthV3Url[] = "https://www.googleapis.com/oauth2/v3/token";
|
||||
|
||||
// The URL to retrieve the auth bearer token via OAuth with a private key.
|
||||
constexpr char kOAuthV4Url[] = "https://www.googleapis.com/oauth2/v4/token";
|
||||
|
||||
// The URL to retrieve the auth bearer token when running in Google Compute
|
||||
// Engine.
|
||||
constexpr char kGceTokenUrl[] =
|
||||
"http://metadata/computeMetadata/v1/instance/service-accounts/default/"
|
||||
"token";
|
||||
|
||||
// The authentication token scope to request.
|
||||
constexpr char kOAuthScope[] = "https://www.googleapis.com/auth/cloud-platform";
|
||||
|
||||
/// Returns whether the given path points to a readable file.
|
||||
bool IsFile(const string& filename) {
|
||||
std::ifstream fstream(filename.c_str());
|
||||
return fstream.good();
|
||||
}
|
||||
|
||||
/// Returns the credentials file name from the env variable.
|
||||
Status GetEnvironmentVariableFileName(string* filename) {
|
||||
if (!filename) {
|
||||
return errors::FailedPrecondition("'filename' cannot be nullptr.");
|
||||
}
|
||||
const char* result = std::getenv(kGoogleApplicationCredentials);
|
||||
if (!result || !IsFile(result)) {
|
||||
return errors::NotFound(strings::StrCat("$", kGoogleApplicationCredentials,
|
||||
" is not set or corrupt."));
|
||||
}
|
||||
*filename = result;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/// Returns the well known file produced by command 'gcloud auth login'.
|
||||
Status GetWellKnownFileName(string* filename) {
|
||||
if (!filename) {
|
||||
return errors::FailedPrecondition("'filename' cannot be nullptr.");
|
||||
}
|
||||
string config_dir;
|
||||
const char* config_dir_override = std::getenv(kCloudSdkConfig);
|
||||
if (config_dir_override) {
|
||||
config_dir = config_dir_override;
|
||||
} else {
|
||||
// Determine the home dir path.
|
||||
const char* home_dir = std::getenv("HOME");
|
||||
if (!home_dir) {
|
||||
return errors::FailedPrecondition("Could not read $HOME.");
|
||||
}
|
||||
config_dir = io::JoinPath(home_dir, kGCloudConfigFolder);
|
||||
}
|
||||
auto result = io::JoinPath(config_dir, kWellKnownCredentialsFile);
|
||||
if (!IsFile(result)) {
|
||||
return errors::NotFound(
|
||||
"Could not find the credentials file in the standard gcloud location.");
|
||||
}
|
||||
*filename = result;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
GoogleAuthProvider::GoogleAuthProvider()
|
||||
: GoogleAuthProvider(
|
||||
std::unique_ptr<OAuthClient>(new OAuthClient()),
|
||||
std::unique_ptr<HttpRequest::Factory>(new HttpRequest::Factory()),
|
||||
Env::Default()) {}
|
||||
|
||||
GoogleAuthProvider::GoogleAuthProvider(
|
||||
std::unique_ptr<OAuthClient> oauth_client,
|
||||
std::unique_ptr<HttpRequest::Factory> http_request_factory, Env* env)
|
||||
: oauth_client_(std::move(oauth_client)),
|
||||
http_request_factory_(std::move(http_request_factory)),
|
||||
env_(env) {}
|
||||
|
||||
Status GoogleAuthProvider::GetToken(string* t) {
|
||||
const uint64 now_sec = env_->NowSeconds();
|
||||
|
||||
if (!current_token_.empty() &&
|
||||
now_sec + kExpirationTimeMarginSec < expiration_timestamp_sec_) {
|
||||
*t = current_token_;
|
||||
return Status::OK();
|
||||
}
|
||||
if (GetTokenFromFiles().ok() || GetTokenFromGce().ok()) {
|
||||
*t = current_token_;
|
||||
return Status::OK();
|
||||
}
|
||||
return errors::FailedPrecondition(
|
||||
"All attempts to get a Google authentication bearer token failed.");
|
||||
}
|
||||
|
||||
Status GoogleAuthProvider::GetTokenFromFiles() {
|
||||
string credentials_filename;
|
||||
if (!GetEnvironmentVariableFileName(&credentials_filename).ok() &&
|
||||
!GetWellKnownFileName(&credentials_filename).ok()) {
|
||||
return errors::NotFound("Could not locate the credentials file.");
|
||||
}
|
||||
|
||||
Json::Value json;
|
||||
Json::Reader reader;
|
||||
std::ifstream credentials_fstream(credentials_filename);
|
||||
if (!reader.parse(credentials_fstream, json)) {
|
||||
return errors::FailedPrecondition(
|
||||
"Couldn't parse the JSON credentials file.");
|
||||
}
|
||||
if (json.isMember("refresh_token")) {
|
||||
TF_RETURN_IF_ERROR(oauth_client_->GetTokenFromRefreshTokenJson(
|
||||
json, kOAuthV3Url, ¤t_token_, &expiration_timestamp_sec_));
|
||||
} else if (json.isMember("private_key")) {
|
||||
TF_RETURN_IF_ERROR(oauth_client_->GetTokenFromServiceAccountJson(
|
||||
json, kOAuthV4Url, kOAuthScope, ¤t_token_,
|
||||
&expiration_timestamp_sec_));
|
||||
} else {
|
||||
return errors::FailedPrecondition(
|
||||
"Unexpected content of the JSON credentials file.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GoogleAuthProvider::GetTokenFromGce() {
|
||||
std::unique_ptr<HttpRequest> request(http_request_factory_->Create());
|
||||
std::unique_ptr<char[]> response_buffer(
|
||||
new char[OAuthClient::kResponseBufferSize]);
|
||||
const uint64 request_timestamp_sec = env_->NowSeconds();
|
||||
StringPiece response;
|
||||
TF_RETURN_IF_ERROR(request->Init());
|
||||
TF_RETURN_IF_ERROR(request->SetUri(kGceTokenUrl));
|
||||
TF_RETURN_IF_ERROR(request->AddHeader("Metadata-Flavor", "Google"));
|
||||
TF_RETURN_IF_ERROR(request->SetResultBuffer(
|
||||
response_buffer.get(), OAuthClient::kResponseBufferSize, &response));
|
||||
TF_RETURN_IF_ERROR(request->Send());
|
||||
|
||||
TF_RETURN_IF_ERROR(oauth_client_->ParseOAuthResponse(
|
||||
response, request_timestamp_sec, ¤t_token_,
|
||||
&expiration_timestamp_sec_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
57
tensorflow/core/platform/cloud/google_auth_provider.h
Normal file
57
tensorflow/core/platform/cloud/google_auth_provider.h
Normal file
@ -0,0 +1,57 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_PLATFORM_GOOGLE_AUTH_PROVIDER_H_
|
||||
#define TENSORFLOW_CORE_PLATFORM_GOOGLE_AUTH_PROVIDER_H_
|
||||
|
||||
#include <memory>
|
||||
#include "tensorflow/core/platform/cloud/auth_provider.h"
|
||||
#include "tensorflow/core/platform/cloud/oauth_client.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
/// Implementation based on Google Application Default Credentials.
|
||||
class GoogleAuthProvider : public AuthProvider {
|
||||
public:
|
||||
GoogleAuthProvider();
|
||||
explicit GoogleAuthProvider(
|
||||
std::unique_ptr<OAuthClient> oauth_client,
|
||||
std::unique_ptr<HttpRequest::Factory> http_request_factory, Env* env);
|
||||
virtual ~GoogleAuthProvider() {}
|
||||
|
||||
/// Returns the short-term authentication bearer token.
|
||||
Status GetToken(string* token) override;
|
||||
|
||||
private:
|
||||
/// \brief Gets the bearer token from files.
|
||||
///
|
||||
/// Tries the file from $GOOGLE_APPLICATION_CREDENTIALS and the
|
||||
/// standard gcloud tool's location.
|
||||
Status GetTokenFromFiles();
|
||||
|
||||
/// Gets the bearer token from Google Compute Engine environment.
|
||||
Status GetTokenFromGce();
|
||||
|
||||
std::unique_ptr<OAuthClient> oauth_client_;
|
||||
std::unique_ptr<HttpRequest::Factory> http_request_factory_;
|
||||
Env* env_;
|
||||
string current_token_;
|
||||
uint64 expiration_timestamp_sec_ = 0;
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(GoogleAuthProvider);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_PLATFORM_GOOGLE_AUTH_PROVIDER_H_
|
200
tensorflow/core/platform/cloud/google_auth_provider_test.cc
Normal file
200
tensorflow/core/platform/cloud/google_auth_provider_test.cc
Normal file
@ -0,0 +1,200 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/platform/cloud/google_auth_provider.h"
|
||||
#include <stdlib.h>
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/platform/cloud/http_request_fake.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr char kTestData[] = "core/platform/cloud/testdata/";
|
||||
|
||||
class FakeEnv : public EnvWrapper {
|
||||
public:
|
||||
FakeEnv() : EnvWrapper(Env::Default()) {}
|
||||
|
||||
uint64 NowSeconds() override { return now; }
|
||||
uint64 now = 10000;
|
||||
};
|
||||
|
||||
class FakeOAuthClient : public OAuthClient {
|
||||
public:
|
||||
Status GetTokenFromServiceAccountJson(
|
||||
Json::Value json, StringPiece oauth_server_uri, StringPiece scope,
|
||||
string* token, uint64* expiration_timestamp_sec) override {
|
||||
provided_credentials_json = json;
|
||||
*token = return_token;
|
||||
*expiration_timestamp_sec = return_expiration_timestamp;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/// Retrieves a bearer token using a refresh token.
|
||||
Status GetTokenFromRefreshTokenJson(
|
||||
Json::Value json, StringPiece oauth_server_uri, string* token,
|
||||
uint64* expiration_timestamp_sec) override {
|
||||
provided_credentials_json = json;
|
||||
*token = return_token;
|
||||
*expiration_timestamp_sec = return_expiration_timestamp;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
string return_token;
|
||||
uint64 return_expiration_timestamp;
|
||||
Json::Value provided_credentials_json;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
TEST(GoogleAuthProvider, EnvironmentVariable_Caching) {
|
||||
setenv("GOOGLE_APPLICATION_CREDENTIALS",
|
||||
io::JoinPath(
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestData).c_str(),
|
||||
"service_account_credentials.json")
|
||||
.c_str(),
|
||||
1);
|
||||
setenv("CLOUDSDK_CONFIG",
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestData).c_str(),
|
||||
1); // Will not be used.
|
||||
|
||||
auto oauth_client = new FakeOAuthClient;
|
||||
std::vector<HttpRequest*> requests;
|
||||
|
||||
FakeEnv env;
|
||||
GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client),
|
||||
std::unique_ptr<HttpRequest::Factory>(
|
||||
new FakeHttpRequestFactory(&requests)),
|
||||
&env);
|
||||
oauth_client->return_token = "fake-token";
|
||||
oauth_client->return_expiration_timestamp = env.NowSeconds() + 3600;
|
||||
|
||||
string token;
|
||||
TF_EXPECT_OK(provider.GetToken(&token));
|
||||
EXPECT_EQ("fake-token", token);
|
||||
EXPECT_EQ("fake_key_id",
|
||||
oauth_client->provided_credentials_json.get("private_key_id", "")
|
||||
.asString());
|
||||
|
||||
// Check that the token is re-used if not expired.
|
||||
oauth_client->return_token = "new-fake-token";
|
||||
env.now += 3000;
|
||||
TF_EXPECT_OK(provider.GetToken(&token));
|
||||
EXPECT_EQ("fake-token", token);
|
||||
|
||||
// Check that the token is re-generated when almost expired.
|
||||
env.now += 598; // 2 seconds before expiration
|
||||
TF_EXPECT_OK(provider.GetToken(&token));
|
||||
EXPECT_EQ("new-fake-token", token);
|
||||
}
|
||||
|
||||
TEST(GoogleAuthProvider, GCloudRefreshToken) {
|
||||
setenv("GOOGLE_APPLICATION_CREDENTIALS", "", 1);
|
||||
setenv("CLOUDSDK_CONFIG",
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(), kTestData).c_str(), 1);
|
||||
|
||||
auto oauth_client = new FakeOAuthClient;
|
||||
std::vector<HttpRequest*> requests;
|
||||
|
||||
FakeEnv env;
|
||||
GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client),
|
||||
std::unique_ptr<HttpRequest::Factory>(
|
||||
new FakeHttpRequestFactory(&requests)),
|
||||
&env);
|
||||
oauth_client->return_token = "fake-token";
|
||||
oauth_client->return_expiration_timestamp = env.NowSeconds() + 3600;
|
||||
|
||||
string token;
|
||||
TF_EXPECT_OK(provider.GetToken(&token));
|
||||
EXPECT_EQ("fake-token", token);
|
||||
EXPECT_EQ("fake-refresh-token",
|
||||
oauth_client->provided_credentials_json.get("refresh_token", "")
|
||||
.asString());
|
||||
}
|
||||
|
||||
TEST(GoogleAuthProvider, RunningOnGCE) {
|
||||
setenv("GOOGLE_APPLICATION_CREDENTIALS", "", 1);
|
||||
setenv("CLOUDSDK_CONFIG", "", 1);
|
||||
|
||||
auto oauth_client = new FakeOAuthClient;
|
||||
std::vector<HttpRequest*> requests(
|
||||
{new FakeHttpRequest(
|
||||
"Uri: http://metadata/computeMetadata/v1/instance/service-accounts"
|
||||
"/default/token\n"
|
||||
"Header Metadata-Flavor: Google\n",
|
||||
R"(
|
||||
{
|
||||
"access_token":"fake-gce-token",
|
||||
"expires_in": 3920,
|
||||
"token_type":"Bearer"
|
||||
})"),
|
||||
new FakeHttpRequest(
|
||||
"Uri: http://metadata/computeMetadata/v1/instance/service-accounts"
|
||||
"/default/token\n"
|
||||
"Header Metadata-Flavor: Google\n",
|
||||
R"(
|
||||
{
|
||||
"access_token":"new-fake-gce-token",
|
||||
"expires_in": 3920,
|
||||
"token_type":"Bearer"
|
||||
})")});
|
||||
|
||||
FakeEnv env;
|
||||
GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client),
|
||||
std::unique_ptr<HttpRequest::Factory>(
|
||||
new FakeHttpRequestFactory(&requests)),
|
||||
&env);
|
||||
|
||||
string token;
|
||||
TF_EXPECT_OK(provider.GetToken(&token));
|
||||
EXPECT_EQ("fake-gce-token", token);
|
||||
|
||||
// Check that the token is re-used if not expired.
|
||||
env.now += 3700;
|
||||
TF_EXPECT_OK(provider.GetToken(&token));
|
||||
EXPECT_EQ("fake-gce-token", token);
|
||||
|
||||
// Check that the token is re-generated when almost expired.
|
||||
env.now += 598; // 2 seconds before expiration
|
||||
TF_EXPECT_OK(provider.GetToken(&token));
|
||||
EXPECT_EQ("new-fake-gce-token", token);
|
||||
}
|
||||
|
||||
TEST(GoogleAuthProvider, NothingAvailable) {
|
||||
setenv("GOOGLE_APPLICATION_CREDENTIALS", "", 1);
|
||||
setenv("CLOUDSDK_CONFIG", "", 1);
|
||||
|
||||
auto oauth_client = new FakeOAuthClient;
|
||||
|
||||
std::vector<HttpRequest*> requests({new FakeHttpRequest(
|
||||
"Uri: http://metadata/computeMetadata/v1/instance/service-accounts"
|
||||
"/default/token\n"
|
||||
"Header Metadata-Flavor: Google\n",
|
||||
"", errors::NotFound("404"))});
|
||||
|
||||
FakeEnv env;
|
||||
GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client),
|
||||
std::unique_ptr<HttpRequest::Factory>(
|
||||
new FakeHttpRequestFactory(&requests)),
|
||||
&env);
|
||||
|
||||
string token;
|
||||
EXPECT_FALSE(provider.GetToken(&token).ok());
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/strings/scanner.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -200,8 +201,11 @@ HttpRequest::~HttpRequest() {
|
||||
}
|
||||
|
||||
Status HttpRequest::Init() {
|
||||
if (is_initialized_) {
|
||||
return errors::FailedPrecondition("Already initialized.");
|
||||
}
|
||||
if (!libcurl_) {
|
||||
return errors::Internal("libcurl proxy cannot be nullptr.");
|
||||
return errors::FailedPrecondition("libcurl proxy cannot be nullptr.");
|
||||
}
|
||||
TF_RETURN_IF_ERROR(libcurl_->MaybeLoadDll());
|
||||
curl_ = libcurl_->curl_easy_init();
|
||||
@ -211,6 +215,9 @@ Status HttpRequest::Init() {
|
||||
|
||||
libcurl_->curl_easy_setopt(curl_, CURLOPT_VERBOSE, kVerboseOutput);
|
||||
libcurl_->curl_easy_setopt(curl_, CURLOPT_CAPATH, kCertsPath);
|
||||
libcurl_->curl_easy_setopt(
|
||||
curl_, CURLOPT_USERAGENT,
|
||||
strings::StrCat("TensorFlow/", TF_VERSION_STRING).c_str());
|
||||
|
||||
// If response buffer is not set, libcurl will print results to stdout,
|
||||
// so we always set it.
|
||||
@ -240,13 +247,19 @@ Status HttpRequest::SetRange(uint64 start, uint64 end) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HttpRequest::AddHeader(const string& name, const string& value) {
|
||||
TF_RETURN_IF_ERROR(CheckInitialized());
|
||||
TF_RETURN_IF_ERROR(CheckNotSent());
|
||||
curl_headers_ = libcurl_->curl_slist_append(
|
||||
curl_headers_, strings::StrCat(name, ": ", value).c_str());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HttpRequest::AddAuthBearerHeader(const string& auth_token) {
|
||||
TF_RETURN_IF_ERROR(CheckInitialized());
|
||||
TF_RETURN_IF_ERROR(CheckNotSent());
|
||||
if (!auth_token.empty()) {
|
||||
curl_headers_ = libcurl_->curl_slist_append(
|
||||
curl_headers_,
|
||||
strings::StrCat("Authorization: Bearer ", auth_token).c_str());
|
||||
return AddHeader("Authorization", strings::StrCat("Bearer ", auth_token));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -285,6 +298,22 @@ Status HttpRequest::SetPostRequest(const string& body_filepath) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HttpRequest::SetPostRequest(const char* buffer, size_t size) {
|
||||
TF_RETURN_IF_ERROR(CheckInitialized());
|
||||
TF_RETURN_IF_ERROR(CheckNotSent());
|
||||
TF_RETURN_IF_ERROR(CheckMethodNotSet());
|
||||
is_method_set_ = true;
|
||||
curl_headers_ = libcurl_->curl_slist_append(
|
||||
curl_headers_, strings::StrCat("Content-Length: ", size).c_str());
|
||||
libcurl_->curl_easy_setopt(curl_, CURLOPT_POST, 1);
|
||||
libcurl_->curl_easy_setopt(curl_, CURLOPT_READDATA,
|
||||
reinterpret_cast<void*>(this));
|
||||
libcurl_->curl_easy_setopt(curl_, CURLOPT_READFUNCTION,
|
||||
&HttpRequest::ReadCallback);
|
||||
post_body_buffer_ = StringPiece(buffer, size);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HttpRequest::SetPostRequest() {
|
||||
TF_RETURN_IF_ERROR(CheckInitialized());
|
||||
TF_RETURN_IF_ERROR(CheckNotSent());
|
||||
@ -337,6 +366,19 @@ size_t HttpRequest::WriteCallback(const void* ptr, size_t size, size_t nmemb,
|
||||
return bytes_to_copy;
|
||||
}
|
||||
|
||||
size_t HttpRequest::ReadCallback(void* ptr, size_t size, size_t nmemb,
|
||||
FILE* this_object) {
|
||||
CHECK(ptr);
|
||||
auto that = reinterpret_cast<HttpRequest*>(this_object);
|
||||
CHECK(that->post_body_read_ <= that->post_body_buffer_.size());
|
||||
const size_t bytes_to_copy = std::min(
|
||||
size * nmemb, that->post_body_buffer_.size() - that->post_body_read_);
|
||||
memcpy(ptr, that->post_body_buffer_.data() + that->post_body_read_,
|
||||
bytes_to_copy);
|
||||
that->post_body_read_ += bytes_to_copy;
|
||||
return bytes_to_copy;
|
||||
}
|
||||
|
||||
Status HttpRequest::Send() {
|
||||
TF_RETURN_IF_ERROR(CheckInitialized());
|
||||
TF_RETURN_IF_ERROR(CheckNotSent());
|
||||
|
@ -16,7 +16,6 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_PLATFORM_HTTP_REQUEST_H_
|
||||
#define TENSORFLOW_CORE_PLATFORM_HTTP_REQUEST_H_
|
||||
|
||||
#include <functional>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <curl/curl.h>
|
||||
@ -64,6 +63,9 @@ class HttpRequest {
|
||||
/// (note that the right border is included).
|
||||
virtual Status SetRange(uint64 start, uint64 end);
|
||||
|
||||
/// Sets a request header.
|
||||
virtual Status AddHeader(const string& name, const string& value);
|
||||
|
||||
/// Sets the 'Authorization' header to the value of 'Bearer ' + auth_token.
|
||||
virtual Status AddAuthBearerHeader(const string& auth_token);
|
||||
|
||||
@ -75,6 +77,11 @@ class HttpRequest {
|
||||
/// The request body will be taken from the specified file.
|
||||
virtual Status SetPostRequest(const string& body_filepath);
|
||||
|
||||
/// \brief Makes the request a POST request.
|
||||
///
|
||||
/// The request body will be taken from the specified buffer.
|
||||
virtual Status SetPostRequest(const char* buffer, size_t size);
|
||||
|
||||
/// Makes the request a POST request.
|
||||
virtual Status SetPostRequest();
|
||||
|
||||
@ -91,15 +98,23 @@ class HttpRequest {
|
||||
virtual Status Send();
|
||||
|
||||
private:
|
||||
/// A callback in the form which can be accepted by libcurl.
|
||||
/// A write callback in the form which can be accepted by libcurl.
|
||||
static size_t WriteCallback(const void* ptr, size_t size, size_t nmemb,
|
||||
void* userdata);
|
||||
/// A read callback in the form which can be accepted by libcurl.
|
||||
static size_t ReadCallback(void* ptr, size_t size, size_t nmemb,
|
||||
FILE* userdata);
|
||||
Status CheckInitialized() const;
|
||||
Status CheckMethodNotSet() const;
|
||||
Status CheckNotSent() const;
|
||||
|
||||
std::unique_ptr<LibCurl> libcurl_;
|
||||
|
||||
FILE* post_body_ = nullptr;
|
||||
|
||||
StringPiece post_body_buffer_;
|
||||
size_t post_body_read_ = 0;
|
||||
|
||||
char* response_buffer_ = nullptr;
|
||||
size_t response_buffer_size_ = 0;
|
||||
size_t response_buffer_written_ = 0;
|
||||
|
166
tensorflow/core/platform/cloud/http_request_fake.h
Normal file
166
tensorflow/core/platform/cloud/http_request_fake.h
Normal file
@ -0,0 +1,166 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_PLATFORM_HTTP_REQUEST_FAKE_H_
|
||||
#define TENSORFLOW_CORE_PLATFORM_HTTP_REQUEST_FAKE_H_
|
||||
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <curl/curl.h>
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/platform/cloud/http_request.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
/// Fake HttpRequest for testing.
|
||||
class FakeHttpRequest : public HttpRequest {
|
||||
public:
|
||||
/// Return the response for the given request.
|
||||
FakeHttpRequest(const string& request, const string& response)
|
||||
: FakeHttpRequest(request, response, Status::OK(), nullptr) {}
|
||||
|
||||
/// \brief Return the response for the request and capture the POST body.
|
||||
///
|
||||
/// Post body is not expected to be a part of the 'request' parameter.
|
||||
FakeHttpRequest(const string& request, const string& response,
|
||||
string* captured_post_body)
|
||||
: FakeHttpRequest(request, response, Status::OK(), captured_post_body) {}
|
||||
|
||||
/// \brief Return the response and the status for the given request.
|
||||
FakeHttpRequest(const string& request, const string& response,
|
||||
Status response_status)
|
||||
: FakeHttpRequest(request, response, response_status, nullptr) {}
|
||||
|
||||
/// \brief Return the response and the status for the given request
|
||||
/// and capture the POST body.
|
||||
///
|
||||
/// Post body is not expected to be a part of the 'request' parameter.
|
||||
FakeHttpRequest(const string& request, const string& response,
|
||||
Status response_status, string* captured_post_body)
|
||||
: expected_request_(request),
|
||||
response_(response),
|
||||
response_status_(response_status),
|
||||
captured_post_body_(captured_post_body) {}
|
||||
|
||||
Status Init() override { return Status::OK(); }
|
||||
Status SetUri(const string& uri) override {
|
||||
actual_request_ += "Uri: " + uri + "\n";
|
||||
return Status::OK();
|
||||
}
|
||||
Status SetRange(uint64 start, uint64 end) override {
|
||||
actual_request_ += strings::StrCat("Range: ", start, "-", end, "\n");
|
||||
return Status::OK();
|
||||
}
|
||||
Status AddHeader(const string& name, const string& value) override {
|
||||
actual_request_ += "Header " + name + ": " + value + "\n";
|
||||
return Status::OK();
|
||||
}
|
||||
Status AddAuthBearerHeader(const string& auth_token) override {
|
||||
actual_request_ += "Auth Token: " + auth_token + "\n";
|
||||
return Status::OK();
|
||||
}
|
||||
Status SetDeleteRequest() override {
|
||||
actual_request_ += "Delete: yes\n";
|
||||
return Status::OK();
|
||||
}
|
||||
Status SetPostRequest(const string& body_filepath) override {
|
||||
std::ifstream stream(body_filepath);
|
||||
string content((std::istreambuf_iterator<char>(stream)),
|
||||
std::istreambuf_iterator<char>());
|
||||
if (captured_post_body_) {
|
||||
*captured_post_body_ = content;
|
||||
} else {
|
||||
actual_request_ += "Post body: " + content + "\n";
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
Status SetPostRequest(const char* buffer, size_t size) override {
|
||||
if (captured_post_body_) {
|
||||
*captured_post_body_ = string(buffer, size);
|
||||
} else {
|
||||
actual_request_ +=
|
||||
strings::StrCat("Post body: ", StringPiece(buffer, size), "\n");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
Status SetPostRequest() override {
|
||||
if (captured_post_body_) {
|
||||
*captured_post_body_ = "<empty>";
|
||||
} else {
|
||||
actual_request_ += "Post: yes\n";
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
Status SetResultBuffer(char* scratch, size_t size,
|
||||
StringPiece* result) override {
|
||||
scratch_ = scratch;
|
||||
size_ = size;
|
||||
result_ = result;
|
||||
return Status::OK();
|
||||
}
|
||||
Status Send() override {
|
||||
EXPECT_EQ(expected_request_, actual_request_) << "Unexpected HTTP request.";
|
||||
if (scratch_ && result_) {
|
||||
auto actual_size = std::min(response_.size(), size_);
|
||||
memcpy(scratch_, response_.c_str(), actual_size);
|
||||
*result_ = StringPiece(scratch_, actual_size);
|
||||
}
|
||||
return response_status_;
|
||||
}
|
||||
|
||||
private:
|
||||
char* scratch_ = nullptr;
|
||||
size_t size_ = 0;
|
||||
StringPiece* result_ = nullptr;
|
||||
string expected_request_;
|
||||
string actual_request_;
|
||||
string response_;
|
||||
Status response_status_;
|
||||
string* captured_post_body_ = nullptr;
|
||||
};
|
||||
|
||||
/// Fake HttpRequest factory for testing.
|
||||
class FakeHttpRequestFactory : public HttpRequest::Factory {
|
||||
public:
|
||||
FakeHttpRequestFactory(const std::vector<HttpRequest*>* requests)
|
||||
: requests_(requests) {}
|
||||
|
||||
~FakeHttpRequestFactory() {
|
||||
EXPECT_EQ(current_index_, requests_->size())
|
||||
<< "Not all expected requests were made.";
|
||||
}
|
||||
|
||||
HttpRequest* Create() override {
|
||||
EXPECT_LT(current_index_, requests_->size())
|
||||
<< "Too many calls of HttpRequest factory.";
|
||||
return (*requests_)[current_index_++];
|
||||
}
|
||||
|
||||
private:
|
||||
const std::vector<HttpRequest*>* requests_;
|
||||
int current_index_ = 0;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_PLATFORM_HTTP_REQUEST_FAKE_H_
|
@ -81,7 +81,7 @@ class FakeLibCurl : public LibCurl {
|
||||
CURLcode curl_easy_setopt(CURL* curl, CURLoption option,
|
||||
size_t (*param)(void*, size_t, size_t,
|
||||
FILE*)) override {
|
||||
EXPECT_EQ(param, &fread) << "Expected the standard fread() function.";
|
||||
read_callback = param;
|
||||
return CURLE_OK;
|
||||
}
|
||||
CURLcode curl_easy_setopt(CURL* curl, CURLoption option,
|
||||
@ -98,11 +98,11 @@ class FakeLibCurl : public LibCurl {
|
||||
}
|
||||
CURLcode curl_easy_perform(CURL* curl) override {
|
||||
if (read_data) {
|
||||
char buffer[100];
|
||||
char buffer[3];
|
||||
int bytes_read;
|
||||
posted_content = "";
|
||||
do {
|
||||
bytes_read = fread(buffer, 1, 100, read_data);
|
||||
bytes_read = read_callback(buffer, 1, sizeof(buffer), read_data);
|
||||
posted_content =
|
||||
strings::StrCat(posted_content, StringPiece(buffer, bytes_read));
|
||||
} while (bytes_read > 0);
|
||||
@ -158,11 +158,13 @@ class FakeLibCurl : public LibCurl {
|
||||
bool is_initialized = false;
|
||||
bool is_cleaned_up = false;
|
||||
std::vector<string>* headers = nullptr;
|
||||
FILE* read_data = nullptr;
|
||||
bool is_post = false;
|
||||
void* write_data = nullptr;
|
||||
size_t (*write_callback)(const void* ptr, size_t size, size_t nmemb,
|
||||
void* userdata) = nullptr;
|
||||
FILE* read_data = nullptr;
|
||||
size_t (*read_callback)(void* ptr, size_t size, size_t nmemb,
|
||||
FILE* userdata) = &fread;
|
||||
// Outcome of performing the request.
|
||||
string posted_content;
|
||||
};
|
||||
@ -193,7 +195,7 @@ TEST(HttpRequestTest, GetRequest) {
|
||||
EXPECT_FALSE(libcurl->is_post);
|
||||
}
|
||||
|
||||
TEST(HttpRequestTest, PostRequest_WithBody) {
|
||||
TEST(HttpRequestTest, PostRequest_WithBody_FromFile) {
|
||||
FakeLibCurl* libcurl = new FakeLibCurl("", 200);
|
||||
HttpRequest http_request((std::unique_ptr<LibCurl>(libcurl)));
|
||||
TF_EXPECT_OK(http_request.Init());
|
||||
@ -221,6 +223,29 @@ TEST(HttpRequestTest, PostRequest_WithBody) {
|
||||
std::remove(content_filename.c_str());
|
||||
}
|
||||
|
||||
TEST(HttpRequestTest, PostRequest_WithBody_FromMemory) {
|
||||
FakeLibCurl* libcurl = new FakeLibCurl("", 200);
|
||||
HttpRequest http_request((std::unique_ptr<LibCurl>(libcurl)));
|
||||
TF_EXPECT_OK(http_request.Init());
|
||||
|
||||
string content = "post body content";
|
||||
|
||||
TF_EXPECT_OK(http_request.SetUri("http://www.testuri.com"));
|
||||
TF_EXPECT_OK(http_request.AddAuthBearerHeader("fake-bearer"));
|
||||
TF_EXPECT_OK(http_request.SetPostRequest(content.c_str(), content.size()));
|
||||
TF_EXPECT_OK(http_request.Send());
|
||||
|
||||
// Check interactions with libcurl.
|
||||
EXPECT_TRUE(libcurl->is_initialized);
|
||||
EXPECT_EQ("http://www.testuri.com", libcurl->url);
|
||||
EXPECT_EQ("", libcurl->custom_request);
|
||||
EXPECT_EQ(2, libcurl->headers->size());
|
||||
EXPECT_EQ("Authorization: Bearer fake-bearer", (*libcurl->headers)[0]);
|
||||
EXPECT_EQ("Content-Length: 17", (*libcurl->headers)[1]);
|
||||
EXPECT_TRUE(libcurl->is_post);
|
||||
EXPECT_EQ("post body content", libcurl->posted_content);
|
||||
}
|
||||
|
||||
TEST(HttpRequestTest, PostRequest_WithoutBody) {
|
||||
FakeLibCurl* libcurl = new FakeLibCurl("", 200);
|
||||
HttpRequest http_request((std::unique_ptr<LibCurl>(libcurl)));
|
||||
|
288
tensorflow/core/platform/cloud/oauth_client.cc
Normal file
288
tensorflow/core/platform/cloud/oauth_client.cc
Normal file
@ -0,0 +1,288 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/platform/cloud/oauth_client.h"
|
||||
#include <pwd.h>
|
||||
#include <sys/types.h>
|
||||
#include <unistd.h>
|
||||
#include <fstream>
|
||||
#include <openssl/bio.h>
|
||||
#include <openssl/evp.h>
|
||||
#include <openssl/pem.h>
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/platform/cloud/base64.h"
|
||||
#include "tensorflow/core/platform/cloud/http_request.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
// The requested lifetime of a auth bearer token.
|
||||
constexpr int kRequestedTokenLifetimeSec = 3600;
|
||||
|
||||
// The crypto algorithm to be used with OAuth.
|
||||
constexpr char kCryptoAlgorithm[] = "RS256";
|
||||
|
||||
// The token type for the OAuth request.
|
||||
constexpr char kJwtType[] = "JWT";
|
||||
|
||||
// The grant type for the OAuth request. Already URL-encoded for convenience.
|
||||
constexpr char kGrantType[] =
|
||||
"urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Ajwt-bearer";
|
||||
|
||||
Status ReadJsonValue(Json::Value json, const string& name, Json::Value* value) {
|
||||
if (!value) {
|
||||
return errors::FailedPrecondition("'value' cannot be nullptr.");
|
||||
}
|
||||
*value = json.get(name, Json::Value::null);
|
||||
if (*value == Json::Value::null) {
|
||||
return errors::FailedPrecondition(
|
||||
strings::StrCat("Couldn't read a JSON value '", name, "'."));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ReadJsonString(Json::Value json, const string& name, string* value) {
|
||||
Json::Value json_value;
|
||||
TF_RETURN_IF_ERROR(ReadJsonValue(json, name, &json_value));
|
||||
if (!json_value.isString()) {
|
||||
return errors::FailedPrecondition(
|
||||
strings::StrCat("JSON value '", name, "' is not string."));
|
||||
}
|
||||
*value = json_value.asString();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ReadJsonInt(Json::Value json, const string& name, int64* value) {
|
||||
Json::Value json_value;
|
||||
TF_RETURN_IF_ERROR(ReadJsonValue(json, name, &json_value));
|
||||
if (!json_value.isIntegral()) {
|
||||
return errors::FailedPrecondition(
|
||||
strings::StrCat("JSON value '", name, "' is not integer."));
|
||||
}
|
||||
*value = json_value.asInt64();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CreateSignature(RSA* private_key, StringPiece to_sign,
|
||||
string* signature) {
|
||||
if (!private_key || !signature) {
|
||||
return errors::FailedPrecondition(
|
||||
"'private_key' and 'signature' cannot be nullptr.");
|
||||
}
|
||||
|
||||
const auto md = EVP_sha256();
|
||||
if (!md) {
|
||||
return errors::Internal("Could not get a sha256 encryptor.");
|
||||
}
|
||||
std::unique_ptr<EVP_MD_CTX, std::function<void(EVP_MD_CTX*)>> md_ctx(
|
||||
EVP_MD_CTX_create(), [](EVP_MD_CTX* ptr) { EVP_MD_CTX_destroy(ptr); });
|
||||
if (!md_ctx.get()) {
|
||||
return errors::Internal("Could not create MD_CTX.");
|
||||
}
|
||||
|
||||
std::unique_ptr<EVP_PKEY, std::function<void(EVP_PKEY*)>> key(
|
||||
EVP_PKEY_new(), [](EVP_PKEY* ptr) { EVP_PKEY_free(ptr); });
|
||||
EVP_PKEY_set1_RSA(key.get(), private_key);
|
||||
|
||||
if (EVP_DigestSignInit(md_ctx.get(), NULL, md, NULL, key.get()) != 1) {
|
||||
return errors::Internal("DigestInit failed.");
|
||||
}
|
||||
if (EVP_DigestSignUpdate(md_ctx.get(), to_sign.data(), to_sign.size()) != 1) {
|
||||
return errors::Internal("DigestUpdate failed.");
|
||||
}
|
||||
size_t sig_len = 0;
|
||||
if (EVP_DigestSignFinal(md_ctx.get(), NULL, &sig_len) != 1) {
|
||||
return errors::Internal("DigestFinal (get signature length) failed.");
|
||||
}
|
||||
std::unique_ptr<unsigned char[]> sig(new unsigned char[sig_len]);
|
||||
if (EVP_DigestSignFinal(md_ctx.get(), sig.get(), &sig_len) != 1) {
|
||||
return errors::Internal("DigestFinal (signature compute) failed.");
|
||||
}
|
||||
EVP_MD_CTX_cleanup(md_ctx.get());
|
||||
return Base64Encode(StringPiece(reinterpret_cast<char*>(sig.get()), sig_len),
|
||||
signature);
|
||||
}
|
||||
|
||||
/// Encodes a claim for a JSON web token (JWT) to make an OAuth request.
|
||||
Status EncodeJwtClaim(StringPiece client_email, StringPiece scope,
|
||||
StringPiece audience, uint64 request_timestamp_sec,
|
||||
string* encoded) {
|
||||
// Step 1: create the JSON with the claim.
|
||||
Json::Value root;
|
||||
root["iss"] = Json::Value(client_email.begin(), client_email.end());
|
||||
root["scope"] = Json::Value(scope.begin(), scope.end());
|
||||
root["aud"] = Json::Value(audience.begin(), audience.end());
|
||||
|
||||
const auto expiration_timestamp_sec =
|
||||
request_timestamp_sec + kRequestedTokenLifetimeSec;
|
||||
|
||||
root["iat"] = request_timestamp_sec;
|
||||
root["exp"] = expiration_timestamp_sec;
|
||||
|
||||
// Step 2: represent the JSON as a string.
|
||||
string claim = root.toStyledString();
|
||||
|
||||
// Step 3: encode the string as base64.
|
||||
return Base64Encode(claim, encoded);
|
||||
}
|
||||
|
||||
/// Encodes a header for a JSON web token (JWT) to make an OAuth request.
|
||||
Status EncodeJwtHeader(StringPiece key_id, string* encoded) {
|
||||
// Step 1: create the JSON with the header.
|
||||
Json::Value root;
|
||||
root["alg"] = kCryptoAlgorithm;
|
||||
root["typ"] = kJwtType;
|
||||
root["kid"] = Json::Value(key_id.begin(), key_id.end());
|
||||
|
||||
// Step 2: represent the JSON as a string.
|
||||
const string header = root.toStyledString();
|
||||
|
||||
// Step 3: encode the string as base64.
|
||||
return Base64Encode(header, encoded);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
OAuthClient::OAuthClient()
|
||||
: OAuthClient(
|
||||
std::unique_ptr<HttpRequest::Factory>(new HttpRequest::Factory()),
|
||||
Env::Default()) {}
|
||||
|
||||
OAuthClient::OAuthClient(
|
||||
std::unique_ptr<HttpRequest::Factory> http_request_factory, Env* env)
|
||||
: http_request_factory_(std::move(http_request_factory)), env_(env) {}
|
||||
|
||||
Status OAuthClient::GetTokenFromServiceAccountJson(
|
||||
Json::Value json, StringPiece oauth_server_uri, StringPiece scope,
|
||||
string* token, uint64* expiration_timestamp_sec) {
|
||||
if (!token || !expiration_timestamp_sec) {
|
||||
return errors::FailedPrecondition(
|
||||
"'token' and 'expiration_timestamp_sec' cannot be nullptr.");
|
||||
}
|
||||
string private_key_serialized, private_key_id, client_id, client_email;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ReadJsonString(json, "private_key", &private_key_serialized));
|
||||
TF_RETURN_IF_ERROR(ReadJsonString(json, "private_key_id", &private_key_id));
|
||||
TF_RETURN_IF_ERROR(ReadJsonString(json, "client_id", &client_id));
|
||||
TF_RETURN_IF_ERROR(ReadJsonString(json, "client_email", &client_email));
|
||||
|
||||
std::unique_ptr<BIO, std::function<void(BIO*)>> bio(
|
||||
BIO_new(BIO_s_mem()), [](BIO* ptr) { BIO_free_all(ptr); });
|
||||
if (BIO_puts(bio.get(), private_key_serialized.c_str()) !=
|
||||
static_cast<int>(private_key_serialized.size())) {
|
||||
return errors::Internal("Could not load the private key.");
|
||||
}
|
||||
std::unique_ptr<RSA, std::function<void(RSA*)>> private_key(
|
||||
PEM_read_bio_RSAPrivateKey(bio.get(), nullptr, nullptr, nullptr),
|
||||
[](RSA* ptr) { RSA_free(ptr); });
|
||||
if (!private_key.get()) {
|
||||
return errors::Internal("Could not deserialize the private key.");
|
||||
}
|
||||
|
||||
const uint64 request_timestamp_sec = env_->NowSeconds();
|
||||
|
||||
string encoded_claim, encoded_header;
|
||||
TF_RETURN_IF_ERROR(EncodeJwtHeader(private_key_id, &encoded_header));
|
||||
TF_RETURN_IF_ERROR(EncodeJwtClaim(client_email, scope, oauth_server_uri,
|
||||
request_timestamp_sec, &encoded_claim));
|
||||
const string to_sign = encoded_header + "." + encoded_claim;
|
||||
string signature;
|
||||
TF_RETURN_IF_ERROR(CreateSignature(private_key.get(), to_sign, &signature));
|
||||
const string jwt = to_sign + "." + signature;
|
||||
const string request_body =
|
||||
strings::StrCat("grant_type=", kGrantType, "&assertion=", jwt);
|
||||
|
||||
// Send the request to the Google OAuth 2.0 server to get the token.
|
||||
std::unique_ptr<HttpRequest> request(http_request_factory_->Create());
|
||||
std::unique_ptr<char[]> response_buffer(new char[kResponseBufferSize]);
|
||||
StringPiece response;
|
||||
TF_RETURN_IF_ERROR(request->Init());
|
||||
TF_RETURN_IF_ERROR(request->SetUri(oauth_server_uri.ToString()));
|
||||
TF_RETURN_IF_ERROR(
|
||||
request->SetPostRequest(request_body.c_str(), request_body.size()));
|
||||
TF_RETURN_IF_ERROR(request->SetResultBuffer(response_buffer.get(),
|
||||
kResponseBufferSize, &response));
|
||||
TF_RETURN_IF_ERROR(request->Send());
|
||||
|
||||
TF_RETURN_IF_ERROR(ParseOAuthResponse(response, request_timestamp_sec, token,
|
||||
expiration_timestamp_sec));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OAuthClient::GetTokenFromRefreshTokenJson(
|
||||
Json::Value json, StringPiece oauth_server_uri, string* token,
|
||||
uint64* expiration_timestamp_sec) {
|
||||
if (!token || !expiration_timestamp_sec) {
|
||||
return errors::FailedPrecondition(
|
||||
"'token' and 'expiration_timestamp_sec' cannot be nullptr.");
|
||||
}
|
||||
string client_id, client_secret, refresh_token;
|
||||
TF_RETURN_IF_ERROR(ReadJsonString(json, "client_id", &client_id));
|
||||
TF_RETURN_IF_ERROR(ReadJsonString(json, "client_secret", &client_secret));
|
||||
TF_RETURN_IF_ERROR(ReadJsonString(json, "refresh_token", &refresh_token));
|
||||
|
||||
const auto request_body = strings::StrCat(
|
||||
"client_id=", client_id, "&client_secret=", client_secret,
|
||||
"&refresh_token=", refresh_token, "&grant_type=refresh_token");
|
||||
|
||||
const uint64 request_timestamp_sec = env_->NowSeconds();
|
||||
|
||||
std::unique_ptr<HttpRequest> request(http_request_factory_->Create());
|
||||
std::unique_ptr<char[]> response_buffer(new char[kResponseBufferSize]);
|
||||
StringPiece response;
|
||||
TF_RETURN_IF_ERROR(request->Init());
|
||||
TF_RETURN_IF_ERROR(request->SetUri(oauth_server_uri.ToString()));
|
||||
TF_RETURN_IF_ERROR(
|
||||
request->SetPostRequest(request_body.c_str(), request_body.size()));
|
||||
TF_RETURN_IF_ERROR(request->SetResultBuffer(response_buffer.get(),
|
||||
kResponseBufferSize, &response));
|
||||
TF_RETURN_IF_ERROR(request->Send());
|
||||
|
||||
TF_RETURN_IF_ERROR(ParseOAuthResponse(response, request_timestamp_sec, token,
|
||||
expiration_timestamp_sec));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OAuthClient::ParseOAuthResponse(StringPiece response,
|
||||
uint64 request_timestamp_sec,
|
||||
string* token,
|
||||
uint64* expiration_timestamp_sec) {
|
||||
if (!token || !expiration_timestamp_sec) {
|
||||
return errors::FailedPrecondition(
|
||||
"'token' and 'expiration_timestamp_sec' cannot be nullptr.");
|
||||
}
|
||||
Json::Value root;
|
||||
Json::Reader reader;
|
||||
if (!reader.parse(response.begin(), response.end(), root)) {
|
||||
return errors::Internal("Couldn't parse JSON response from OAuth server.");
|
||||
}
|
||||
|
||||
string token_type;
|
||||
TF_RETURN_IF_ERROR(ReadJsonString(root, "token_type", &token_type));
|
||||
if (token_type != "Bearer") {
|
||||
return errors::FailedPrecondition("Unexpected Oauth token type: " +
|
||||
token_type);
|
||||
}
|
||||
int64 expires_in;
|
||||
TF_RETURN_IF_ERROR(ReadJsonInt(root, "expires_in", &expires_in));
|
||||
*expiration_timestamp_sec = request_timestamp_sec + expires_in;
|
||||
TF_RETURN_IF_ERROR(ReadJsonString(root, "access_token", token));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
65
tensorflow/core/platform/cloud/oauth_client.h
Normal file
65
tensorflow/core/platform/cloud/oauth_client.h
Normal file
@ -0,0 +1,65 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_CLOUD_OAUTH_CLIENT_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_CLOUD_OAUTH_CLIENT_H_
|
||||
|
||||
#include <memory>
|
||||
#include "include/json/json.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/cloud/http_request.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
/// OAuth 2.0 client.
|
||||
class OAuthClient {
|
||||
public:
|
||||
OAuthClient();
|
||||
explicit OAuthClient(
|
||||
std::unique_ptr<HttpRequest::Factory> http_request_factory, Env* env);
|
||||
virtual ~OAuthClient() {}
|
||||
|
||||
/// \brief Retrieves a bearer token using a private key.
|
||||
///
|
||||
/// Retrieves the authentication bearer token using a JSON file
|
||||
/// with the client's private key.
|
||||
virtual Status GetTokenFromServiceAccountJson(
|
||||
Json::Value json, StringPiece oauth_server_uri, StringPiece scope,
|
||||
string* token, uint64* expiration_timestamp_sec);
|
||||
|
||||
/// Retrieves a bearer token using a refresh token.
|
||||
virtual Status GetTokenFromRefreshTokenJson(Json::Value json,
|
||||
StringPiece oauth_server_uri,
|
||||
string* token,
|
||||
uint64* expiration_timestamp_sec);
|
||||
|
||||
/// Parses the JSON response with the token from an OAuth 2.0 server.
|
||||
virtual Status ParseOAuthResponse(StringPiece response,
|
||||
uint64 request_timestamp_sec, string* token,
|
||||
uint64* expiration_timestamp_sec);
|
||||
|
||||
/// The max size of the JSON response from an OAuth 2.0 server, in bytes.
|
||||
static constexpr size_t kResponseBufferSize = 1000;
|
||||
|
||||
private:
|
||||
std::unique_ptr<HttpRequest::Factory> http_request_factory_;
|
||||
Env* env_;
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(OAuthClient);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CORE_PLATFORM_CLOUD_OAUTH_CLIENT_H_
|
203
tensorflow/core/platform/cloud/oauth_client_test.cc
Normal file
203
tensorflow/core/platform/cloud/oauth_client_test.cc
Normal file
@ -0,0 +1,203 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/platform/cloud/oauth_client.h"
|
||||
#include <fstream>
|
||||
#include <openssl/bio.h>
|
||||
#include <openssl/evp.h>
|
||||
#include <openssl/pem.h>
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/lib/strings/scanner.h"
|
||||
#include "tensorflow/core/platform/cloud/base64.h"
|
||||
#include "tensorflow/core/platform/cloud/http_request_fake.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
constexpr char kTestData[] = "core/platform/cloud/testdata/";
|
||||
|
||||
constexpr char kTokenJson[] = R"(
|
||||
{
|
||||
"access_token":"1/fFAGRNJru1FTz70BzhT3Zg",
|
||||
"expires_in":3920,
|
||||
"token_type":"Bearer"
|
||||
})";
|
||||
|
||||
class FakeEnv : public EnvWrapper {
|
||||
public:
|
||||
FakeEnv() : EnvWrapper(Env::Default()) {}
|
||||
|
||||
uint64 NowSeconds() override { return now; }
|
||||
uint64 now = 10000;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
TEST(OAuthClientTest, ParseOAuthResponse) {
|
||||
const uint64 request_timestamp = 100;
|
||||
string token;
|
||||
uint64 expiration_timestamp;
|
||||
TF_EXPECT_OK(OAuthClient().ParseOAuthResponse(kTokenJson, request_timestamp,
|
||||
&token, &expiration_timestamp));
|
||||
EXPECT_EQ("1/fFAGRNJru1FTz70BzhT3Zg", token);
|
||||
EXPECT_EQ(4020, expiration_timestamp);
|
||||
}
|
||||
|
||||
TEST(OAuthClientTest, GetTokenFromRefreshTokenJson) {
|
||||
const string credentials_json = R"(
|
||||
{
|
||||
"client_id": "test_client_id",
|
||||
"client_secret": "test_client_secret",
|
||||
"refresh_token": "test_refresh_token",
|
||||
"type": "authorized_user"
|
||||
})";
|
||||
Json::Value json;
|
||||
Json::Reader reader;
|
||||
ASSERT_TRUE(reader.parse(credentials_json, json));
|
||||
|
||||
std::vector<HttpRequest*> requests({new FakeHttpRequest(
|
||||
"Uri: https://www.googleapis.com/oauth2/v3/token\n"
|
||||
"Post body: client_id=test_client_id&"
|
||||
"client_secret=test_client_secret&"
|
||||
"refresh_token=test_refresh_token&grant_type=refresh_token\n",
|
||||
kTokenJson)});
|
||||
FakeEnv env;
|
||||
OAuthClient client(std::unique_ptr<HttpRequest::Factory>(
|
||||
new FakeHttpRequestFactory(&requests)),
|
||||
&env);
|
||||
string token;
|
||||
uint64 expiration_timestamp;
|
||||
TF_EXPECT_OK(client.GetTokenFromRefreshTokenJson(
|
||||
json, "https://www.googleapis.com/oauth2/v3/token", &token,
|
||||
&expiration_timestamp));
|
||||
EXPECT_EQ("1/fFAGRNJru1FTz70BzhT3Zg", token);
|
||||
EXPECT_EQ(13920, expiration_timestamp);
|
||||
}
|
||||
|
||||
TEST(OAuthClientTest, GetTokenFromServiceAccountJson) {
|
||||
std::ifstream credentials(
|
||||
io::JoinPath(io::JoinPath(testing::TensorFlowSrcRoot(), kTestData),
|
||||
"service_account_credentials.json"));
|
||||
ASSERT_TRUE(credentials.is_open());
|
||||
Json::Value json;
|
||||
Json::Reader reader;
|
||||
ASSERT_TRUE(reader.parse(credentials, json));
|
||||
|
||||
string post_body;
|
||||
std::vector<HttpRequest*> requests(
|
||||
{new FakeHttpRequest("Uri: https://www.googleapis.com/oauth2/v3/token\n",
|
||||
kTokenJson, &post_body)});
|
||||
FakeEnv env;
|
||||
OAuthClient client(std::unique_ptr<HttpRequest::Factory>(
|
||||
new FakeHttpRequestFactory(&requests)),
|
||||
&env);
|
||||
string token;
|
||||
uint64 expiration_timestamp;
|
||||
TF_EXPECT_OK(client.GetTokenFromServiceAccountJson(
|
||||
json, "https://www.googleapis.com/oauth2/v3/token",
|
||||
"https://test-token-scope.com", &token, &expiration_timestamp));
|
||||
EXPECT_EQ("1/fFAGRNJru1FTz70BzhT3Zg", token);
|
||||
EXPECT_EQ(13920, expiration_timestamp);
|
||||
|
||||
// Now look at the JWT claim that was sent to the OAuth server.
|
||||
StringPiece grant_type, assertion;
|
||||
ASSERT_TRUE(strings::Scanner(post_body)
|
||||
.OneLiteral("grant_type=")
|
||||
.RestartCapture()
|
||||
.ScanEscapedUntil('&')
|
||||
.StopCapture()
|
||||
.OneLiteral("&assertion=")
|
||||
.GetResult(&assertion, &grant_type));
|
||||
EXPECT_EQ("urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Ajwt-bearer",
|
||||
grant_type.ToString());
|
||||
|
||||
int last_dot = assertion.ToString().find_last_of(".");
|
||||
string header_dot_claim = assertion.ToString().substr(0, last_dot);
|
||||
string signature_encoded = assertion.ToString().substr(last_dot + 1);
|
||||
|
||||
// Check that 'signature' signs 'header_dot_claim'.
|
||||
|
||||
// Read the serialized public key.
|
||||
std::ifstream public_key_stream(
|
||||
io::JoinPath(io::JoinPath(testing::TensorFlowSrcRoot(), kTestData),
|
||||
"service_account_public_key.txt"));
|
||||
string public_key_serialized(
|
||||
(std::istreambuf_iterator<char>(public_key_stream)),
|
||||
(std::istreambuf_iterator<char>()));
|
||||
|
||||
// Deserialize the public key.
|
||||
auto bio = BIO_new(BIO_s_mem());
|
||||
RSA* public_key = nullptr;
|
||||
EXPECT_EQ(public_key_serialized.size(),
|
||||
BIO_puts(bio, public_key_serialized.c_str()));
|
||||
public_key = PEM_read_bio_RSA_PUBKEY(bio, nullptr, nullptr, nullptr);
|
||||
EXPECT_TRUE(public_key) << "Could not load the public key from testdata.";
|
||||
|
||||
// Deserialize the signature.
|
||||
string signature;
|
||||
TF_EXPECT_OK(Base64Decode(signature_encoded, &signature));
|
||||
|
||||
// Actually cryptographically verify the signature.
|
||||
const auto md = EVP_sha256();
|
||||
auto md_ctx = EVP_MD_CTX_create();
|
||||
auto key = EVP_PKEY_new();
|
||||
EVP_PKEY_set1_RSA(key, public_key);
|
||||
ASSERT_EQ(1, EVP_DigestVerifyInit(md_ctx, nullptr, md, nullptr, key));
|
||||
ASSERT_EQ(1, EVP_DigestVerifyUpdate(md_ctx, header_dot_claim.c_str(),
|
||||
header_dot_claim.size()));
|
||||
ASSERT_EQ(
|
||||
1,
|
||||
EVP_DigestVerifyFinal(
|
||||
md_ctx, const_cast<unsigned char*>(
|
||||
reinterpret_cast<const unsigned char*>(signature.data())),
|
||||
signature.size()));
|
||||
EVP_MD_CTX_cleanup(md_ctx);
|
||||
|
||||
// Free all the crypto-related resources.
|
||||
EVP_PKEY_free(key);
|
||||
EVP_MD_CTX_destroy(md_ctx);
|
||||
RSA_free(public_key);
|
||||
BIO_free_all(bio);
|
||||
|
||||
// Now check the content of the header and the claim.
|
||||
int dot = header_dot_claim.find_last_of(".");
|
||||
string header_encoded = header_dot_claim.substr(0, dot);
|
||||
string claim_encoded = header_dot_claim.substr(dot + 1);
|
||||
|
||||
string header, claim;
|
||||
TF_EXPECT_OK(Base64Decode(header_encoded, &header));
|
||||
TF_EXPECT_OK(Base64Decode(claim_encoded, &claim));
|
||||
|
||||
Json::Value header_json, claim_json;
|
||||
EXPECT_TRUE(reader.parse(header, header_json));
|
||||
EXPECT_EQ("RS256", header_json.get("alg", Json::Value::null).asString());
|
||||
EXPECT_EQ("JWT", header_json.get("typ", Json::Value::null).asString());
|
||||
EXPECT_EQ("fake_key_id",
|
||||
header_json.get("kid", Json::Value::null).asString());
|
||||
|
||||
EXPECT_TRUE(reader.parse(claim, claim_json));
|
||||
EXPECT_EQ("fake-test-project.iam.gserviceaccount.com",
|
||||
claim_json.get("iss", Json::Value::null).asString());
|
||||
EXPECT_EQ("https://test-token-scope.com",
|
||||
claim_json.get("scope", Json::Value::null).asString());
|
||||
EXPECT_EQ("https://www.googleapis.com/oauth2/v3/token",
|
||||
claim_json.get("aud", Json::Value::null).asString());
|
||||
EXPECT_EQ(10000, claim_json.get("iat", Json::Value::null).asInt64());
|
||||
EXPECT_EQ(13600, claim_json.get("exp", Json::Value::null).asInt64());
|
||||
}
|
||||
} // namespace tensorflow
|
6
tensorflow/core/platform/cloud/testdata/application_default_credentials.json
vendored
Normal file
6
tensorflow/core/platform/cloud/testdata/application_default_credentials.json
vendored
Normal file
@ -0,0 +1,6 @@
|
||||
{
|
||||
"client_id": "fake-client-id.apps.googleusercontent.com",
|
||||
"client_secret": "fake-client-secret",
|
||||
"refresh_token": "fake-refresh-token",
|
||||
"type": "authorized_user"
|
||||
}
|
12
tensorflow/core/platform/cloud/testdata/service_account_credentials.json
vendored
Normal file
12
tensorflow/core/platform/cloud/testdata/service_account_credentials.json
vendored
Normal file
@ -0,0 +1,12 @@
|
||||
{
|
||||
"type": "service_account",
|
||||
"project_id": "fake_project_id",
|
||||
"private_key_id": "fake_key_id",
|
||||
"private_key": "-----BEGIN RSA PRIVATE KEY-----\nMIIEpAIBAAKCAQEAwrEZE6PWQYAy68mWPMuC6KAD02Sb9Pv/FHWpGKe8MxxdDiz/\nspb2KIrWxxZolStHgDXAOoElbAv4GbRLJiivEl8k0gSP9YpIE56nSxfXxRIDH25N\nI3fhRIs5hSG+/p3lLV5NsdNrm1CYHnEbTY7Ow7gpyxl0n+6q+ngguZTOGtBIMqVS\n4KIJlzTlJgeqvLFbtLP6uFc4OuGL6UZ+s4I7zSJVPBRxrFA+mOhBEPz/QjANBHBd\nIEhgh5VlmX/oRUK+D3zR/MnRTYtD8skiZSFMUix1eWvKw/1wX0mieH1rUQbpIYdJ\nTgFhROKuAJWVU7c+T6JHZwm8DqXaVz6oCJPlzwIDAQABAoIBAGHQVAb4A0b5P5wS\ntXZp0KVK72EfZPNaP7dpvcDzVKxhDad3mCeDjLyltG5lpbl7+vpBBwjdpY15Hfbc\nC/1p5ztVrcwOGr2D8d5ZkTc7DV6nRAZghkTRj82+HPH0GF8XuPJoNKSo0aFAhoyU\nyuDWZK8UMXsmmN9ZK3GXNOnIBxyUs703ueIgNkH9zlT2x0wmEs4toZKiPVZhLUrc\nG1zLfuf1onhB5xq7u0sYZCiJrvaVvzNrKune1IrBM+FK/dc3k0vF9NEvwCYxWuTj\nGwO2wU3U945Scj9718pxhMMxZpsPZfMZHrYcdMvjpPaKFhJjxb16kT4gvSdm015j\nLgpM1xECgYEA35/KW4npUPoltBZ2Gi/YPmGVfpyXz6ToOw9ENawiGdNrOQG1Pw+v\nPBV0+yvcp1AvlL46lp87xQrl0dYHwwsQ7eRqpeyG6PCXRN7pJXP9Dac6Tq07lu2g\nriltHcuw8WYLv0gjrNr8IaCN04VS30d8MayXgHuvR3+NHkBdryuKFgsCgYEA3uD7\nmNukdNxJBQhgOO8lCbLXdEjgFFDBuh/9GvpqaeILP4MIwpWj9tA9Hjw5JlK3qpHL\nvLsJinKMmaswX43Hzf8OAAhTkSC/TfIJwZTGuBPoDH4UnMD+83SAk8DDgWTUvz/6\n1ilR4zm3kus6ZxTA1zp3P5UFD2etbv+cmGkjHc0CgYBkpw1z6j0j/5Oc3UdHPiW8\n3jtlg6IpCfalLpfq+JFYwnpObGBiA/NBvf6rVvC4NjVUY9MHHKDQbblHm2he98ok\n6Vy/VhjbG/9aNmMGQpCx5oUuCHb71fUuruK4OIhp/x5meFfmY6J8mEF95VKJwSk7\nSo3efM1GBzlDVoFUaOp8RQKBgQDWBQ0Ul7WwUef8YTKk+V+DlKy4CVLDr1iYNieC\nRHzy+BD9CALdd3xfgU9vPT1Tw5KCxEX0EVb0D1NcLLrixu7arNTwyw4UCnIpkwYz\nUX4RPWxSsq9wZxNrDLB7MVuLYRu6GuHvzPXJUJ8rAZ6vZYpYIthnwd1+EXzFXcct\nw6fo8QKBgQClY0EmhGIoDHNPjPOGzl2hmZCm5FKPx9i2SOOVYuSMdPT3qTYOp4/Q\nUp1oqkbd1ZWxMlbuRljpwbUHRcj85O5bkmWylINjpA1hFqxcxtj1r9xRmeO9Qcqa\n89jOblkbSoVDE5CFHD0Cv4bFw09z/l6Ih9DOW4AlB5UN+byEUPsIdw==\n-----END RSA PRIVATE KEY-----",
|
||||
"client_email": "fake-test-project.iam.gserviceaccount.com",
|
||||
"client_id": "fake_client_id",
|
||||
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
|
||||
"token_uri": "https://accounts.google.com/o/oauth2/token",
|
||||
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
|
||||
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/fake-test-project.iam.gserviceaccount.com"
|
||||
}
|
9
tensorflow/core/platform/cloud/testdata/service_account_public_key.txt
vendored
Normal file
9
tensorflow/core/platform/cloud/testdata/service_account_public_key.txt
vendored
Normal file
@ -0,0 +1,9 @@
|
||||
-----BEGIN PUBLIC KEY-----
|
||||
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAwrEZE6PWQYAy68mWPMuC
|
||||
6KAD02Sb9Pv/FHWpGKe8MxxdDiz/spb2KIrWxxZolStHgDXAOoElbAv4GbRLJiiv
|
||||
El8k0gSP9YpIE56nSxfXxRIDH25NI3fhRIs5hSG+/p3lLV5NsdNrm1CYHnEbTY7O
|
||||
w7gpyxl0n+6q+ngguZTOGtBIMqVS4KIJlzTlJgeqvLFbtLP6uFc4OuGL6UZ+s4I7
|
||||
zSJVPBRxrFA+mOhBEPz/QjANBHBdIEhgh5VlmX/oRUK+D3zR/MnRTYtD8skiZSFM
|
||||
Uix1eWvKw/1wX0mieH1rUQbpIYdJTgFhROKuAJWVU7c+T6JHZwm8DqXaVz6oCJPl
|
||||
zwIDAQAB
|
||||
-----END PUBLIC KEY-----
|
@ -162,6 +162,10 @@ class Env {
|
||||
/// time. Only useful for computing deltas of time.
|
||||
virtual uint64 NowMicros() = 0;
|
||||
|
||||
/// \brief Returns the number of seconds since some fixed point in
|
||||
/// time. Only useful for computing deltas of time.
|
||||
virtual uint64 NowSeconds() { return NowMicros() / 1000000L; }
|
||||
|
||||
/// Sleeps/delays the thread for the prescribed number of micro-seconds.
|
||||
virtual void SleepForMicroseconds(int micros) = 0;
|
||||
|
||||
|
@ -120,3 +120,10 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
|
||||
name = "jsoncpp",
|
||||
actual = "@jsoncpp_git//:jsoncpp",
|
||||
)
|
||||
|
||||
native.new_git_repository(
|
||||
name = "boringssl_git",
|
||||
commit = "e72df93461c6d9d2b5698f10e16d3ab82f5adde3",
|
||||
remote = "https://boringssl.googlesource.com/boringssl",
|
||||
build_file = path_prefix + "boringssl.BUILD",
|
||||
)
|
||||
|
13
third_party/boringssl/BUILD
vendored
Normal file
13
third_party/boringssl/BUILD
vendored
Normal file
@ -0,0 +1,13 @@
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
licenses(["restricted"]) # OpenSSL license, partly BSD-like
|
||||
|
||||
# See https://boringssl.googlesource.com/boringssl/+/master/INCORPORATING.md
|
||||
# on how to re-generate err_data.c.
|
||||
|
||||
filegroup(
|
||||
name = "err_data_c",
|
||||
srcs = [
|
||||
"err_data.c",
|
||||
],
|
||||
)
|
1236
third_party/boringssl/err_data.c
vendored
Normal file
1236
third_party/boringssl/err_data.c
vendored
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user