Update test_util.TensorFlowTestCase's assertAllEqual() and assertAllClose() methods to support RaggedTensors.
PiperOrigin-RevId: 254878390
This commit is contained in:
parent
b6b7d99893
commit
b972f7334e
@ -1,266 +1,266 @@
|
||||
tensorflow/third_party/systemlibs/nsync.BUILD
|
||||
tensorflow/third_party/systemlibs/absl_py.absl.flags.BUILD
|
||||
tensorflow/third_party/systemlibs/google_cloud_cpp.google.cloud.bigtable.BUILD
|
||||
tensorflow/third_party/systemlibs/build_defs.bzl.tpl
|
||||
tensorflow/third_party/systemlibs/curl.BUILD
|
||||
tensorflow/third_party/systemlibs/cython.BUILD
|
||||
tensorflow/third_party/systemlibs/astor.BUILD
|
||||
tensorflow/third_party/systemlibs/jsoncpp.BUILD
|
||||
tensorflow/third_party/systemlibs/png.BUILD
|
||||
tensorflow/third_party/systemlibs/pcre.BUILD
|
||||
tensorflow/third_party/systemlibs/grpc.BUILD
|
||||
tensorflow/third_party/systemlibs/protobuf.BUILD
|
||||
tensorflow/third_party/systemlibs/double_conversion.BUILD
|
||||
tensorflow/third_party/systemlibs/six.BUILD
|
||||
tensorflow/third_party/systemlibs/zlib.BUILD
|
||||
tensorflow/third_party/systemlibs/lmdb.BUILD
|
||||
tensorflow/third_party/systemlibs/sqlite.BUILD
|
||||
tensorflow/third_party/systemlibs/gast.BUILD
|
||||
tensorflow/third_party/systemlibs/absl_py.BUILD
|
||||
tensorflow/third_party/systemlibs/boringssl.BUILD
|
||||
tensorflow/third_party/systemlibs/BUILD.tpl
|
||||
tensorflow/third_party/systemlibs/BUILD
|
||||
tensorflow/third_party/systemlibs/termcolor.BUILD
|
||||
tensorflow/third_party/systemlibs/gif.BUILD
|
||||
tensorflow/third_party/systemlibs/protobuf.bzl
|
||||
tensorflow/third_party/systemlibs/snappy.BUILD
|
||||
tensorflow/third_party/systemlibs/googleapis.BUILD
|
||||
tensorflow/third_party/systemlibs/google_cloud_cpp.BUILD
|
||||
tensorflow/third_party/systemlibs/re2.BUILD
|
||||
tensorflow/third_party/systemlibs/swig.BUILD
|
||||
tensorflow/third_party/systemlibs/syslibs_configure.bzl
|
||||
tensorflow/third_party/systemlibs/absl_py.absl.testing.BUILD
|
||||
tensorflow/third_party/pprof.BUILD
|
||||
tensorflow/third_party/toolchains/remote/execution.bzl.tpl
|
||||
tensorflow/third_party/toolchains/remote/BUILD.tpl
|
||||
tensorflow/third_party/toolchains/remote/BUILD
|
||||
tensorflow/third_party/toolchains/remote/configure.bzl
|
||||
tensorflow/third_party/toolchains/cpus/py3/BUILD
|
||||
tensorflow/third_party/toolchains/cpus/py/BUILD
|
||||
tensorflow/contrib/tpu/profiler/pip_package/BUILD
|
||||
tensorflow/contrib/tpu/profiler/pip_package/setup.py
|
||||
tensorflow/contrib/tpu/profiler/pip_package/README
|
||||
tensorflow/contrib/tpu/profiler/pip_package/build_pip_package.sh
|
||||
tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py
|
||||
tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/__init__.py
|
||||
tensorflow/contrib/mpi/BUILD
|
||||
tensorflow/stream_executor/build_defs.bzl
|
||||
tensorflow/python/autograph/core/config.py
|
||||
tensorflow/tools/ci_build/remote/BUILD
|
||||
tensorflow/tools/pip_package/README
|
||||
tensorflow/tools/pip_package/MANIFEST.in
|
||||
tensorflow/tools/pip_package/simple_console.py
|
||||
tensorflow/tools/pip_package/build_pip_package.sh
|
||||
tensorflow/tools/pip_package/check_load_py_test.py
|
||||
tensorflow/tools/pip_package/pip_smoke_test.py
|
||||
tensorflow/tools/pip_package/simple_console_for_windows.py
|
||||
tensorflow/tools/pip_package/setup.py
|
||||
tensorflow/tools/pip_package/BUILD
|
||||
tensorflow/tools/lib_package/concat_licenses.sh
|
||||
tensorflow/tools/lib_package/libtensorflow_test.c
|
||||
tensorflow/tools/lib_package/LibTensorFlowTest.java
|
||||
tensorflow/tools/lib_package/BUILD
|
||||
tensorflow/tools/lib_package/libtensorflow_test.sh
|
||||
tensorflow/tools/lib_package/README.md
|
||||
tensorflow/tools/lib_package/libtensorflow_java_test.sh
|
||||
tensorflow/tools/def_file_filter/def_file_filter_configure.bzl
|
||||
tensorflow/tools/def_file_filter/BUILD
|
||||
tensorflow/tools/def_file_filter/BUILD.tpl
|
||||
tensorflow/tools/def_file_filter/def_file_filter.py.tpl
|
||||
tensorflow/third_party/mkl/MKL_LICENSE
|
||||
tensorflow/third_party/mkl/LICENSE
|
||||
tensorflow/third_party/mkl/BUILD
|
||||
tensorflow/third_party/mkl/mkl.BUILD
|
||||
tensorflow/third_party/mkl/build_defs.bzl
|
||||
tensorflow/third_party/backports_weakref.BUILD
|
||||
tensorflow/third_party/toolchains/clang6/BUILD
|
||||
tensorflow/third_party/toolchains/clang6/README.md
|
||||
tensorflow/third_party/toolchains/clang6/repo.bzl
|
||||
tensorflow/third_party/toolchains/clang6/CROSSTOOL.tpl
|
||||
tensorflow/third_party/toolchains/clang6/clang.BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/tensorrt5/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/tensorrt5/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/cc_toolchain_config.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/dummy_toolchain.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.0/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.0/cc_toolchain_config.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/cuda10.0-cudnn7/cuda/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/cuda10.0-cudnn7/cuda/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/py3/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/py3/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/cc_toolchain_config.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/generate/workspace.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/generate/containers.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/generate/generate.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/generate/archives.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/generate/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/py/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/cuda10.0-cudnn7/cuda/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/cuda10.0-cudnn7/cuda/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/gcc7-nvcc-cuda10.0/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/gcc7-nvcc-cuda10.0/cc_toolchain_config.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/tensorrt5/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/tensorrt5/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/gcc7/cc_toolchain_config.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/gcc7/dummy_toolchain.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/gcc7/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/py3/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/win_1803/bazel_025/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/win_1803/py36/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/win_1803/BUILD
|
||||
tensorflow/third_party/toolchains/cpus/arm/arm_compiler_configure.bzl
|
||||
tensorflow/third_party/toolchains/cpus/arm/CROSSTOOL.tpl
|
||||
tensorflow/third_party/toolchains/cpus/arm/BUILD
|
||||
tensorflow/third_party/toolchains/cpus/py3/BUILD
|
||||
tensorflow/third_party/toolchains/cpus/py/BUILD
|
||||
tensorflow/third_party/toolchains/remote/configure.bzl
|
||||
tensorflow/third_party/toolchains/remote/BUILD.tpl
|
||||
tensorflow/third_party/toolchains/remote/BUILD
|
||||
tensorflow/third_party/toolchains/remote/execution.bzl.tpl
|
||||
tensorflow/third_party/toolchains/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/py3/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/cc_toolchain_config.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/dummy_toolchain.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.0/cc_toolchain_config.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.0/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/cuda10.0-cudnn7/cuda/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/cuda10.0-cudnn7/cuda/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/tensorrt5/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/tensorrt5/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/win_1803/bazel_025/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/win_1803/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/win_1803/py36/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/generate/containers.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/generate/workspace.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/generate/archives.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/generate/generate.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/generate/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/py3/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/py/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/gcc7/cc_toolchain_config.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/gcc7/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/gcc7/dummy_toolchain.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/gcc7-nvcc-cuda10.0/cc_toolchain_config.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/gcc7-nvcc-cuda10.0/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/cuda10.0-cudnn7/cuda/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/cuda10.0-cudnn7/cuda/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/tensorrt5/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/centos6/tensorrt5/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/py3/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/cc_toolchain_config.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/BUILD
|
||||
tensorflow/third_party/toolchains/clang6/repo.bzl
|
||||
tensorflow/third_party/toolchains/clang6/CROSSTOOL.tpl
|
||||
tensorflow/third_party/toolchains/clang6/BUILD
|
||||
tensorflow/third_party/toolchains/clang6/clang.BUILD
|
||||
tensorflow/third_party/toolchains/clang6/README.md
|
||||
tensorflow/third_party/farmhash.BUILD
|
||||
tensorflow/third_party/git/BUILD.tpl
|
||||
tensorflow/third_party/git/git_configure.bzl
|
||||
tensorflow/third_party/git/BUILD
|
||||
tensorflow/third_party/cub.BUILD
|
||||
tensorflow/third_party/gpus/cuda_configure.bzl
|
||||
tensorflow/third_party/gpus/rocm/build_defs.bzl.tpl
|
||||
tensorflow/third_party/gpus/rocm/BUILD.tpl
|
||||
tensorflow/third_party/gpus/rocm/BUILD
|
||||
tensorflow/third_party/gpus/rocm/rocm_config.h.tpl
|
||||
tensorflow/third_party/gpus/rocm_configure.bzl
|
||||
tensorflow/third_party/gpus/find_cuda_config.py
|
||||
tensorflow/third_party/gpus/BUILD
|
||||
tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl
|
||||
tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
|
||||
tensorflow/third_party/gpus/crosstool/LICENSE
|
||||
tensorflow/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl
|
||||
tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
|
||||
tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl
|
||||
tensorflow/third_party/gpus/crosstool/BUILD.tpl
|
||||
tensorflow/third_party/gpus/crosstool/BUILD
|
||||
tensorflow/third_party/gpus/cuda/build_defs.bzl.tpl
|
||||
tensorflow/third_party/gpus/cuda/LICENSE
|
||||
tensorflow/third_party/gpus/cuda/BUILD.tpl
|
||||
tensorflow/third_party/gpus/cuda/BUILD.windows.tpl
|
||||
tensorflow/third_party/gpus/cuda/cuda_config.h.tpl
|
||||
tensorflow/third_party/gpus/cuda/BUILD.tpl
|
||||
tensorflow/third_party/gpus/cuda/BUILD
|
||||
tensorflow/third_party/gpus/BUILD
|
||||
tensorflow/third_party/common.bzl
|
||||
tensorflow/third_party/tflite_mobilenet_quant.BUILD
|
||||
tensorflow/third_party/linenoise.BUILD
|
||||
tensorflow/third_party/curl.BUILD
|
||||
tensorflow/third_party/mkl_dnn/LICENSE
|
||||
tensorflow/third_party/mkl_dnn/mkldnn.BUILD
|
||||
tensorflow/third_party/fft2d/LICENSE
|
||||
tensorflow/third_party/fft2d/fft2d.BUILD
|
||||
tensorflow/third_party/fft2d/fft2d.h
|
||||
tensorflow/third_party/fft2d/fft.h
|
||||
tensorflow/third_party/fft2d/BUILD
|
||||
tensorflow/third_party/ngraph/LICENSE
|
||||
tensorflow/third_party/ngraph/build_defs.bzl
|
||||
tensorflow/third_party/ngraph/tbb.BUILD
|
||||
tensorflow/third_party/ngraph/ngraph.BUILD
|
||||
tensorflow/third_party/ngraph/nlohmann_json.BUILD
|
||||
tensorflow/third_party/ngraph/BUILD
|
||||
tensorflow/third_party/ngraph/ngraph_tf.BUILD
|
||||
tensorflow/third_party/ngraph/NGRAPH_LICENSE
|
||||
tensorflow/third_party/grpc/BUILD
|
||||
tensorflow/third_party/gpus/cuda/build_defs.bzl.tpl
|
||||
tensorflow/third_party/gpus/rocm/rocm_config.h.tpl
|
||||
tensorflow/third_party/gpus/rocm/BUILD
|
||||
tensorflow/third_party/gpus/rocm/BUILD.tpl
|
||||
tensorflow/third_party/gpus/rocm/build_defs.bzl.tpl
|
||||
tensorflow/third_party/gpus/cuda_configure.bzl
|
||||
tensorflow/third_party/gpus/find_cuda_config.py
|
||||
tensorflow/third_party/gpus/rocm_configure.bzl
|
||||
tensorflow/third_party/snappy.BUILD
|
||||
tensorflow/third_party/cython.BUILD
|
||||
tensorflow/third_party/icu/udata.patch
|
||||
tensorflow/third_party/astor.BUILD
|
||||
tensorflow/third_party/jsoncpp.BUILD
|
||||
tensorflow/third_party/sycl/crosstool/BUILD
|
||||
tensorflow/third_party/llvm/llvm.autogenerated.BUILD
|
||||
tensorflow/third_party/llvm/expand_cmake_vars.py
|
||||
tensorflow/third_party/llvm/llvm.bzl
|
||||
tensorflow/third_party/llvm/BUILD
|
||||
tensorflow/third_party/png.BUILD
|
||||
tensorflow/third_party/arm_neon_2_x86_sse.BUILD
|
||||
tensorflow/third_party/codegen.BUILD
|
||||
tensorflow/third_party/enum34.BUILD
|
||||
tensorflow/third_party/kafka/config.patch
|
||||
tensorflow/third_party/kafka/BUILD
|
||||
tensorflow/third_party/pcre.BUILD
|
||||
tensorflow/third_party/mpi/BUILD
|
||||
tensorflow/third_party/mpi/.gitignore
|
||||
tensorflow/third_party/clang_toolchain/BUILD
|
||||
tensorflow/third_party/clang_toolchain/download_clang.bzl
|
||||
tensorflow/third_party/clang_toolchain/cc_configure_clang.bzl
|
||||
tensorflow/third_party/tflite_ovic_testdata.BUILD
|
||||
tensorflow/third_party/repo.bzl
|
||||
tensorflow/third_party/png_fix_rpi.patch
|
||||
tensorflow/third_party/py/python_configure.bzl
|
||||
tensorflow/third_party/py/BUILD.tpl
|
||||
tensorflow/third_party/py/BUILD
|
||||
tensorflow/third_party/py/numpy/BUILD
|
||||
tensorflow/third_party/double_conversion.BUILD
|
||||
tensorflow/third_party/six.BUILD
|
||||
tensorflow/third_party/zlib.BUILD
|
||||
tensorflow/third_party/lmdb.BUILD
|
||||
tensorflow/third_party/nanopb.BUILD
|
||||
tensorflow/third_party/pybind11.BUILD
|
||||
tensorflow/third_party/android/android.bzl.tpl
|
||||
tensorflow/third_party/android/BUILD
|
||||
tensorflow/third_party/android/android_configure.BUILD.tpl
|
||||
tensorflow/third_party/android/android_configure.bzl
|
||||
tensorflow/third_party/tflite_mobilenet_float.BUILD
|
||||
tensorflow/third_party/sqlite.BUILD
|
||||
tensorflow/third_party/tensorrt/build_defs.bzl.tpl
|
||||
tensorflow/third_party/tensorrt/LICENSE
|
||||
tensorflow/third_party/tensorrt/tensorrt_configure.bzl
|
||||
tensorflow/third_party/tensorrt/tensorrt/include/tensorrt_config.h.tpl
|
||||
tensorflow/third_party/tensorrt/BUILD.tpl
|
||||
tensorflow/third_party/tensorrt/BUILD
|
||||
tensorflow/third_party/gast.BUILD
|
||||
tensorflow/third_party/mpi_collectives/BUILD
|
||||
tensorflow/third_party/libxsmm.BUILD
|
||||
tensorflow/third_party/eigen.BUILD
|
||||
tensorflow/third_party/com_google_absl.BUILD
|
||||
tensorflow/third_party/eigen3/LICENSE
|
||||
tensorflow/third_party/eigen3/gpu_packet_math.patch
|
||||
tensorflow/third_party/eigen3/BUILD
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/MatrixFunctions
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/SpecialFunctions
|
||||
tensorflow/third_party/farmhash.BUILD
|
||||
tensorflow/third_party/eigen3/Eigen/Cholesky
|
||||
tensorflow/third_party/eigen3/Eigen/QR
|
||||
tensorflow/third_party/eigen3/Eigen/LU
|
||||
tensorflow/third_party/eigen3/Eigen/Core
|
||||
tensorflow/third_party/eigen3/Eigen/SVD
|
||||
tensorflow/third_party/eigen3/Eigen/Eigenvalues
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProduct.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX512.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX2.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductNEON.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProduct.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatVecProduct.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductNEON.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX2.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool
|
||||
tensorflow/third_party/eigen3/Eigen/QR
|
||||
tensorflow/third_party/eigen3/Eigen/SVD
|
||||
tensorflow/third_party/eigen3/Eigen/LU
|
||||
tensorflow/third_party/eigen3/Eigen/Cholesky
|
||||
tensorflow/third_party/eigen3/Eigen/Eigenvalues
|
||||
tensorflow/third_party/eigen3/Eigen/Core
|
||||
tensorflow/third_party/BUILD
|
||||
tensorflow/third_party/termcolor.BUILD
|
||||
tensorflow/third_party/gif.BUILD
|
||||
tensorflow/third_party/tflite_mobilenet.BUILD
|
||||
tensorflow/third_party/__init__.py
|
||||
tensorflow/third_party/mkl/LICENSE
|
||||
tensorflow/third_party/mkl/build_defs.bzl
|
||||
tensorflow/third_party/mkl/mkl.BUILD
|
||||
tensorflow/third_party/mkl/MKL_LICENSE
|
||||
tensorflow/third_party/mkl/BUILD
|
||||
tensorflow/third_party/nccl/build_defs.bzl.tpl
|
||||
tensorflow/third_party/nccl/LICENSE
|
||||
tensorflow/third_party/nccl/archive.patch
|
||||
tensorflow/third_party/nccl/nccl_configure.bzl
|
||||
tensorflow/third_party/nccl/archive.BUILD
|
||||
tensorflow/third_party/nccl/BUILD
|
||||
tensorflow/third_party/nccl/system.BUILD.tpl
|
||||
tensorflow/third_party/snappy.BUILD
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/SpecialFunctions
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/MatrixFunctions
|
||||
tensorflow/third_party/eigen3/gpu_packet_math.patch
|
||||
tensorflow/third_party/eigen3/LICENSE
|
||||
tensorflow/third_party/eigen3/BUILD
|
||||
tensorflow/third_party/systemlibs/build_defs.bzl.tpl
|
||||
tensorflow/third_party/systemlibs/absl_py.BUILD
|
||||
tensorflow/third_party/systemlibs/curl.BUILD
|
||||
tensorflow/third_party/systemlibs/termcolor.BUILD
|
||||
tensorflow/third_party/systemlibs/absl_py.absl.flags.BUILD
|
||||
tensorflow/third_party/systemlibs/grpc.BUILD
|
||||
tensorflow/third_party/systemlibs/swig.BUILD
|
||||
tensorflow/third_party/systemlibs/protobuf.bzl
|
||||
tensorflow/third_party/systemlibs/protobuf.BUILD
|
||||
tensorflow/third_party/systemlibs/BUILD
|
||||
tensorflow/third_party/systemlibs/google_cloud_cpp.BUILD
|
||||
tensorflow/third_party/systemlibs/astor.BUILD
|
||||
tensorflow/third_party/systemlibs/six.BUILD
|
||||
tensorflow/third_party/systemlibs/absl_py.absl.testing.BUILD
|
||||
tensorflow/third_party/systemlibs/boringssl.BUILD
|
||||
tensorflow/third_party/systemlibs/nsync.BUILD
|
||||
tensorflow/third_party/systemlibs/google_cloud_cpp.google.cloud.bigtable.BUILD
|
||||
tensorflow/third_party/systemlibs/gif.BUILD
|
||||
tensorflow/third_party/systemlibs/pcre.BUILD
|
||||
tensorflow/third_party/systemlibs/BUILD.tpl
|
||||
tensorflow/third_party/systemlibs/snappy.BUILD
|
||||
tensorflow/third_party/systemlibs/gast.BUILD
|
||||
tensorflow/third_party/systemlibs/cython.BUILD
|
||||
tensorflow/third_party/systemlibs/double_conversion.BUILD
|
||||
tensorflow/third_party/systemlibs/zlib.BUILD
|
||||
tensorflow/third_party/systemlibs/jsoncpp.BUILD
|
||||
tensorflow/third_party/systemlibs/re2.BUILD
|
||||
tensorflow/third_party/systemlibs/lmdb.BUILD
|
||||
tensorflow/third_party/systemlibs/googleapis.BUILD
|
||||
tensorflow/third_party/systemlibs/png.BUILD
|
||||
tensorflow/third_party/systemlibs/syslibs_configure.bzl
|
||||
tensorflow/third_party/systemlibs/sqlite.BUILD
|
||||
tensorflow/third_party/python_runtime/BUILD
|
||||
tensorflow/third_party/googleapis.BUILD
|
||||
tensorflow/third_party/wrapt.BUILD
|
||||
tensorflow/third_party/sycl/crosstool/BUILD
|
||||
tensorflow/third_party/ngraph/LICENSE
|
||||
tensorflow/third_party/ngraph/tbb.BUILD
|
||||
tensorflow/third_party/ngraph/BUILD
|
||||
tensorflow/third_party/ngraph/ngraph.BUILD
|
||||
tensorflow/third_party/ngraph/build_defs.bzl
|
||||
tensorflow/third_party/ngraph/NGRAPH_LICENSE
|
||||
tensorflow/third_party/ngraph/ngraph_tf.BUILD
|
||||
tensorflow/third_party/ngraph/nlohmann_json.BUILD
|
||||
tensorflow/third_party/clang_toolchain/download_clang.bzl
|
||||
tensorflow/third_party/clang_toolchain/BUILD
|
||||
tensorflow/third_party/clang_toolchain/cc_configure_clang.bzl
|
||||
tensorflow/third_party/gast.BUILD
|
||||
tensorflow/third_party/llvm/BUILD
|
||||
tensorflow/third_party/llvm/expand_cmake_vars.py
|
||||
tensorflow/third_party/llvm/llvm.autogenerated.BUILD
|
||||
tensorflow/third_party/llvm/llvm.bzl
|
||||
tensorflow/third_party/icu/udata.patch
|
||||
tensorflow/third_party/fft2d/fft2d.h
|
||||
tensorflow/third_party/fft2d/BUILD
|
||||
tensorflow/third_party/fft2d/fft.h
|
||||
tensorflow/third_party/fft2d/LICENSE
|
||||
tensorflow/third_party/fft2d/fft2d.BUILD
|
||||
tensorflow/third_party/nccl/archive.BUILD
|
||||
tensorflow/third_party/nccl/LICENSE
|
||||
tensorflow/third_party/nccl/system.BUILD.tpl
|
||||
tensorflow/third_party/nccl/nccl_configure.bzl
|
||||
tensorflow/third_party/nccl/build_defs.bzl.tpl
|
||||
tensorflow/third_party/nccl/archive.patch
|
||||
tensorflow/third_party/nccl/BUILD
|
||||
tensorflow/third_party/boringssl/BUILD
|
||||
tensorflow/third_party/protobuf/BUILD
|
||||
tensorflow/third_party/backports_weakref.BUILD
|
||||
tensorflow/third_party/mpi/.gitignore
|
||||
tensorflow/third_party/mpi/BUILD
|
||||
tensorflow/third_party/tensorrt/LICENSE
|
||||
tensorflow/third_party/tensorrt/BUILD
|
||||
tensorflow/third_party/tensorrt/build_defs.bzl.tpl
|
||||
tensorflow/third_party/tensorrt/BUILD.tpl
|
||||
tensorflow/third_party/tensorrt/tensorrt_configure.bzl
|
||||
tensorflow/third_party/tensorrt/tensorrt/include/tensorrt_config.h.tpl
|
||||
tensorflow/third_party/kafka/config.patch
|
||||
tensorflow/third_party/kafka/BUILD
|
||||
tensorflow/third_party/android/BUILD
|
||||
tensorflow/third_party/android/android.bzl.tpl
|
||||
tensorflow/third_party/android/android_configure.bzl
|
||||
tensorflow/third_party/android/android_configure.BUILD.tpl
|
||||
tensorflow/third_party/tflite_smartreply.BUILD
|
||||
tensorflow/third_party/mkl_dnn/LICENSE
|
||||
tensorflow/third_party/mkl_dnn/mkldnn.BUILD
|
||||
tensorflow/third_party/pcre.BUILD
|
||||
tensorflow/third_party/pybind11.BUILD
|
||||
tensorflow/third_party/linenoise.BUILD
|
||||
tensorflow/third_party/sqlite.BUILD
|
||||
tensorflow/third_party/common.bzl
|
||||
tensorflow/third_party/com_google_absl.BUILD
|
||||
tensorflow/third_party/pprof.BUILD
|
||||
tensorflow/third_party/BUILD
|
||||
tensorflow/third_party/tflite_mobilenet_quant.BUILD
|
||||
tensorflow/third_party/wrapt.BUILD
|
||||
tensorflow/third_party/lmdb.BUILD
|
||||
tensorflow/third_party/git/BUILD.tpl
|
||||
tensorflow/third_party/git/BUILD
|
||||
tensorflow/third_party/git/git_configure.bzl
|
||||
tensorflow/third_party/protobuf/BUILD
|
||||
tensorflow/third_party/enum34.BUILD
|
||||
tensorflow/third_party/tflite_mobilenet.BUILD
|
||||
tensorflow/third_party/py/BUILD
|
||||
tensorflow/third_party/py/BUILD.tpl
|
||||
tensorflow/third_party/py/numpy/BUILD
|
||||
tensorflow/third_party/py/python_configure.bzl
|
||||
tensorflow/third_party/termcolor.BUILD
|
||||
tensorflow/third_party/png_fix_rpi.patch
|
||||
tensorflow/third_party/swig.BUILD
|
||||
tensorflow/compat_template.__init__.py
|
||||
tensorflow/tools/lib_package/libtensorflow_test.sh
|
||||
tensorflow/tools/lib_package/libtensorflow_java_test.sh
|
||||
tensorflow/tools/lib_package/libtensorflow_test.c
|
||||
tensorflow/tools/lib_package/concat_licenses.sh
|
||||
tensorflow/tools/lib_package/LibTensorFlowTest.java
|
||||
tensorflow/tools/lib_package/BUILD
|
||||
tensorflow/tools/lib_package/README.md
|
||||
tensorflow/tools/pip_package/check_load_py_test.py
|
||||
tensorflow/tools/pip_package/simple_console.py
|
||||
tensorflow/tools/pip_package/pip_smoke_test.py
|
||||
tensorflow/tools/pip_package/BUILD
|
||||
tensorflow/tools/pip_package/simple_console_for_windows.py
|
||||
tensorflow/tools/pip_package/build_pip_package.sh
|
||||
tensorflow/tools/pip_package/README
|
||||
tensorflow/tools/pip_package/setup.py
|
||||
tensorflow/tools/pip_package/MANIFEST.in
|
||||
tensorflow/tools/ci_build/remote/BUILD
|
||||
tensorflow/tools/def_file_filter/def_file_filter.py.tpl
|
||||
tensorflow/tools/def_file_filter/BUILD.tpl
|
||||
tensorflow/tools/def_file_filter/BUILD
|
||||
tensorflow/tools/def_file_filter/def_file_filter_configure.bzl
|
||||
tensorflow/api_template.__init__.py
|
||||
tensorflow/contrib/tpu/profiler/pip_package/BUILD
|
||||
tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/__init__.py
|
||||
tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py
|
||||
tensorflow/contrib/tpu/profiler/pip_package/build_pip_package.sh
|
||||
tensorflow/contrib/tpu/profiler/pip_package/README
|
||||
tensorflow/contrib/tpu/profiler/pip_package/setup.py
|
||||
tensorflow/contrib/mpi/BUILD
|
||||
tensorflow/python/autograph/core/config.py
|
||||
tensorflow/virtual_root_template_v2.__init__.py
|
||||
tensorflow/__init__.py
|
||||
tensorflow/stream_executor/build_defs.bzl
|
||||
tensorflow/third_party/astor.BUILD
|
||||
tensorflow/third_party/grpc/BUILD
|
||||
tensorflow/third_party/curl.BUILD
|
||||
tensorflow/third_party/arm_neon_2_x86_sse.BUILD
|
||||
tensorflow/third_party/png.BUILD
|
||||
tensorflow/third_party/googleapis.BUILD
|
||||
tensorflow/third_party/mpi_collectives/BUILD
|
||||
tensorflow/third_party/nanopb.BUILD
|
||||
tensorflow/third_party/gif.BUILD
|
||||
tensorflow/third_party/double_conversion.BUILD
|
||||
tensorflow/third_party/six.BUILD
|
||||
tensorflow/third_party/tflite_mobilenet_float.BUILD
|
||||
tensorflow/third_party/repo.bzl
|
||||
tensorflow/third_party/codegen.BUILD
|
||||
tensorflow/third_party/cub.BUILD
|
||||
tensorflow/third_party/jsoncpp.BUILD
|
||||
tensorflow/third_party/tflite_ovic_testdata.BUILD
|
||||
tensorflow/third_party/__init__.py
|
||||
tensorflow/third_party/libxsmm.BUILD
|
||||
tensorflow/third_party/zlib.BUILD
|
||||
tensorflow/third_party/eigen.BUILD
|
||||
tensorflow/api_template_v1.__init__.py
|
||||
tensorflow/compat_template_v1.__init__.py
|
||||
tensorflow/compat_template.__init__.py
|
||||
tensorflow/api_template.__init__.py
|
||||
tensorflow/__init__.py
|
||||
tensorflow/virtual_root_template_v2.__init__.py
|
||||
tensorflow/virtual_root_template_v1.__init__.py
|
@ -276,9 +276,12 @@ tf_py_test(
|
||||
additional_deps = [
|
||||
":test_base",
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/python/ops/ragged:ragged_factory_ops",
|
||||
"//tensorflow/python/ops/ragged:ragged_tensor",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
@ -709,12 +712,12 @@ py_library(
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/util:nest",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/ops/ragged",
|
||||
"//tensorflow/python/ops/ragged:ragged_test_util",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -258,7 +258,7 @@ class FromTensorSlicesTest(test_base.DatasetTestBase):
|
||||
if sparse_tensor.is_sparse(component):
|
||||
self.assertSparseValuesEqual(component, result_component)
|
||||
elif ragged_tensor.is_ragged(component):
|
||||
self.assertRaggedEqual(component, result_component)
|
||||
self.assertAllEqual(component, result_component)
|
||||
else:
|
||||
self.assertAllEqual(component, result_component)
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
|
@ -30,11 +30,10 @@ from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import tensor_array_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.ops.ragged import ragged_test_util
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class DatasetTestBase(ragged_test_util.RaggedTensorTestCase, test.TestCase):
|
||||
class DatasetTestBase(test.TestCase):
|
||||
"""Base class for dataset tests."""
|
||||
|
||||
@classmethod
|
||||
@ -108,7 +107,7 @@ class DatasetTestBase(ragged_test_util.RaggedTensorTestCase, test.TestCase):
|
||||
if sparse_tensor.is_sparse(result_value):
|
||||
self.assertSparseValuesEqual(result_value, expected_value)
|
||||
elif ragged_tensor.is_ragged(result_value):
|
||||
self.assertRaggedEqual(result_value, expected_value)
|
||||
self.assertAllEqual(result_value, expected_value)
|
||||
else:
|
||||
self.assertAllEqual(
|
||||
result_value,
|
||||
@ -209,7 +208,7 @@ class DatasetTestBase(ragged_test_util.RaggedTensorTestCase, test.TestCase):
|
||||
if sparse_tensor.is_sparse(op1[i]):
|
||||
self.assertSparseValuesEqual(op1[i], op2[i])
|
||||
elif ragged_tensor.is_ragged(op1[i]):
|
||||
self.assertRaggedEqual(op1[i], op2[i])
|
||||
self.assertAllEqual(op1[i], op2[i])
|
||||
elif flattened_types[i] == dtypes.string:
|
||||
self.assertAllEqual(op1[i], op2[i])
|
||||
else:
|
||||
|
@ -96,14 +96,21 @@ py_test(
|
||||
":structure",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python:tensor_array_ops",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:tensor_spec",
|
||||
"//tensorflow/python:type_spec",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
"//tensorflow/python/ops/ragged",
|
||||
"//tensorflow/python/ops/ragged:ragged_test_util",
|
||||
"//tensorflow/python/ops/ragged:ragged_factory_ops",
|
||||
"//tensorflow/python/ops/ragged:ragged_tensor",
|
||||
"//tensorflow/python/ops/ragged:ragged_tensor_value",
|
||||
"//third_party/py/numpy",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
@ -37,7 +37,6 @@ from tensorflow.python.ops import variables
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.ops.ragged import ragged_tensor_value
|
||||
from tensorflow.python.ops.ragged import ragged_test_util
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -47,7 +46,7 @@ from tensorflow.python.platform import test
|
||||
#
|
||||
# TODO(jsimsa): Add tests for OptionalStructure and DatasetStructure.
|
||||
class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
|
||||
ragged_test_util.RaggedTensorTestCase):
|
||||
test_util.TensorFlowTestCase):
|
||||
|
||||
# pylint: disable=g-long-lambda,protected-access
|
||||
@parameterized.named_parameters(
|
||||
@ -348,7 +347,7 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
|
||||
elif isinstance(
|
||||
b,
|
||||
(ragged_tensor.RaggedTensor, ragged_tensor_value.RaggedTensorValue)):
|
||||
self.assertRaggedEqual(b, a)
|
||||
self.assertAllEqual(b, a)
|
||||
else:
|
||||
self.assertAllEqual(b, a)
|
||||
|
||||
@ -707,7 +706,7 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
|
||||
if sparse_tensor.is_sparse(expected):
|
||||
self.assertSparseValuesEqual(expected, actual)
|
||||
elif ragged_tensor.is_ragged(expected):
|
||||
self.assertRaggedEqual(expected, actual)
|
||||
self.assertAllEqual(expected, actual)
|
||||
else:
|
||||
self.assertAllEqual(expected, actual)
|
||||
|
||||
|
@ -1900,7 +1900,8 @@ class TensorFlowTestCase(googletest.TestCase):
|
||||
tensor.dense_shape.numpy())
|
||||
elif ragged_tensor.is_ragged(tensor):
|
||||
return ragged_tensor_value.RaggedTensorValue(
|
||||
tensor.values.numpy(), tensor.row_splits.numpy())
|
||||
self._eval_tensor(tensor.values),
|
||||
self._eval_tensor(tensor.row_splits))
|
||||
elif isinstance(tensor, ops.IndexedSlices):
|
||||
return ops.IndexedSlicesValue(
|
||||
values=tensor.values.numpy(),
|
||||
@ -2363,6 +2364,8 @@ class TensorFlowTestCase(googletest.TestCase):
|
||||
to the nested structure, e.g. given `a = [(1, 1), {'d': (6, 7)}]` and
|
||||
`[p] = [1]['d']`, then `a[p] = (6, 7)`.
|
||||
"""
|
||||
if ragged_tensor.is_ragged(a) or ragged_tensor.is_ragged(b):
|
||||
return self._assertRaggedClose(a, b, rtol, atol, msg)
|
||||
self._assertAllCloseRecursive(a, b, rtol=rtol, atol=atol, msg=msg)
|
||||
|
||||
@py_func_if_in_function
|
||||
@ -2441,6 +2444,8 @@ class TensorFlowTestCase(googletest.TestCase):
|
||||
b: the actual numpy ndarray or anything can be converted to one.
|
||||
msg: Optional message to report on failure.
|
||||
"""
|
||||
if (ragged_tensor.is_ragged(a) or ragged_tensor.is_ragged(b)):
|
||||
return self._assertRaggedEqual(a, b, msg)
|
||||
msg = msg if msg else ""
|
||||
a = self._GetNdArray(a)
|
||||
b = self._GetNdArray(b)
|
||||
@ -2730,6 +2735,51 @@ class TensorFlowTestCase(googletest.TestCase):
|
||||
device1, device2,
|
||||
"Devices %s and %s are not equal. %s" % (device1, device2, msg))
|
||||
|
||||
def _GetPyList(self, a):
|
||||
"""Converts `a` to a nested python list."""
|
||||
if isinstance(a, ragged_tensor.RaggedTensor):
|
||||
return self.evaluate(a).to_list()
|
||||
elif isinstance(a, ops.Tensor):
|
||||
a = self.evaluate(a)
|
||||
return a.tolist() if isinstance(a, np.ndarray) else a
|
||||
elif isinstance(a, np.ndarray):
|
||||
return a.tolist()
|
||||
elif isinstance(a, ragged_tensor_value.RaggedTensorValue):
|
||||
return a.to_list()
|
||||
else:
|
||||
return np.array(a).tolist()
|
||||
|
||||
def _assertRaggedEqual(self, a, b, msg):
|
||||
"""Asserts that two ragged tensors are equal."""
|
||||
a_list = self._GetPyList(a)
|
||||
b_list = self._GetPyList(b)
|
||||
self.assertEqual(a_list, b_list, msg)
|
||||
|
||||
if not (isinstance(a, (list, tuple)) or isinstance(b, (list, tuple))):
|
||||
a_ragged_rank = a.ragged_rank if ragged_tensor.is_ragged(a) else 0
|
||||
b_ragged_rank = b.ragged_rank if ragged_tensor.is_ragged(b) else 0
|
||||
self.assertEqual(a_ragged_rank, b_ragged_rank, msg)
|
||||
|
||||
def _assertRaggedClose(self, a, b, rtol, atol, msg=None):
|
||||
a_list = self._GetPyList(a)
|
||||
b_list = self._GetPyList(b)
|
||||
self._assertListCloseRecursive(a_list, b_list, rtol, atol, msg)
|
||||
|
||||
if not (isinstance(a, (list, tuple)) or isinstance(b, (list, tuple))):
|
||||
a_ragged_rank = a.ragged_rank if ragged_tensor.is_ragged(a) else 0
|
||||
b_ragged_rank = b.ragged_rank if ragged_tensor.is_ragged(b) else 0
|
||||
self.assertEqual(a_ragged_rank, b_ragged_rank, msg)
|
||||
|
||||
def _assertListCloseRecursive(self, a, b, rtol, atol, msg, path="value"):
|
||||
self.assertEqual(type(a), type(b))
|
||||
if isinstance(a, (list, tuple)):
|
||||
self.assertLen(a, len(b), "Length differs for %s" % path)
|
||||
for i in range(len(a)):
|
||||
self._assertListCloseRecursive(a[i], b[i], rtol, atol, msg,
|
||||
"%s[%s]" % (path, i))
|
||||
else:
|
||||
self._assertAllCloseRecursive(a, b, rtol, atol, path, msg)
|
||||
|
||||
# Fix Python 3 compatibility issues
|
||||
if six.PY3:
|
||||
# pylint: disable=invalid-name
|
||||
|
@ -1158,7 +1158,6 @@ tf_py_test(
|
||||
"//third_party/py/numpy",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
"//tensorflow/python/ops/ragged:ragged_tensor",
|
||||
"//tensorflow/python/ops/ragged:ragged_test_util",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:dtypes",
|
||||
@ -1679,11 +1678,13 @@ tf_py_test(
|
||||
additional_deps = [
|
||||
":keras",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
"//tensorflow/python/ops/ragged:ragged_tensor",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/ops/ragged:ragged_factory_ops",
|
||||
"//tensorflow/python/ops/ragged:ragged_test_util",
|
||||
],
|
||||
)
|
||||
|
@ -40,7 +40,6 @@ from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.ops.ragged import ragged_test_util
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -155,8 +154,7 @@ def get_model_from_layers_with_input(layers,
|
||||
|
||||
@keras_parameterized.run_with_all_model_types
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
class CompositeTensorInternalTest(keras_parameterized.TestCase,
|
||||
ragged_test_util.RaggedTensorTestCase):
|
||||
class CompositeTensorInternalTest(keras_parameterized.TestCase):
|
||||
|
||||
def test_internal_ragged_tensors(self):
|
||||
# Create a model that accepts an input, converts it to Ragged, and
|
||||
@ -207,8 +205,7 @@ class CompositeTensorInternalTest(keras_parameterized.TestCase,
|
||||
|
||||
@keras_parameterized.run_with_all_model_types
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
class CompositeTensorOutputTest(keras_parameterized.TestCase,
|
||||
ragged_test_util.RaggedTensorTestCase):
|
||||
class CompositeTensorOutputTest(keras_parameterized.TestCase):
|
||||
|
||||
def test_ragged_tensor_outputs(self):
|
||||
# Create a model that accepts an input, converts it to Ragged, and
|
||||
@ -221,7 +218,7 @@ class CompositeTensorOutputTest(keras_parameterized.TestCase,
|
||||
output = model.predict(input_data)
|
||||
|
||||
expected_values = [[1], [2, 3]]
|
||||
self.assertRaggedEqual(expected_values, output)
|
||||
self.assertAllEqual(expected_values, output)
|
||||
|
||||
def test_ragged_tensor_rebatched_outputs(self):
|
||||
# Create a model that accepts an input, converts it to Ragged, and
|
||||
@ -234,7 +231,7 @@ class CompositeTensorOutputTest(keras_parameterized.TestCase,
|
||||
output = model.predict(input_data, batch_size=2)
|
||||
|
||||
expected_values = [[1], [2, 3], [4], [5, 6]]
|
||||
self.assertRaggedEqual(expected_values, output)
|
||||
self.assertAllEqual(expected_values, output)
|
||||
|
||||
def test_sparse_tensor_outputs(self):
|
||||
# Create a model that accepts an input, converts it to Ragged, and
|
||||
@ -315,8 +312,7 @@ def prepare_inputs(data, use_dict, use_dataset, action, input_name):
|
||||
use_dict=[True, False],
|
||||
use_dataset=[True, False],
|
||||
action=["predict", "evaluate", "fit"]))
|
||||
class SparseTensorInputTest(keras_parameterized.TestCase,
|
||||
ragged_test_util.RaggedTensorTestCase):
|
||||
class SparseTensorInputTest(keras_parameterized.TestCase):
|
||||
|
||||
def test_sparse_tensors(self, use_dict, use_dataset, action):
|
||||
data = [(sparse_tensor.SparseTensor([[0, 0, 0], [1, 0, 0], [1, 0, 1]],
|
||||
@ -357,7 +353,7 @@ class SparseTensorInputTest(keras_parameterized.TestCase,
|
||||
@keras_parameterized.run_with_all_model_types
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
class ScipySparseTensorInputTest(keras_parameterized.TestCase,
|
||||
ragged_test_util.RaggedTensorTestCase):
|
||||
test_util.TensorFlowTestCase):
|
||||
|
||||
def test_sparse_scipy_predict_inputs_via_input_layer_args(self):
|
||||
# Create a model that accepts a sparse input and converts the sparse tensor
|
||||
@ -473,7 +469,7 @@ class ScipySparseTensorInputTest(keras_parameterized.TestCase,
|
||||
use_dataset=[True, False],
|
||||
action=["predict", "evaluate", "fit"]))
|
||||
class RaggedTensorInputTest(keras_parameterized.TestCase,
|
||||
ragged_test_util.RaggedTensorTestCase):
|
||||
test_util.TensorFlowTestCase):
|
||||
|
||||
def test_ragged_input(self, use_dict, use_dataset, action):
|
||||
data = [(ragged_factory_ops.constant([[[1]], [[2, 3]]]),
|
||||
@ -510,7 +506,7 @@ class RaggedTensorInputTest(keras_parameterized.TestCase,
|
||||
*test_util.generate_combinations_with_testcase_name(
|
||||
use_dict=[True, False], use_dataset=[True, False]))
|
||||
class RaggedTensorInputValidationTest(keras_parameterized.TestCase,
|
||||
ragged_test_util.RaggedTensorTestCase):
|
||||
test_util.TensorFlowTestCase):
|
||||
|
||||
def test_ragged_tensor_input_with_one_none_dimension(self, use_dict,
|
||||
use_dataset):
|
||||
@ -596,8 +592,7 @@ class RaggedTensorInputValidationTest(keras_parameterized.TestCase,
|
||||
# subclassed models, so we run a separate parameterized test for them.
|
||||
@keras_parameterized.run_with_all_model_types(exclude_models=["subclass"])
|
||||
@keras_parameterized.run_all_keras_modes(always_skip_eager=True)
|
||||
class SparseTensorInputValidationTest(keras_parameterized.TestCase,
|
||||
ragged_test_util.RaggedTensorTestCase):
|
||||
class SparseTensorInputValidationTest(keras_parameterized.TestCase):
|
||||
|
||||
def test_sparse_scipy_input_checks_shape(self):
|
||||
model_input = input_layer.Input(shape=(3,), sparse=True, dtype=dtypes.int32)
|
||||
@ -647,8 +642,7 @@ class SparseTensorInputValidationTest(keras_parameterized.TestCase,
|
||||
@keras_parameterized.run_with_all_model_types(
|
||||
exclude_models=["functional"])
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
class UndefinedCompositeTensorInputsTest(keras_parameterized.TestCase,
|
||||
ragged_test_util.RaggedTensorTestCase):
|
||||
class UndefinedCompositeTensorInputsTest(keras_parameterized.TestCase):
|
||||
|
||||
def test_subclass_implicit_sparse_inputs_fails(self):
|
||||
# Create a model that accepts a sparse input and converts the sparse tensor
|
||||
|
@ -25,13 +25,11 @@ from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras.utils import metrics_utils
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.ops.ragged import ragged_test_util
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RaggedSizeOpTest(ragged_test_util.RaggedTensorTestCase,
|
||||
parameterized.TestCase):
|
||||
class RaggedSizeOpTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters([
|
||||
{
|
||||
|
@ -1077,14 +1077,17 @@ tf_py_test(
|
||||
size = "small",
|
||||
srcs = ["string_split_op_test.py"],
|
||||
additional_deps = [
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/python/ops/ragged:ragged_factory_ops",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:string_ops",
|
||||
"//tensorflow/python/ops/ragged:ragged_string_ops",
|
||||
"//tensorflow/python/ops/ragged:ragged_test_util",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
||||
@ -1093,14 +1096,17 @@ tf_py_test(
|
||||
size = "small",
|
||||
srcs = ["string_bytes_split_op_test.py"],
|
||||
additional_deps = [
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/python/ops/ragged:ragged_factory_ops",
|
||||
"//tensorflow/python/ops/ragged:ragged_string_ops",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:string_ops",
|
||||
"//tensorflow/python/ops/ragged",
|
||||
"//tensorflow/python/ops/ragged:ragged_test_util",
|
||||
],
|
||||
)
|
||||
|
||||
@ -1327,11 +1333,14 @@ tf_py_test(
|
||||
additional_deps = [
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/python/ops/ragged:ragged_tensor",
|
||||
"//tensorflow/python/ops/ragged:ragged_tensor_value",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python/ops/ragged:ragged_factory_ops",
|
||||
"//tensorflow/python/ops/ragged:ragged_string_ops",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
],
|
||||
)
|
||||
|
||||
@ -1353,13 +1362,14 @@ tf_py_test(
|
||||
srcs = ["unicode_decode_op_test.py"],
|
||||
additional_deps = [
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/ops/ragged:ragged_factory_ops",
|
||||
"//tensorflow/python/ops/ragged:ragged_string_ops",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python/ops/ragged:ragged",
|
||||
"//tensorflow/python/ops/ragged:ragged_test_util",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python:string_ops",
|
||||
|
@ -22,13 +22,13 @@ from __future__ import print_function
|
||||
from absl.testing import parameterized
|
||||
|
||||
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_string_ops
|
||||
from tensorflow.python.ops.ragged import ragged_test_util
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class StringsToBytesOpTest(ragged_test_util.RaggedTensorTestCase,
|
||||
class StringsToBytesOpTest(test_util.TensorFlowTestCase,
|
||||
parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters(
|
||||
@ -62,7 +62,7 @@ class StringsToBytesOpTest(ragged_test_util.RaggedTensorTestCase,
|
||||
def testStringToBytes(self, source, expected):
|
||||
expected = ragged_factory_ops.constant_value(expected, dtype=object)
|
||||
result = ragged_string_ops.string_bytes_split(source)
|
||||
self.assertRaggedEqual(expected, result)
|
||||
self.assertAllEqual(expected, result)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -29,7 +29,6 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_string_ops
|
||||
from tensorflow.python.ops.ragged import ragged_test_util
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
@ -171,8 +170,7 @@ class StringSplitOpTest(test.TestCase):
|
||||
self.assertAllEqual(shape, [3, 1])
|
||||
|
||||
|
||||
class StringSplitV2OpTest(ragged_test_util.RaggedTensorTestCase,
|
||||
parameterized.TestCase):
|
||||
class StringSplitV2OpTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.named_parameters([
|
||||
{"testcase_name": "Simple",
|
||||
@ -278,11 +276,11 @@ class StringSplitV2OpTest(ragged_test_util.RaggedTensorTestCase,
|
||||
actual_ragged_v2 = ragged_string_ops.string_split_v2(input, **kwargs)
|
||||
actual_ragged_v2_input_kwarg = ragged_string_ops.string_split_v2(
|
||||
input=input, **kwargs)
|
||||
self.assertRaggedEqual(expected_ragged, actual_ragged_v1)
|
||||
self.assertRaggedEqual(expected_ragged, actual_ragged_v1_input_kwarg)
|
||||
self.assertRaggedEqual(expected_ragged, actual_ragged_v1_source_kwarg)
|
||||
self.assertRaggedEqual(expected_ragged, actual_ragged_v2)
|
||||
self.assertRaggedEqual(expected_ragged, actual_ragged_v2_input_kwarg)
|
||||
self.assertAllEqual(expected_ragged, actual_ragged_v1)
|
||||
self.assertAllEqual(expected_ragged, actual_ragged_v1_input_kwarg)
|
||||
self.assertAllEqual(expected_ragged, actual_ragged_v1_source_kwarg)
|
||||
self.assertAllEqual(expected_ragged, actual_ragged_v2)
|
||||
self.assertAllEqual(expected_ragged, actual_ragged_v2_input_kwarg)
|
||||
|
||||
# Check that the internal version (which returns a SparseTensor) works
|
||||
# correctly. Note: the internal version oly supports vector inputs.
|
||||
|
@ -31,7 +31,6 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_string_ops
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_string_ops
|
||||
from tensorflow.python.ops.ragged import ragged_test_util
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -91,7 +90,7 @@ def _make_sparse_tensor(indices, values, dense_shape, dtype=np.int32):
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class UnicodeDecodeTest(ragged_test_util.RaggedTensorTestCase,
|
||||
class UnicodeDecodeTest(test_util.TensorFlowTestCase,
|
||||
parameterized.TestCase):
|
||||
|
||||
def testScalarDecode(self):
|
||||
@ -110,15 +109,15 @@ class UnicodeDecodeTest(ragged_test_util.RaggedTensorTestCase,
|
||||
chars = ragged_string_ops.unicode_decode(text, "utf-8")
|
||||
expected_chars = [[ord(c) for c in u"仅今年前"],
|
||||
[ord(c) for c in u"hello"]]
|
||||
self.assertRaggedEqual(chars, expected_chars)
|
||||
self.assertAllEqual(chars, expected_chars)
|
||||
|
||||
def testVectorDecodeWithOffset(self):
|
||||
text = constant_op.constant([u"仅今年前".encode("utf-8"), b"hello"])
|
||||
chars, starts = ragged_string_ops.unicode_decode_with_offsets(text, "utf-8")
|
||||
expected_chars = [[ord(c) for c in u"仅今年前"],
|
||||
[ord(c) for c in u"hello"]]
|
||||
self.assertRaggedEqual(chars, expected_chars)
|
||||
self.assertRaggedEqual(starts, [[0, 3, 6, 9], [0, 1, 2, 3, 4]])
|
||||
self.assertAllEqual(chars, expected_chars)
|
||||
self.assertAllEqual(starts, [[0, 3, 6, 9], [0, 1, 2, 3, 4]])
|
||||
|
||||
@parameterized.parameters([
|
||||
{"texts": u"仅今年前"},
|
||||
@ -134,7 +133,7 @@ class UnicodeDecodeTest(ragged_test_util.RaggedTensorTestCase,
|
||||
_nested_encode(texts, "UTF-8"), ragged_rank=ragged_rank, dtype=bytes)
|
||||
result = ragged_string_ops.unicode_decode(input_tensor, "UTF-8")
|
||||
expected = _nested_codepoints(texts)
|
||||
self.assertRaggedEqual(expected, result)
|
||||
self.assertAllEqual(expected, result)
|
||||
|
||||
@parameterized.parameters([
|
||||
{"texts": u"仅今年前"},
|
||||
@ -152,19 +151,19 @@ class UnicodeDecodeTest(ragged_test_util.RaggedTensorTestCase,
|
||||
input_tensor, "UTF-8")
|
||||
expected_codepoints = _nested_codepoints(texts)
|
||||
expected_offsets = _nested_offsets(texts, "UTF-8")
|
||||
self.assertRaggedEqual(expected_codepoints, result[0])
|
||||
self.assertRaggedEqual(expected_offsets, result[1])
|
||||
self.assertAllEqual(expected_codepoints, result[0])
|
||||
self.assertAllEqual(expected_offsets, result[1])
|
||||
|
||||
def testDocstringExamples(self):
|
||||
texts = [s.encode("utf8") for s in [u"G\xf6\xf6dnight", u"\U0001f60a"]]
|
||||
codepoints1 = ragged_string_ops.unicode_decode(texts, "UTF-8")
|
||||
codepoints2, offsets = ragged_string_ops.unicode_decode_with_offsets(
|
||||
texts, "UTF-8")
|
||||
self.assertRaggedEqual(
|
||||
self.assertAllEqual(
|
||||
codepoints1, [[71, 246, 246, 100, 110, 105, 103, 104, 116], [128522]])
|
||||
self.assertRaggedEqual(
|
||||
self.assertAllEqual(
|
||||
codepoints2, [[71, 246, 246, 100, 110, 105, 103, 104, 116], [128522]])
|
||||
self.assertRaggedEqual(offsets, [[0, 1, 3, 5, 6, 7, 8, 9, 10], [0]])
|
||||
self.assertAllEqual(offsets, [[0, 1, 3, 5, 6, 7, 8, 9, 10], [0]])
|
||||
|
||||
@parameterized.parameters([
|
||||
dict(
|
||||
@ -263,7 +262,7 @@ class UnicodeDecodeTest(ragged_test_util.RaggedTensorTestCase,
|
||||
]) # pyformat: disable
|
||||
def testErrorModes(self, expected=None, **args):
|
||||
result = ragged_string_ops.unicode_decode(**args)
|
||||
self.assertRaggedEqual(expected, result)
|
||||
self.assertAllEqual(expected, result)
|
||||
|
||||
@parameterized.parameters([
|
||||
dict(
|
||||
@ -314,8 +313,8 @@ class UnicodeDecodeTest(ragged_test_util.RaggedTensorTestCase,
|
||||
expected_offsets=None,
|
||||
**args):
|
||||
result = ragged_string_ops.unicode_decode_with_offsets(**args)
|
||||
self.assertRaggedEqual(result[0], expected)
|
||||
self.assertRaggedEqual(result[1], expected_offsets)
|
||||
self.assertAllEqual(result[0], expected)
|
||||
self.assertAllEqual(result[1], expected_offsets)
|
||||
|
||||
@parameterized.parameters(
|
||||
("UTF-8", [u"こんにちは", u"你好", u"Hello"]),
|
||||
@ -329,7 +328,7 @@ class UnicodeDecodeTest(ragged_test_util.RaggedTensorTestCase,
|
||||
expected = _nested_codepoints(texts)
|
||||
input_tensor = constant_op.constant(_nested_encode(texts, encoding))
|
||||
result = ragged_string_ops.unicode_decode(input_tensor, encoding)
|
||||
self.assertRaggedEqual(expected, result)
|
||||
self.assertAllEqual(expected, result)
|
||||
|
||||
@parameterized.parameters(
|
||||
("UTF-8", [u"こんにちは", u"你好", u"Hello"]),
|
||||
@ -345,8 +344,8 @@ class UnicodeDecodeTest(ragged_test_util.RaggedTensorTestCase,
|
||||
input_tensor = constant_op.constant(_nested_encode(texts, encoding))
|
||||
result = ragged_string_ops.unicode_decode_with_offsets(
|
||||
input_tensor, encoding)
|
||||
self.assertRaggedEqual(expected_codepoints, result[0])
|
||||
self.assertRaggedEqual(expected_offsets, result[1])
|
||||
self.assertAllEqual(expected_codepoints, result[0])
|
||||
self.assertAllEqual(expected_offsets, result[1])
|
||||
|
||||
@parameterized.parameters([
|
||||
dict(input=[b"\xFEED"],
|
||||
@ -423,7 +422,7 @@ class UnicodeDecodeTest(ragged_test_util.RaggedTensorTestCase,
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class UnicodeSplitTest(ragged_test_util.RaggedTensorTestCase,
|
||||
class UnicodeSplitTest(test_util.TensorFlowTestCase,
|
||||
parameterized.TestCase):
|
||||
|
||||
def testScalarSplit(self):
|
||||
@ -442,15 +441,15 @@ class UnicodeSplitTest(ragged_test_util.RaggedTensorTestCase,
|
||||
chars = ragged_string_ops.unicode_split(text, "UTF-8")
|
||||
expected_chars = [[c.encode("UTF-8") for c in u"仅今年前"],
|
||||
[c.encode("UTF-8") for c in u"hello"]]
|
||||
self.assertRaggedEqual(chars, expected_chars)
|
||||
self.assertAllEqual(chars, expected_chars)
|
||||
|
||||
def testVectorSplitWithOffset(self):
|
||||
text = constant_op.constant([u"仅今年前".encode("UTF-8"), b"hello"])
|
||||
chars, starts = ragged_string_ops.unicode_split_with_offsets(text, "UTF-8")
|
||||
expected_chars = [[c.encode("UTF-8") for c in u"仅今年前"],
|
||||
[c.encode("UTF-8") for c in u"hello"]]
|
||||
self.assertRaggedEqual(chars, expected_chars)
|
||||
self.assertRaggedEqual(starts, [[0, 3, 6, 9], [0, 1, 2, 3, 4]])
|
||||
self.assertAllEqual(chars, expected_chars)
|
||||
self.assertAllEqual(starts, [[0, 3, 6, 9], [0, 1, 2, 3, 4]])
|
||||
|
||||
@parameterized.parameters([
|
||||
{"texts": u"仅今年前"},
|
||||
@ -466,7 +465,7 @@ class UnicodeSplitTest(ragged_test_util.RaggedTensorTestCase,
|
||||
_nested_encode(texts, "UTF-8"), ragged_rank=ragged_rank, dtype=bytes)
|
||||
result = ragged_string_ops.unicode_split(input_tensor, "UTF-8")
|
||||
expected = _nested_splitchars(texts, "UTF-8")
|
||||
self.assertRaggedEqual(expected, result)
|
||||
self.assertAllEqual(expected, result)
|
||||
|
||||
@parameterized.parameters([
|
||||
{"texts": u"仅今年前"},
|
||||
@ -483,23 +482,23 @@ class UnicodeSplitTest(ragged_test_util.RaggedTensorTestCase,
|
||||
result = ragged_string_ops.unicode_split_with_offsets(input_tensor, "UTF-8")
|
||||
expected_codepoints = _nested_splitchars(texts, "UTF-8")
|
||||
expected_offsets = _nested_offsets(texts, "UTF-8")
|
||||
self.assertRaggedEqual(expected_codepoints, result[0])
|
||||
self.assertRaggedEqual(expected_offsets, result[1])
|
||||
self.assertAllEqual(expected_codepoints, result[0])
|
||||
self.assertAllEqual(expected_offsets, result[1])
|
||||
|
||||
def testDocstringExamples(self):
|
||||
texts = [s.encode("utf8") for s in [u"G\xf6\xf6dnight", u"\U0001f60a"]]
|
||||
codepoints1 = ragged_string_ops.unicode_split(texts, "UTF-8")
|
||||
codepoints2, offsets = ragged_string_ops.unicode_split_with_offsets(
|
||||
texts, "UTF-8")
|
||||
self.assertRaggedEqual(
|
||||
self.assertAllEqual(
|
||||
codepoints1,
|
||||
[[b"G", b"\xc3\xb6", b"\xc3\xb6", b"d", b"n", b"i", b"g", b"h", b"t"],
|
||||
[b"\xf0\x9f\x98\x8a"]])
|
||||
self.assertRaggedEqual(
|
||||
self.assertAllEqual(
|
||||
codepoints2,
|
||||
[[b"G", b"\xc3\xb6", b"\xc3\xb6", b"d", b"n", b"i", b"g", b"h", b"t"],
|
||||
[b"\xf0\x9f\x98\x8a"]])
|
||||
self.assertRaggedEqual(offsets, [[0, 1, 3, 5, 6, 7, 8, 9, 10], [0]])
|
||||
self.assertAllEqual(offsets, [[0, 1, 3, 5, 6, 7, 8, 9, 10], [0]])
|
||||
|
||||
@parameterized.parameters([
|
||||
dict(
|
||||
@ -604,7 +603,7 @@ class UnicodeSplitTest(ragged_test_util.RaggedTensorTestCase,
|
||||
]) # pyformat: disable
|
||||
def testErrorModes(self, expected=None, **args):
|
||||
result = ragged_string_ops.unicode_split(**args)
|
||||
self.assertRaggedEqual(expected, result)
|
||||
self.assertAllEqual(expected, result)
|
||||
|
||||
@parameterized.parameters([
|
||||
dict(
|
||||
@ -644,8 +643,8 @@ class UnicodeSplitTest(ragged_test_util.RaggedTensorTestCase,
|
||||
expected_offsets=None,
|
||||
**args):
|
||||
result = ragged_string_ops.unicode_split_with_offsets(**args)
|
||||
self.assertRaggedEqual(expected, result[0])
|
||||
self.assertRaggedEqual(expected_offsets, result[1])
|
||||
self.assertAllEqual(expected, result[0])
|
||||
self.assertAllEqual(expected_offsets, result[1])
|
||||
|
||||
@parameterized.parameters(
|
||||
("UTF-8", [u"こんにちは", u"你好", u"Hello"]),
|
||||
@ -656,7 +655,7 @@ class UnicodeSplitTest(ragged_test_util.RaggedTensorTestCase,
|
||||
expected = _nested_splitchars(texts, encoding)
|
||||
input_tensor = constant_op.constant(_nested_encode(texts, encoding))
|
||||
result = ragged_string_ops.unicode_split(input_tensor, encoding)
|
||||
self.assertRaggedEqual(expected, result)
|
||||
self.assertAllEqual(expected, result)
|
||||
|
||||
@parameterized.parameters(
|
||||
("UTF-8", [u"こんにちは", u"你好", u"Hello"]),
|
||||
@ -669,8 +668,8 @@ class UnicodeSplitTest(ragged_test_util.RaggedTensorTestCase,
|
||||
input_tensor = constant_op.constant(_nested_encode(texts, encoding))
|
||||
result = ragged_string_ops.unicode_split_with_offsets(
|
||||
input_tensor, encoding)
|
||||
self.assertRaggedEqual(expected_codepoints, result[0])
|
||||
self.assertRaggedEqual(expected_offsets, result[1])
|
||||
self.assertAllEqual(expected_codepoints, result[0])
|
||||
self.assertAllEqual(expected_offsets, result[1])
|
||||
|
||||
@parameterized.parameters([
|
||||
dict(input=[b"\xFEED"],
|
||||
|
@ -33,7 +33,7 @@ from tensorflow.python.platform import test
|
||||
|
||||
class UnicodeEncodeOpTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def assertRaggedEqual(self, rt, expected):
|
||||
def assertAllEqual(self, rt, expected):
|
||||
with self.cached_session() as sess:
|
||||
value = sess.run(rt)
|
||||
if isinstance(value, np.ndarray):
|
||||
@ -76,10 +76,7 @@ class UnicodeEncodeOpTest(test.TestCase, parameterized.TestCase):
|
||||
expected_value = u"Heo".encode(encoding)
|
||||
unicode_encode_op = ragged_string_ops.unicode_encode(test_value, encoding,
|
||||
"ignore")
|
||||
with self.cached_session() as session:
|
||||
result = session.run(unicode_encode_op)
|
||||
self.assertIsInstance(result, bytes)
|
||||
self.assertAllEqual(result, expected_value)
|
||||
self.assertAllEqual(unicode_encode_op, expected_value)
|
||||
|
||||
@parameterized.parameters("UTF-8", "UTF-16-BE", "UTF-32-BE")
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
@ -88,20 +85,20 @@ class UnicodeEncodeOpTest(test.TestCase, parameterized.TestCase):
|
||||
expected_value = u"He\U0000fffd\U0000fffdo".encode(encoding)
|
||||
unicode_encode_op = ragged_string_ops.unicode_encode(test_value, encoding,
|
||||
"replace")
|
||||
self.assertRaggedEqual(unicode_encode_op, expected_value)
|
||||
self.assertAllEqual(unicode_encode_op, expected_value)
|
||||
|
||||
# Test custom replacement character
|
||||
test_value = np.array([72, 101, 2147483647, -1, 111], np.int32)
|
||||
expected_value = u"Heooo".encode(encoding)
|
||||
unicode_encode_op = ragged_string_ops.unicode_encode(test_value, encoding,
|
||||
"replace", 111)
|
||||
self.assertRaggedEqual(unicode_encode_op, expected_value)
|
||||
self.assertAllEqual(unicode_encode_op, expected_value)
|
||||
|
||||
# Verify "replace" is default
|
||||
test_value = np.array([72, 101, 2147483647, -1, 111], np.int32)
|
||||
expected_value = u"He\U0000fffd\U0000fffdo".encode(encoding)
|
||||
unicode_encode_op = ragged_string_ops.unicode_encode(test_value, encoding)
|
||||
self.assertRaggedEqual(unicode_encode_op, expected_value)
|
||||
self.assertAllEqual(unicode_encode_op, expected_value)
|
||||
|
||||
# Replacement_char must be within range
|
||||
test_value = np.array([72, 101, 2147483647, -1, 111], np.int32)
|
||||
@ -118,23 +115,23 @@ class UnicodeEncodeOpTest(test.TestCase, parameterized.TestCase):
|
||||
test_value = np.array([72, 101, 108, 108, 111], np.int32)
|
||||
expected_value = u"Hello".encode(encoding)
|
||||
unicode_encode_op = ragged_string_ops.unicode_encode(test_value, encoding)
|
||||
self.assertRaggedEqual(unicode_encode_op, expected_value)
|
||||
self.assertAllEqual(unicode_encode_op, expected_value)
|
||||
|
||||
test_value = np.array([72, 101, 195, 195, 128516], np.int32)
|
||||
expected_value = u"He\xc3\xc3\U0001f604".encode(encoding)
|
||||
unicode_encode_op = ragged_string_ops.unicode_encode(test_value, encoding)
|
||||
self.assertRaggedEqual(unicode_encode_op, expected_value)
|
||||
self.assertAllEqual(unicode_encode_op, expected_value)
|
||||
|
||||
# Single character string
|
||||
test_value = np.array([72], np.int32)
|
||||
expected_value = u"H".encode(encoding)
|
||||
unicode_encode_op = ragged_string_ops.unicode_encode(test_value, encoding)
|
||||
self.assertRaggedEqual(unicode_encode_op, expected_value)
|
||||
self.assertAllEqual(unicode_encode_op, expected_value)
|
||||
|
||||
test_value = np.array([128516], np.int32)
|
||||
expected_value = u"\U0001f604".encode(encoding)
|
||||
unicode_encode_op = ragged_string_ops.unicode_encode(test_value, encoding)
|
||||
self.assertRaggedEqual(unicode_encode_op, expected_value)
|
||||
self.assertAllEqual(unicode_encode_op, expected_value)
|
||||
|
||||
@parameterized.parameters("UTF-8", "UTF-16-BE", "UTF-32-BE")
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
@ -158,7 +155,7 @@ class UnicodeEncodeOpTest(test.TestCase, parameterized.TestCase):
|
||||
[u"fixed".encode(encoding), u"words".encode(encoding)],
|
||||
[u"Hyper".encode(encoding), u"cube.".encode(encoding)]]
|
||||
unicode_encode_op = ragged_string_ops.unicode_encode(test_value, encoding)
|
||||
self.assertRaggedEqual(unicode_encode_op, expected_value)
|
||||
self.assertAllEqual(unicode_encode_op, expected_value)
|
||||
|
||||
@parameterized.parameters("UTF-8", "UTF-16-BE", "UTF-32-BE")
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
@ -174,7 +171,7 @@ class UnicodeEncodeOpTest(test.TestCase, parameterized.TestCase):
|
||||
[[u"Hyper".encode(encoding)],
|
||||
[u"cube.".encode(encoding)]]]
|
||||
unicode_encode_op = ragged_string_ops.unicode_encode(test_value, encoding)
|
||||
self.assertRaggedEqual(unicode_encode_op, expected_value)
|
||||
self.assertAllEqual(unicode_encode_op, expected_value)
|
||||
|
||||
# -- Ragged Tensor tests -- #
|
||||
|
||||
@ -187,7 +184,7 @@ class UnicodeEncodeOpTest(test.TestCase, parameterized.TestCase):
|
||||
u"H\xc3llo".encode(encoding), u"W\U0001f604rld.".encode(encoding)
|
||||
]
|
||||
unicode_encode_op = ragged_string_ops.unicode_encode(test_value, encoding)
|
||||
self.assertRaggedEqual(unicode_encode_op, expected_value)
|
||||
self.assertAllEqual(unicode_encode_op, expected_value)
|
||||
|
||||
@parameterized.parameters("UTF-8", "UTF-16-BE", "UTF-32-BE")
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
@ -204,7 +201,7 @@ class UnicodeEncodeOpTest(test.TestCase, parameterized.TestCase):
|
||||
u"cube.".encode(encoding)
|
||||
]]
|
||||
unicode_encode_op = ragged_string_ops.unicode_encode(test_value, encoding)
|
||||
self.assertRaggedEqual(unicode_encode_op, expected_value)
|
||||
self.assertAllEqual(unicode_encode_op, expected_value)
|
||||
|
||||
@parameterized.parameters("UTF-8", "UTF-16-BE", "UTF-32-BE")
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
@ -219,7 +216,7 @@ class UnicodeEncodeOpTest(test.TestCase, parameterized.TestCase):
|
||||
u"w\xc3rry, be".encode(encoding)
|
||||
], [u"\U0001f604".encode(encoding), u"".encode(encoding)]]
|
||||
unicode_encode_op = ragged_string_ops.unicode_encode(test_value, encoding)
|
||||
self.assertRaggedEqual(unicode_encode_op, expected_value)
|
||||
self.assertAllEqual(unicode_encode_op, expected_value)
|
||||
|
||||
@parameterized.parameters("UTF-8", "UTF-16-BE", "UTF-32-BE")
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
@ -230,7 +227,7 @@ class UnicodeEncodeOpTest(test.TestCase, parameterized.TestCase):
|
||||
expected_value = [[u"Hello".encode(encoding), u"World.".encode(encoding)],
|
||||
[], [u"\U0001f604".encode(encoding)]]
|
||||
unicode_encode_op = ragged_string_ops.unicode_encode(test_value, encoding)
|
||||
self.assertRaggedEqual(unicode_encode_op, expected_value)
|
||||
self.assertAllEqual(unicode_encode_op, expected_value)
|
||||
|
||||
@parameterized.parameters("UTF-8", "UTF-16-BE", "UTF-32-BE")
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
@ -241,7 +238,7 @@ class UnicodeEncodeOpTest(test.TestCase, parameterized.TestCase):
|
||||
expected_value = [[[u"Hello".encode(encoding), u"World".encode(encoding)]],
|
||||
[[u"".encode(encoding)], [u"Hype".encode(encoding)]]]
|
||||
unicode_encode_op = ragged_string_ops.unicode_encode(test_value, encoding)
|
||||
self.assertRaggedEqual(unicode_encode_op, expected_value)
|
||||
self.assertAllEqual(unicode_encode_op, expected_value)
|
||||
|
||||
@parameterized.parameters("UTF-8", "UTF-16-BE", "UTF-32-BE")
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
@ -264,7 +261,7 @@ class UnicodeEncodeOpTest(test.TestCase, parameterized.TestCase):
|
||||
[u"Hyper".encode(encoding),
|
||||
u"cube.".encode(encoding)]]]]
|
||||
unicode_encode_op = ragged_string_ops.unicode_encode(test_value, encoding)
|
||||
self.assertRaggedEqual(unicode_encode_op, expected_value)
|
||||
self.assertAllEqual(unicode_encode_op, expected_value)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -53,21 +53,17 @@ py_library(
|
||||
srcs = ["ragged_array_ops.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":ragged_conversion_ops",
|
||||
":ragged_functional_ops",
|
||||
":ragged_math_ops",
|
||||
":ragged_tensor",
|
||||
":ragged_util",
|
||||
":segment_id_ops",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:check_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:ragged_array_ops_gen",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:tensor_util",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
||||
@ -76,17 +72,11 @@ py_library(
|
||||
srcs = ["ragged_batch_gather_ops.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":ragged_array_ops",
|
||||
":ragged_concat_ops",
|
||||
":ragged_conversion_ops",
|
||||
":ragged_gather_ops",
|
||||
":ragged_tensor",
|
||||
":ragged_tensor_shape",
|
||||
":ragged_util",
|
||||
":ragged_where_op",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:check_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
],
|
||||
@ -100,8 +90,6 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":ragged_array_ops",
|
||||
":ragged_batch_gather_ops",
|
||||
":ragged_concat_ops",
|
||||
":ragged_dispatch",
|
||||
":ragged_operators",
|
||||
":ragged_tensor",
|
||||
@ -109,7 +97,7 @@ py_library(
|
||||
":ragged_where_op",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:check_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
],
|
||||
@ -121,7 +109,6 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":ragged_array_ops",
|
||||
":ragged_conversion_ops",
|
||||
":ragged_gather_ops",
|
||||
":ragged_tensor",
|
||||
":ragged_util",
|
||||
@ -138,19 +125,8 @@ py_library(
|
||||
srcs = ["ragged_conversion_ops.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":ragged_factory_ops",
|
||||
":ragged_tensor",
|
||||
":ragged_util",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:ragged_conversion_ops_gen",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
||||
@ -158,15 +134,14 @@ py_library(
|
||||
name = "ragged_factory_ops",
|
||||
srcs = ["ragged_factory_ops.py"],
|
||||
deps = [
|
||||
":ragged_tensor",
|
||||
":ragged_tensor_value",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:tensor_util",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/ops/ragged:ragged_tensor",
|
||||
"//tensorflow/python/ops/ragged:ragged_tensor_value",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
@ -176,10 +151,12 @@ py_library(
|
||||
srcs = ["ragged_functional_ops.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":ragged_factory_ops",
|
||||
":ragged_config",
|
||||
":ragged_tensor",
|
||||
":ragged_util",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
@ -190,7 +167,6 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":ragged_array_ops",
|
||||
":ragged_conversion_ops",
|
||||
":ragged_tensor",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
@ -206,15 +182,14 @@ py_library(
|
||||
srcs = ["ragged_getitem.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":ragged_factory_ops",
|
||||
":ragged_gather_ops",
|
||||
":ragged_math_ops",
|
||||
":ragged_tensor",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python/eager:context",
|
||||
],
|
||||
)
|
||||
|
||||
@ -223,7 +198,6 @@ py_library(
|
||||
srcs = ["ragged_math_ops.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":ragged_factory_ops",
|
||||
":ragged_functional_ops",
|
||||
":ragged_tensor",
|
||||
":ragged_util",
|
||||
@ -236,6 +210,7 @@ py_library(
|
||||
"//tensorflow/python:ragged_math_ops_gen",
|
||||
"//tensorflow/python:tensor_util",
|
||||
"//tensorflow/python:util",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
@ -257,14 +232,10 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":ragged_array_ops",
|
||||
":ragged_conversion_ops",
|
||||
":ragged_factory_ops",
|
||||
":ragged_tensor",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python:string_ops",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
@ -276,10 +247,13 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":ragged_tensor",
|
||||
":ragged_tensor_shape",
|
||||
":ragged_util",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:ops",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
],
|
||||
)
|
||||
|
||||
@ -287,9 +261,6 @@ py_library(
|
||||
name = "ragged_config",
|
||||
srcs = ["ragged_config.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:dtypes",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
@ -301,11 +272,23 @@ py_library(
|
||||
":ragged_tensor_value",
|
||||
":ragged_util",
|
||||
":segment_id_ops",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:check_ops",
|
||||
"//tensorflow/python:composite_tensor",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:ragged_conversion_ops_gen",
|
||||
"//tensorflow/python:session",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:tensor_spec",
|
||||
"//tensorflow/python:tensor_util",
|
||||
"//tensorflow/python:type_spec",
|
||||
"//tensorflow/python:util",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
@ -315,8 +298,7 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":ragged_array_ops",
|
||||
":ragged_conversion_ops",
|
||||
":ragged_factory_ops",
|
||||
":ragged_config",
|
||||
":ragged_tensor",
|
||||
":ragged_util",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -325,6 +307,7 @@ py_library(
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:tensor_util",
|
||||
],
|
||||
)
|
||||
@ -363,7 +346,6 @@ py_library(
|
||||
":ragged_gather_ops",
|
||||
":ragged_tensor",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
],
|
||||
@ -374,7 +356,6 @@ py_library(
|
||||
srcs = ["segment_id_ops.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":ragged_config",
|
||||
":ragged_util",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
@ -391,6 +372,7 @@ py_library(
|
||||
srcs = ["ragged_map_ops.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":ragged_config",
|
||||
":ragged_tensor",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
@ -415,14 +397,18 @@ py_library(
|
||||
deps = [
|
||||
":ragged_array_ops",
|
||||
":ragged_batch_gather_ops",
|
||||
":ragged_concat_ops",
|
||||
":ragged_gather_ops",
|
||||
":ragged_math_ops",
|
||||
":ragged_squeeze_op",
|
||||
":ragged_tensor",
|
||||
":ragged_tensor_shape",
|
||||
":ragged_util",
|
||||
":ragged_where_op",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:bitwise_ops",
|
||||
"//tensorflow/python:clip_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:parsing_ops",
|
||||
@ -438,19 +424,6 @@ py_library(
|
||||
# RaggedTensor Tests
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
py_library(
|
||||
name = "ragged_test_util",
|
||||
srcs = ["ragged_test_util.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":ragged_tensor",
|
||||
":ragged_tensor_value",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "ragged_tensor_test",
|
||||
size = "medium",
|
||||
@ -467,7 +440,6 @@ py_test(
|
||||
":ragged_math_ops",
|
||||
":ragged_tensor",
|
||||
":ragged_tensor_value",
|
||||
":ragged_test_util",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
@ -475,6 +447,7 @@ py_test(
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//third_party/py/numpy",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
@ -489,8 +462,8 @@ py_test(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":ragged_factory_ops",
|
||||
":ragged_test_util",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
@ -503,7 +476,6 @@ py_test(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":ragged_math_ops",
|
||||
":ragged_test_util",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
@ -518,7 +490,6 @@ py_test(
|
||||
deps = [
|
||||
":ragged_factory_ops",
|
||||
":ragged_tensor",
|
||||
":ragged_test_util",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
@ -532,7 +503,6 @@ py_test(
|
||||
deps = [
|
||||
":ragged_factory_ops",
|
||||
":ragged_tensor",
|
||||
":ragged_test_util",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
@ -546,10 +516,8 @@ py_test(
|
||||
python_version = "PY2",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":ragged_array_ops",
|
||||
":ragged_factory_ops",
|
||||
":ragged_gather_ops",
|
||||
":ragged_test_util",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
@ -567,13 +535,12 @@ py_test(
|
||||
python_version = "PY2",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":ragged_array_ops",
|
||||
":ragged_batch_gather_ops",
|
||||
":ragged_batch_gather_with_default_op",
|
||||
":ragged_factory_ops",
|
||||
":ragged_tensor",
|
||||
":ragged_test_util",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
@ -591,7 +558,6 @@ py_test(
|
||||
deps = [
|
||||
":ragged_factory_ops",
|
||||
":ragged_gather_ops",
|
||||
":ragged_test_util",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
@ -609,7 +575,6 @@ py_test(
|
||||
python_version = "PY2",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":ragged_test_util",
|
||||
":segment_id_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
@ -623,7 +588,6 @@ py_test(
|
||||
python_version = "PY2",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":ragged_test_util",
|
||||
":segment_id_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
@ -637,15 +601,11 @@ py_test(
|
||||
python_version = "PY2",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":ragged_string_ops",
|
||||
":ragged_tensor",
|
||||
":ragged_test_util",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:nn_ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
@ -664,7 +624,6 @@ py_test(
|
||||
":ragged_factory_ops",
|
||||
":ragged_functional_ops",
|
||||
":ragged_tensor",
|
||||
":ragged_test_util",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
@ -683,7 +642,6 @@ py_test(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":ragged_tensor",
|
||||
":ragged_test_util",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
@ -701,11 +659,10 @@ py_test(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":ragged_factory_ops",
|
||||
":ragged_test_util",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//third_party/py/numpy",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
@ -719,7 +676,6 @@ py_test(
|
||||
":ragged_factory_ops",
|
||||
":ragged_math_ops",
|
||||
":ragged_tensor",
|
||||
":ragged_test_util",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:errors",
|
||||
@ -737,7 +693,6 @@ py_test(
|
||||
deps = [
|
||||
":ragged_factory_ops",
|
||||
":ragged_math_ops",
|
||||
":ragged_test_util",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
@ -758,7 +713,6 @@ py_test(
|
||||
":ragged_factory_ops",
|
||||
":ragged_functional_ops",
|
||||
":ragged_tensor",
|
||||
":ragged_test_util",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
@ -778,10 +732,10 @@ py_test(
|
||||
":ragged",
|
||||
":ragged_factory_ops",
|
||||
":ragged_tensor",
|
||||
":ragged_test_util",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//third_party/py/numpy",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
@ -796,9 +750,7 @@ py_test(
|
||||
],
|
||||
deps = [
|
||||
":ragged_factory_ops",
|
||||
":ragged_tensor",
|
||||
":ragged_tensor_value",
|
||||
":ragged_test_util",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//third_party/py/numpy",
|
||||
@ -814,7 +766,6 @@ py_test(
|
||||
deps = [
|
||||
":ragged_factory_ops",
|
||||
":ragged_tensor",
|
||||
":ragged_test_util",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
@ -832,7 +783,6 @@ py_test(
|
||||
deps = [
|
||||
":ragged_array_ops",
|
||||
":ragged_factory_ops",
|
||||
":ragged_test_util",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
@ -851,7 +801,6 @@ py_test(
|
||||
deps = [
|
||||
":ragged_concat_ops",
|
||||
":ragged_factory_ops",
|
||||
":ragged_test_util",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
@ -871,7 +820,6 @@ py_test(
|
||||
deps = [
|
||||
":ragged_concat_ops",
|
||||
":ragged_factory_ops",
|
||||
":ragged_test_util",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
@ -887,8 +835,8 @@ py_test(
|
||||
deps = [
|
||||
":ragged_array_ops",
|
||||
":ragged_factory_ops",
|
||||
":ragged_test_util",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
@ -901,7 +849,6 @@ py_test(
|
||||
deps = [
|
||||
":ragged_array_ops",
|
||||
":ragged_factory_ops",
|
||||
":ragged_test_util",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
@ -917,7 +864,6 @@ py_test(
|
||||
python_version = "PY2",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":ragged_test_util",
|
||||
":ragged_util",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
@ -936,7 +882,6 @@ py_test(
|
||||
deps = [
|
||||
":ragged_array_ops",
|
||||
":ragged_factory_ops",
|
||||
":ragged_test_util",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
@ -950,7 +895,6 @@ py_test(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":ragged_factory_ops",
|
||||
":ragged_test_util",
|
||||
":ragged_where_op",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
@ -967,8 +911,8 @@ py_test(
|
||||
":ragged", # fixdeps: keep
|
||||
":ragged_factory_ops",
|
||||
":ragged_tensor",
|
||||
":ragged_test_util",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:bitwise_ops",
|
||||
"//tensorflow/python:clip_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
@ -993,7 +937,6 @@ py_test(
|
||||
deps = [
|
||||
":ragged", # fixdeps: keep
|
||||
":ragged_factory_ops",
|
||||
":ragged_test_util",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
@ -1012,12 +955,12 @@ py_test(
|
||||
":ragged_map_ops",
|
||||
":ragged_math_ops",
|
||||
":ragged_tensor",
|
||||
":ragged_test_util",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python:string_ops",
|
||||
"//tensorflow/python/keras:backend",
|
||||
"//third_party/py/numpy",
|
||||
@ -1035,7 +978,6 @@ py_test(
|
||||
":ragged_factory_ops",
|
||||
":ragged_tensor",
|
||||
":ragged_tensor_shape",
|
||||
":ragged_test_util",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
@ -1052,8 +994,6 @@ py_test(
|
||||
deps = [
|
||||
":ragged_array_ops",
|
||||
":ragged_factory_ops",
|
||||
":ragged_test_util",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
@ -1066,11 +1006,12 @@ py_test(
|
||||
python_version = "PY2",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":ragged_math_ops",
|
||||
":ragged_test_util",
|
||||
"//tensorflow/python:errors",
|
||||
":ragged_factory_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python/eager:context",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
@ -1083,8 +1024,9 @@ py_test(
|
||||
":ragged_conversion_ops",
|
||||
":ragged_factory_ops",
|
||||
":ragged_squeeze_op",
|
||||
":ragged_test_util",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
|
@ -26,13 +26,12 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.ops.ragged import ragged_test_util
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RaggedConvertToTensorOrRaggedTensorTest(
|
||||
ragged_test_util.RaggedTensorTestCase, parameterized.TestCase):
|
||||
class RaggedConvertToTensorOrRaggedTensorTest(test_util.TensorFlowTestCase,
|
||||
parameterized.TestCase):
|
||||
|
||||
#=============================================================================
|
||||
# Tests where the 'value' param is a RaggedTensor
|
||||
@ -126,7 +125,7 @@ class RaggedConvertToTensorOrRaggedTensorTest(
|
||||
value, dtype, preferred_dtype)
|
||||
self.assertEqual(value.ragged_rank, converted.ragged_rank)
|
||||
self.assertEqual(dtypes.as_dtype(expected_dtype), converted.dtype)
|
||||
self.assertEqual(value.to_list(), self.eval_to_list(converted))
|
||||
self.assertAllEqual(value, converted)
|
||||
|
||||
@parameterized.parameters([
|
||||
dict(
|
||||
|
@ -30,12 +30,11 @@ from tensorflow.python.ops.ragged import ragged_batch_gather_ops
|
||||
from tensorflow.python.ops.ragged import ragged_batch_gather_with_default_op
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.ops.ragged import ragged_test_util
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RaggedBatchGatherOpTest(ragged_test_util.RaggedTensorTestCase,
|
||||
class RaggedBatchGatherOpTest(test_util.TensorFlowTestCase,
|
||||
parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters([
|
||||
@ -149,7 +148,7 @@ class RaggedBatchGatherOpTest(ragged_test_util.RaggedTensorTestCase,
|
||||
])
|
||||
def testRaggedBatchGather(self, descr, params, indices, expected):
|
||||
result = ragged_batch_gather_ops.batch_gather(params, indices)
|
||||
self.assertRaggedEqual(result, expected)
|
||||
self.assertAllEqual(result, expected)
|
||||
|
||||
@parameterized.parameters([
|
||||
# Docstring example:
|
||||
@ -359,7 +358,7 @@ class RaggedBatchGatherOpTest(ragged_test_util.RaggedTensorTestCase,
|
||||
expected, ragged_rank=expected_ragged_rank or ragged_rank)
|
||||
result = ragged_batch_gather_with_default_op.batch_gather_with_default(
|
||||
params, indices, default_value)
|
||||
self.assertRaggedEqual(result, expected)
|
||||
self.assertAllEqual(result, expected)
|
||||
|
||||
@parameterized.parameters([
|
||||
# Dimensions:
|
||||
|
@ -27,12 +27,11 @@ from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops.ragged import ragged_array_ops
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_test_util
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RaggedBooleanMaskOpTest(ragged_test_util.RaggedTensorTestCase,
|
||||
class RaggedBooleanMaskOpTest(test_util.TensorFlowTestCase,
|
||||
parameterized.TestCase):
|
||||
# Define short constants for true & false, so the data & mask can be lined
|
||||
# up in the examples below. This makes it easier to read the examples, to
|
||||
@ -243,7 +242,7 @@ class RaggedBooleanMaskOpTest(ragged_test_util.RaggedTensorTestCase,
|
||||
]) # pyformat: disable
|
||||
def testBooleanMask(self, descr, data, mask, expected):
|
||||
actual = ragged_array_ops.boolean_mask(data, mask)
|
||||
self.assertRaggedEqual(actual, expected)
|
||||
self.assertAllEqual(actual, expected)
|
||||
|
||||
def testErrors(self):
|
||||
if not context.executing_eagerly():
|
||||
|
@ -28,12 +28,11 @@ from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops.ragged import ragged_concat_ops
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_test_util
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RaggedConcatOpTest(ragged_test_util.RaggedTensorTestCase,
|
||||
class RaggedConcatOpTest(test_util.TensorFlowTestCase,
|
||||
parameterized.TestCase):
|
||||
|
||||
def _rt_inputs_to_tensors(self, rt_inputs, ragged_ranks=None):
|
||||
@ -240,7 +239,7 @@ class RaggedConcatOpTest(ragged_test_util.RaggedTensorTestCase,
|
||||
self.assertEqual(concatenated.ragged_rank, expected_ragged_rank)
|
||||
if expected_shape is not None:
|
||||
self.assertEqual(concatenated.shape.as_list(), expected_shape)
|
||||
self.assertRaggedEqual(concatenated, expected)
|
||||
self.assertAllEqual(concatenated, expected)
|
||||
|
||||
@parameterized.parameters(
|
||||
dict(
|
||||
@ -318,7 +317,7 @@ class RaggedConcatOpTest(ragged_test_util.RaggedTensorTestCase,
|
||||
"""
|
||||
rt_inputs = ragged_factory_ops.constant([[1, 2], [3, 4]])
|
||||
concatenated = ragged_concat_ops.concat(rt_inputs, 0)
|
||||
self.assertRaggedEqual(concatenated, [[1, 2], [3, 4]])
|
||||
self.assertAllEqual(concatenated, [[1, 2], [3, 4]])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -26,12 +26,11 @@ from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import ragged
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.ops.ragged import ragged_test_util
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RaggedConstOpTest(ragged_test_util.RaggedTensorTestCase,
|
||||
class RaggedConstOpTest(test_util.TensorFlowTestCase,
|
||||
parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters(
|
||||
@ -203,7 +202,7 @@ class RaggedConstOpTest(ragged_test_util.RaggedTensorTestCase,
|
||||
pylist, dtype=dtype, ragged_rank=ragged_rank, inner_shape=inner_shape)
|
||||
# Normalize the pylist, i.e., convert all np.arrays to list.
|
||||
# E.g., [np.array((1,2))] --> [[1,2]]
|
||||
pylist = self._normalize_pylist(pylist)
|
||||
pylist = _normalize_pylist(pylist)
|
||||
|
||||
# If dtype was explicitly specified, check it.
|
||||
if dtype is not None:
|
||||
@ -227,8 +226,11 @@ class RaggedConstOpTest(ragged_test_util.RaggedTensorTestCase,
|
||||
|
||||
if expected_shape is not None:
|
||||
self.assertEqual(tuple(rt.shape.as_list()), expected_shape)
|
||||
if (expected_shape and expected_shape[0] == 0 and
|
||||
None not in expected_shape):
|
||||
pylist = np.zeros(expected_shape, rt.dtype.as_numpy_dtype)
|
||||
|
||||
self.assertRaggedEqual(rt, pylist)
|
||||
self.assertAllEqual(rt, pylist)
|
||||
|
||||
@parameterized.parameters(
|
||||
dict(
|
||||
@ -399,5 +401,14 @@ class RaggedConstOpTest(ragged_test_util.RaggedTensorTestCase,
|
||||
pylist, ragged_rank), inner_shape)
|
||||
|
||||
|
||||
def _normalize_pylist(item):
|
||||
"""Convert all (possibly nested) np.arrays contained in item to list."""
|
||||
# convert np.arrays in current level to list
|
||||
if np.ndim(item) == 0:
|
||||
return item
|
||||
level = (x.tolist() if isinstance(x, np.ndarray) else x for x in item)
|
||||
return [_normalize_pylist(el) if np.ndim(el) != 0 else el for el in level]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
googletest.main()
|
||||
|
@ -24,14 +24,12 @@ import numpy as np
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor_value
|
||||
from tensorflow.python.ops.ragged import ragged_test_util
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RaggedConstantValueOpTest(ragged_test_util.RaggedTensorTestCase,
|
||||
class RaggedConstantValueOpTest(test_util.TensorFlowTestCase,
|
||||
parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters(
|
||||
#=========================================================================
|
||||
# 0-dimensional tensors.
|
||||
@ -190,7 +188,7 @@ class RaggedConstantValueOpTest(ragged_test_util.RaggedTensorTestCase,
|
||||
pylist, dtype=dtype, ragged_rank=ragged_rank, inner_shape=inner_shape)
|
||||
# Normalize the pylist, i.e., convert all np.arrays to list.
|
||||
# E.g., [np.array((1,2))] --> [[1,2]]
|
||||
pylist = self._normalize_pylist(pylist)
|
||||
pylist = _normalize_pylist(pylist)
|
||||
# If dtype was explicitly specified, check it.
|
||||
if dtype is not None:
|
||||
self.assertEqual(rt.dtype, dtype)
|
||||
@ -315,5 +313,14 @@ class RaggedConstantValueOpTest(ragged_test_util.RaggedTensorTestCase,
|
||||
inner_shape=inner_shape)
|
||||
|
||||
|
||||
def _normalize_pylist(item):
|
||||
"""Convert all (possibly nested) np.arrays contained in item to list."""
|
||||
# convert np.arrays in current level to list
|
||||
if np.ndim(item) == 0:
|
||||
return item
|
||||
level = (x.tolist() if isinstance(x, np.ndarray) else x for x in item)
|
||||
return [_normalize_pylist(el) if np.ndim(el) != 0 else el for el in level]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
googletest.main()
|
||||
|
@ -35,7 +35,6 @@ from tensorflow.python.ops import parsing_ops
|
||||
from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.ops.ragged import ragged_test_util
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
# Constants listing various op types to test. Each operation
|
||||
@ -137,7 +136,7 @@ BINARY_INT_OPS = [
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RaggedElementwiseOpsTest(ragged_test_util.RaggedTensorTestCase,
|
||||
class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase,
|
||||
parameterized.TestCase):
|
||||
|
||||
def assertSameShape(self, x, y):
|
||||
@ -468,7 +467,7 @@ class RaggedElementwiseOpsTest(ragged_test_util.RaggedTensorTestCase,
|
||||
x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x, dtype=dtypes.int32)
|
||||
y = ragged_tensor.convert_to_tensor_or_ragged_tensor(y, dtype=dtypes.int32)
|
||||
result = x + y
|
||||
self.assertRaggedEqual(result, expected)
|
||||
self.assertAllEqual(result, expected)
|
||||
|
||||
def testElementwiseOpShapeMismatch(self):
|
||||
x = ragged_factory_ops.constant([[1, 2, 3], [4, 5]])
|
||||
@ -713,7 +712,7 @@ class RaggedElementwiseOpsTest(ragged_test_util.RaggedTensorTestCase,
|
||||
def testRaggedDispatch(self, op, expected, args=(), kwargs=None):
|
||||
if kwargs is None: kwargs = {}
|
||||
result = op(*args, **kwargs)
|
||||
self.assertRaggedEqual(result, expected)
|
||||
self.assertAllEqual(result, expected)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -21,12 +21,12 @@ from __future__ import print_function
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_test_util
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
|
||||
class RaggedTensorTest(test_util.TensorFlowTestCase,
|
||||
parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters([
|
||||
@ -36,7 +36,7 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
|
||||
])
|
||||
def testRaggedTensorToList(self, pylist, ragged_rank=None):
|
||||
rt = ragged_factory_ops.constant(pylist, ragged_rank)
|
||||
self.assertRaggedEqual(rt, pylist)
|
||||
self.assertAllEqual(rt, pylist)
|
||||
|
||||
@parameterized.parameters([
|
||||
dict(pylist=[[b'a', b'b'], [b'c']]),
|
||||
|
@ -23,12 +23,11 @@ from absl.testing import parameterized
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops.ragged import ragged_array_ops
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_test_util
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RaggedExpandDimsOpTest(ragged_test_util.RaggedTensorTestCase,
|
||||
class RaggedExpandDimsOpTest(test_util.TensorFlowTestCase,
|
||||
parameterized.TestCase):
|
||||
|
||||
# An example 4-d ragged tensor with shape [3, (D2), (D3), 2], and the
|
||||
@ -120,7 +119,7 @@ class RaggedExpandDimsOpTest(ragged_test_util.RaggedTensorTestCase,
|
||||
if expected_shape is not None:
|
||||
self.assertEqual(expanded.shape.as_list(), expected_shape)
|
||||
|
||||
self.assertRaggedEqual(expanded, expected)
|
||||
self.assertAllEqual(expanded, expected)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -24,13 +24,12 @@ from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops.ragged import ragged_test_util
|
||||
from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensor
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RaggedTensorToSparseOpTest(ragged_test_util.RaggedTensorTestCase):
|
||||
class RaggedTensorToSparseOpTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testDocStringExample(self):
|
||||
st = sparse_tensor.SparseTensor(
|
||||
@ -39,7 +38,7 @@ class RaggedTensorToSparseOpTest(ragged_test_util.RaggedTensorTestCase):
|
||||
dense_shape=[4, 3])
|
||||
rt = RaggedTensor.from_sparse(st)
|
||||
|
||||
self.assertRaggedEqual(rt, [[1, 2, 3], [4], [], [5]])
|
||||
self.assertAllEqual(rt, [[1, 2, 3], [4], [], [5]])
|
||||
|
||||
def testEmpty(self):
|
||||
st = sparse_tensor.SparseTensor(
|
||||
@ -48,7 +47,7 @@ class RaggedTensorToSparseOpTest(ragged_test_util.RaggedTensorTestCase):
|
||||
dense_shape=[4, 3])
|
||||
rt = RaggedTensor.from_sparse(st)
|
||||
|
||||
self.assertRaggedEqual(rt, [[], [], [], []])
|
||||
self.assertAllEqual(rt, [[], [], [], []])
|
||||
|
||||
def testBadSparseTensorRank(self):
|
||||
st1 = sparse_tensor.SparseTensor(indices=[[0]], values=[0], dense_shape=[3])
|
||||
|
@ -24,31 +24,30 @@ from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.ragged import ragged_test_util
|
||||
from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensor
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RaggedTensorToSparseOpTest(ragged_test_util.RaggedTensorTestCase,
|
||||
class RaggedTensorToSparseOpTest(test_util.TensorFlowTestCase,
|
||||
parameterized.TestCase):
|
||||
|
||||
def testDocStringExamples(self):
|
||||
# The examples from RaggedTensor.from_tensor.__doc__.
|
||||
dt = constant_op.constant([[5, 7, 0], [0, 3, 0], [6, 0, 0]])
|
||||
self.assertRaggedEqual(
|
||||
self.assertAllEqual(
|
||||
RaggedTensor.from_tensor(dt), [[5, 7, 0], [0, 3, 0], [6, 0, 0]])
|
||||
|
||||
self.assertRaggedEqual(
|
||||
self.assertAllEqual(
|
||||
RaggedTensor.from_tensor(dt, lengths=[1, 0, 3]), [[5], [], [6, 0, 0]])
|
||||
|
||||
self.assertRaggedEqual(
|
||||
self.assertAllEqual(
|
||||
RaggedTensor.from_tensor(dt, padding=0), [[5, 7], [0, 3], [6]])
|
||||
|
||||
dt_3d = constant_op.constant([[[5, 0], [7, 0], [0, 0]],
|
||||
[[0, 0], [3, 0], [0, 0]],
|
||||
[[6, 0], [0, 0], [0, 0]]])
|
||||
self.assertRaggedEqual(
|
||||
self.assertAllEqual(
|
||||
RaggedTensor.from_tensor(dt_3d, lengths=([2, 0, 3], [1, 1, 2, 0, 1])),
|
||||
[[[5], [7]], [], [[6, 0], [], [0]]])
|
||||
|
||||
@ -320,7 +319,7 @@ class RaggedTensorToSparseOpTest(ragged_test_util.RaggedTensorTestCase,
|
||||
self.assertTrue(
|
||||
dt.shape.is_compatible_with(rt.shape),
|
||||
'%s is incompatible with %s' % (dt.shape, rt.shape))
|
||||
self.assertRaggedEqual(rt, expected)
|
||||
self.assertAllEqual(rt, expected)
|
||||
|
||||
def testHighDimensions(self):
|
||||
# Use distinct prime numbers for all dimension shapes in this test, so
|
||||
@ -334,7 +333,7 @@ class RaggedTensorToSparseOpTest(ragged_test_util.RaggedTensorTestCase,
|
||||
self.assertTrue(
|
||||
dt.shape.is_compatible_with(rt.shape),
|
||||
'%s is incompatible with %s' % (dt.shape, rt.shape))
|
||||
self.assertRaggedEqual(rt, self.evaluate(dt).tolist())
|
||||
self.assertAllEqual(rt, self.evaluate(dt).tolist())
|
||||
|
||||
@parameterized.parameters(
|
||||
# With no padding or lengths
|
||||
@ -444,7 +443,7 @@ class RaggedTensorToSparseOpTest(ragged_test_util.RaggedTensorTestCase,
|
||||
self.assertEqual(type(rt), RaggedTensor)
|
||||
self.assertEqual(rt.ragged_rank, 1)
|
||||
self.assertTrue(dt.shape.is_compatible_with(rt.shape))
|
||||
self.assertRaggedEqual(rt, expected)
|
||||
self.assertAllEqual(rt, expected)
|
||||
|
||||
@parameterized.parameters(
|
||||
{
|
||||
|
@ -28,12 +28,11 @@ from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_gather_ops
|
||||
from tensorflow.python.ops.ragged import ragged_test_util
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RaggedGatherNdOpTest(ragged_test_util.RaggedTensorTestCase,
|
||||
class RaggedGatherNdOpTest(test_util.TensorFlowTestCase,
|
||||
parameterized.TestCase):
|
||||
|
||||
DOCSTRING_PARAMS = [[['000', '001'], ['010']],
|
||||
@ -202,7 +201,7 @@ class RaggedGatherNdOpTest(ragged_test_util.RaggedTensorTestCase,
|
||||
]) # pyformat: disable
|
||||
def testRaggedGatherNd(self, descr, params, indices, expected):
|
||||
result = ragged_gather_ops.gather_nd(params, indices)
|
||||
self.assertRaggedEqual(result, expected)
|
||||
self.assertAllEqual(result, expected)
|
||||
|
||||
def testRaggedGatherNdUnknownRankError(self):
|
||||
if context.executing_eagerly():
|
||||
|
@ -26,12 +26,11 @@ from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_gather_ops
|
||||
from tensorflow.python.ops.ragged import ragged_test_util
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RaggedGatherOpTest(ragged_test_util.RaggedTensorTestCase):
|
||||
class RaggedGatherOpTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testDocStringExamples(self):
|
||||
params = constant_op.constant(['a', 'b', 'c', 'd', 'e'])
|
||||
@ -39,20 +38,20 @@ class RaggedGatherOpTest(ragged_test_util.RaggedTensorTestCase):
|
||||
ragged_params = ragged_factory_ops.constant([['a', 'b', 'c'], ['d'], [],
|
||||
['e']])
|
||||
ragged_indices = ragged_factory_ops.constant([[3, 1, 2], [1], [], [0]])
|
||||
self.assertRaggedEqual(
|
||||
self.assertAllEqual(
|
||||
ragged_gather_ops.gather(params, ragged_indices),
|
||||
[[b'd', b'b', b'c'], [b'b'], [], [b'a']])
|
||||
self.assertRaggedEqual(
|
||||
self.assertAllEqual(
|
||||
ragged_gather_ops.gather(ragged_params, indices),
|
||||
[[b'e'], [b'd'], [], [b'd'], [b'a', b'b', b'c']])
|
||||
self.assertRaggedEqual(
|
||||
self.assertAllEqual(
|
||||
ragged_gather_ops.gather(ragged_params, ragged_indices),
|
||||
[[[b'e'], [b'd'], []], [[b'd']], [], [[b'a', b'b', b'c']]])
|
||||
|
||||
def testTensorParamsAndTensorIndices(self):
|
||||
params = ['a', 'b', 'c', 'd', 'e']
|
||||
indices = [2, 0, 2, 1]
|
||||
self.assertRaggedEqual(
|
||||
self.assertAllEqual(
|
||||
ragged_gather_ops.gather(params, indices), [b'c', b'a', b'c', b'b'])
|
||||
self.assertIsInstance(ragged_gather_ops.gather(params, indices), ops.Tensor)
|
||||
|
||||
@ -60,14 +59,14 @@ class RaggedGatherOpTest(ragged_test_util.RaggedTensorTestCase):
|
||||
params = ragged_factory_ops.constant([['a', 'b'], ['c', 'd', 'e'], ['f'],
|
||||
[], ['g']])
|
||||
indices = [2, 0, 2, 1]
|
||||
self.assertRaggedEqual(
|
||||
self.assertAllEqual(
|
||||
ragged_gather_ops.gather(params, indices),
|
||||
[[b'f'], [b'a', b'b'], [b'f'], [b'c', b'd', b'e']])
|
||||
|
||||
def testTensorParamsAndRaggedIndices(self):
|
||||
params = ['a', 'b', 'c', 'd', 'e']
|
||||
indices = ragged_factory_ops.constant([[2, 1], [1, 2, 0], [3]])
|
||||
self.assertRaggedEqual(
|
||||
self.assertAllEqual(
|
||||
ragged_gather_ops.gather(params, indices),
|
||||
[[b'c', b'b'], [b'b', b'c', b'a'], [b'd']])
|
||||
|
||||
@ -75,7 +74,7 @@ class RaggedGatherOpTest(ragged_test_util.RaggedTensorTestCase):
|
||||
params = ragged_factory_ops.constant([['a', 'b'], ['c', 'd', 'e'], ['f'],
|
||||
[], ['g']])
|
||||
indices = ragged_factory_ops.constant([[2, 1], [1, 2, 0], [3]])
|
||||
self.assertRaggedEqual(
|
||||
self.assertAllEqual(
|
||||
ragged_gather_ops.gather(params, indices),
|
||||
[[[b'f'], [b'c', b'd', b'e']], # [[p[2], p[1] ],
|
||||
[[b'c', b'd', b'e'], [b'f'], [b'a', b'b']], # [p[1], p[2], p[0]],
|
||||
@ -86,14 +85,14 @@ class RaggedGatherOpTest(ragged_test_util.RaggedTensorTestCase):
|
||||
params = ragged_factory_ops.constant([['a', 'b'], ['c', 'd', 'e'], ['f'],
|
||||
[], ['g']])
|
||||
indices = 1
|
||||
self.assertRaggedEqual(
|
||||
self.assertAllEqual(
|
||||
ragged_gather_ops.gather(params, indices), [b'c', b'd', b'e'])
|
||||
|
||||
def test3DRaggedParamsAnd2DTensorIndices(self):
|
||||
params = ragged_factory_ops.constant([[['a', 'b'], []],
|
||||
[['c', 'd'], ['e'], ['f']], [['g']]])
|
||||
indices = [[1, 2], [0, 1], [2, 2]]
|
||||
self.assertRaggedEqual(
|
||||
self.assertAllEqual(
|
||||
ragged_gather_ops.gather(params, indices),
|
||||
[[[[b'c', b'd'], [b'e'], [b'f']], [[b'g']]], # [[p1, p2],
|
||||
[[[b'a', b'b'], []], [[b'c', b'd'], [b'e'], [b'f']]], # [p0, p1],
|
||||
@ -107,7 +106,7 @@ class RaggedGatherOpTest(ragged_test_util.RaggedTensorTestCase):
|
||||
ragged_rank=2,
|
||||
inner_shape=(2,))
|
||||
params = ['a', 'b', 'c', 'd', 'e', 'f', 'g']
|
||||
self.assertRaggedEqual(
|
||||
self.assertAllEqual(
|
||||
ragged_gather_ops.gather(params, indices),
|
||||
[[[[b'd', b'e'], [b'a', b'g']], []],
|
||||
[[[b'c', b'b'], [b'b', b'a']], [[b'c', b'f']], [[b'c', b'd']]],
|
||||
|
@ -27,12 +27,11 @@ from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_functional_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.ops.ragged import ragged_test_util
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RaggedMapInnerValuesOpTest(ragged_test_util.RaggedTensorTestCase):
|
||||
class RaggedMapInnerValuesOpTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def assertRaggedMapInnerValuesReturns(self,
|
||||
op,
|
||||
@ -41,7 +40,7 @@ class RaggedMapInnerValuesOpTest(ragged_test_util.RaggedTensorTestCase):
|
||||
kwargs=None):
|
||||
kwargs = kwargs or {}
|
||||
result = ragged_functional_ops.map_flat_values(op, *args, **kwargs)
|
||||
self.assertRaggedEqual(result, expected)
|
||||
self.assertAllEqual(result, expected)
|
||||
|
||||
def testDocStringExamples(self):
|
||||
"""Test the examples in apply_op_to_ragged_values.__doc__."""
|
||||
@ -49,9 +48,9 @@ class RaggedMapInnerValuesOpTest(ragged_test_util.RaggedTensorTestCase):
|
||||
v1 = ragged_functional_ops.map_flat_values(array_ops.ones_like, rt)
|
||||
v2 = ragged_functional_ops.map_flat_values(math_ops.multiply, rt, rt)
|
||||
v3 = ragged_functional_ops.map_flat_values(math_ops.add, rt, 5)
|
||||
self.assertRaggedEqual(v1, [[1, 1, 1], [], [1, 1], [1]])
|
||||
self.assertRaggedEqual(v2, [[1, 4, 9], [], [16, 25], [36]])
|
||||
self.assertRaggedEqual(v3, [[6, 7, 8], [], [9, 10], [11]])
|
||||
self.assertAllEqual(v1, [[1, 1, 1], [], [1, 1], [1]])
|
||||
self.assertAllEqual(v2, [[1, 4, 9], [], [16, 25], [36]])
|
||||
self.assertAllEqual(v3, [[6, 7, 8], [], [9, 10], [11]])
|
||||
|
||||
def testOpWithSingleRaggedTensorArg(self):
|
||||
tensor = ragged_factory_ops.constant([[1, 2, 3], [], [4, 5]])
|
||||
@ -122,7 +121,7 @@ class RaggedMapInnerValuesOpTest(ragged_test_util.RaggedTensorTestCase):
|
||||
# ragged_rank=0
|
||||
x0 = [3, 1, 4, 1, 5, 9, 2, 6, 5]
|
||||
y0 = [1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
self.assertRaggedEqual(
|
||||
self.assertAllEqual(
|
||||
math_ops.multiply(x0, y0), [3, 2, 12, 4, 25, 54, 14, 48, 45])
|
||||
|
||||
# ragged_rank=1
|
||||
|
@ -32,12 +32,11 @@ from tensorflow.python.ops.ragged import ragged_functional_ops
|
||||
from tensorflow.python.ops.ragged import ragged_map_ops
|
||||
from tensorflow.python.ops.ragged import ragged_math_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.ops.ragged import ragged_test_util
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RaggedMapOpTest(ragged_test_util.RaggedTensorTestCase,
|
||||
class RaggedMapOpTest(test_util.TensorFlowTestCase,
|
||||
parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters([
|
||||
@ -168,7 +167,7 @@ class RaggedMapOpTest(ragged_test_util.RaggedTensorTestCase,
|
||||
|
||||
expected_rt = ragged_factory_ops.constant(
|
||||
expected_output, ragged_rank=expected_ragged_rank)
|
||||
self.assertRaggedEqual(expected_rt, output)
|
||||
self.assertAllEqual(expected_rt, output)
|
||||
|
||||
def testRaggedMapOnStructure(self):
|
||||
batman = ragged_factory_ops.constant([[1, 2, 3], [4], [5, 6, 7]])
|
||||
@ -186,7 +185,7 @@ class RaggedMapOpTest(ragged_test_util.RaggedTensorTestCase,
|
||||
dtype=dtypes.int32,
|
||||
)
|
||||
|
||||
self.assertRaggedEqual(output, [66, 44, 198])
|
||||
self.assertAllEqual(output, [66, 44, 198])
|
||||
|
||||
# Test mapping over a dict of RTs can produce a dict of RTs.
|
||||
def testRaggedMapOnStructure_RaggedOutputs(self):
|
||||
@ -216,8 +215,8 @@ class RaggedMapOpTest(ragged_test_util.RaggedTensorTestCase,
|
||||
},
|
||||
)
|
||||
|
||||
self.assertRaggedEqual(output['batman'], [[2, 3, 4], [5], [6, 7, 8]])
|
||||
self.assertRaggedEqual(output['robin'], [[11, 21, 31], [41], [51, 61, 71]])
|
||||
self.assertAllEqual(output['batman'], [[2, 3, 4], [5], [6, 7, 8]])
|
||||
self.assertAllEqual(output['robin'], [[11, 21, 31], [41], [51, 61, 71]])
|
||||
|
||||
def testZip(self):
|
||||
x = ragged_factory_ops.constant(
|
||||
@ -234,7 +233,7 @@ class RaggedMapOpTest(ragged_test_util.RaggedTensorTestCase,
|
||||
dtype=ragged_tensor.RaggedTensorType(dtype=dtypes.int64, ragged_rank=1),
|
||||
infer_shape=False)
|
||||
|
||||
self.assertRaggedEqual(
|
||||
self.assertAllEqual(
|
||||
output, [[[0, 10], [0, 20]], [[1, 30], [1, 40]], [[2, 50], [2, 60]],
|
||||
[[3, 70]], [[4, 80], [4, 90], [4, 100]]])
|
||||
|
||||
@ -255,7 +254,7 @@ class RaggedMapOpTest(ragged_test_util.RaggedTensorTestCase,
|
||||
dtype=dtypes.string, ragged_rank=1),
|
||||
infer_shape=False)
|
||||
|
||||
self.assertRaggedEqual(
|
||||
self.assertAllEqual(
|
||||
out, [[b'hello', b'there'], [b'merhaba'], [b'bonjour', b'ca va']])
|
||||
|
||||
def testMismatchRaggedRank(self):
|
||||
@ -290,7 +289,7 @@ class RaggedMapOpTest(ragged_test_util.RaggedTensorTestCase,
|
||||
id_t2 = ragged_map_ops.map_fn(
|
||||
lambda x: x, t2,
|
||||
)
|
||||
self.assertRaggedEqual(id_t2, [[0, 5], [0, 4]])
|
||||
self.assertAllEqual(id_t2, [[0, 5], [0, 4]])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -20,71 +20,70 @@ from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_test_util
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RaggedElementwiseOpsTest(ragged_test_util.RaggedTensorTestCase):
|
||||
class RaggedElementwiseOpsTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testOrderingOperators(self):
|
||||
x = ragged_factory_ops.constant([[1, 5], [3]])
|
||||
y = ragged_factory_ops.constant([[4, 5], [1]])
|
||||
self.assertRaggedEqual((x > y), [[False, False], [True]])
|
||||
self.assertRaggedEqual((x >= y), [[False, True], [True]])
|
||||
self.assertRaggedEqual((x < y), [[True, False], [False]])
|
||||
self.assertRaggedEqual((x <= y), [[True, True], [False]])
|
||||
self.assertAllEqual((x > y), [[False, False], [True]])
|
||||
self.assertAllEqual((x >= y), [[False, True], [True]])
|
||||
self.assertAllEqual((x < y), [[True, False], [False]])
|
||||
self.assertAllEqual((x <= y), [[True, True], [False]])
|
||||
|
||||
def testArithmeticOperators(self):
|
||||
x = ragged_factory_ops.constant([[1.0, -2.0], [8.0]])
|
||||
y = ragged_factory_ops.constant([[4.0, 4.0], [2.0]])
|
||||
self.assertRaggedEqual(abs(x), [[1.0, 2.0], [8.0]])
|
||||
self.assertAllEqual(abs(x), [[1.0, 2.0], [8.0]])
|
||||
|
||||
self.assertRaggedEqual((-x), [[-1.0, 2.0], [-8.0]])
|
||||
self.assertAllEqual((-x), [[-1.0, 2.0], [-8.0]])
|
||||
|
||||
self.assertRaggedEqual((x + y), [[5.0, 2.0], [10.0]])
|
||||
self.assertRaggedEqual((3.0 + y), [[7.0, 7.0], [5.0]])
|
||||
self.assertRaggedEqual((x + 3.0), [[4.0, 1.0], [11.0]])
|
||||
self.assertAllEqual((x + y), [[5.0, 2.0], [10.0]])
|
||||
self.assertAllEqual((3.0 + y), [[7.0, 7.0], [5.0]])
|
||||
self.assertAllEqual((x + 3.0), [[4.0, 1.0], [11.0]])
|
||||
|
||||
self.assertRaggedEqual((x - y), [[-3.0, -6.0], [6.0]])
|
||||
self.assertRaggedEqual((3.0 - y), [[-1.0, -1.0], [1.0]])
|
||||
self.assertRaggedEqual((x + 3.0), [[4.0, 1.0], [11.0]])
|
||||
self.assertAllEqual((x - y), [[-3.0, -6.0], [6.0]])
|
||||
self.assertAllEqual((3.0 - y), [[-1.0, -1.0], [1.0]])
|
||||
self.assertAllEqual((x + 3.0), [[4.0, 1.0], [11.0]])
|
||||
|
||||
self.assertRaggedEqual((x * y), [[4.0, -8.0], [16.0]])
|
||||
self.assertRaggedEqual((3.0 * y), [[12.0, 12.0], [6.0]])
|
||||
self.assertRaggedEqual((x * 3.0), [[3.0, -6.0], [24.0]])
|
||||
self.assertAllEqual((x * y), [[4.0, -8.0], [16.0]])
|
||||
self.assertAllEqual((3.0 * y), [[12.0, 12.0], [6.0]])
|
||||
self.assertAllEqual((x * 3.0), [[3.0, -6.0], [24.0]])
|
||||
|
||||
self.assertRaggedEqual((x / y), [[0.25, -0.5], [4.0]])
|
||||
self.assertRaggedEqual((y / x), [[4.0, -2.0], [0.25]])
|
||||
self.assertRaggedEqual((2.0 / y), [[0.5, 0.5], [1.0]])
|
||||
self.assertRaggedEqual((x / 2.0), [[0.5, -1.0], [4.0]])
|
||||
self.assertAllEqual((x / y), [[0.25, -0.5], [4.0]])
|
||||
self.assertAllEqual((y / x), [[4.0, -2.0], [0.25]])
|
||||
self.assertAllEqual((2.0 / y), [[0.5, 0.5], [1.0]])
|
||||
self.assertAllEqual((x / 2.0), [[0.5, -1.0], [4.0]])
|
||||
|
||||
self.assertRaggedEqual((x // y), [[0.0, -1.0], [4.0]])
|
||||
self.assertRaggedEqual((y // x), [[4.0, -2.0], [0.0]])
|
||||
self.assertRaggedEqual((2.0 // y), [[0.0, 0.0], [1.0]])
|
||||
self.assertRaggedEqual((x // 2.0), [[0.0, -1.0], [4.0]])
|
||||
self.assertAllEqual((x // y), [[0.0, -1.0], [4.0]])
|
||||
self.assertAllEqual((y // x), [[4.0, -2.0], [0.0]])
|
||||
self.assertAllEqual((2.0 // y), [[0.0, 0.0], [1.0]])
|
||||
self.assertAllEqual((x // 2.0), [[0.0, -1.0], [4.0]])
|
||||
|
||||
self.assertRaggedEqual((x % y), [[1.0, 2.0], [0.0]])
|
||||
self.assertRaggedEqual((y % x), [[0.0, -0.0], [2.0]])
|
||||
self.assertRaggedEqual((2.0 % y), [[2.0, 2.0], [0.0]])
|
||||
self.assertRaggedEqual((x % 2.0), [[1.0, 0.0], [0.0]])
|
||||
self.assertAllEqual((x % y), [[1.0, 2.0], [0.0]])
|
||||
self.assertAllEqual((y % x), [[0.0, -0.0], [2.0]])
|
||||
self.assertAllEqual((2.0 % y), [[2.0, 2.0], [0.0]])
|
||||
self.assertAllEqual((x % 2.0), [[1.0, 0.0], [0.0]])
|
||||
|
||||
def testLogicalOperators(self):
|
||||
a = ragged_factory_ops.constant([[True, True], [False]])
|
||||
b = ragged_factory_ops.constant([[True, False], [False]])
|
||||
self.assertRaggedEqual((~a), [[False, False], [True]])
|
||||
self.assertAllEqual((~a), [[False, False], [True]])
|
||||
|
||||
self.assertRaggedEqual((a & b), [[True, False], [False]])
|
||||
self.assertRaggedEqual((a & True), [[True, True], [False]])
|
||||
self.assertRaggedEqual((True & b), [[True, False], [False]])
|
||||
self.assertAllEqual((a & b), [[True, False], [False]])
|
||||
self.assertAllEqual((a & True), [[True, True], [False]])
|
||||
self.assertAllEqual((True & b), [[True, False], [False]])
|
||||
|
||||
self.assertRaggedEqual((a | b), [[True, True], [False]])
|
||||
self.assertRaggedEqual((a | False), [[True, True], [False]])
|
||||
self.assertRaggedEqual((False | b), [[True, False], [False]])
|
||||
self.assertAllEqual((a | b), [[True, True], [False]])
|
||||
self.assertAllEqual((a | False), [[True, True], [False]])
|
||||
self.assertAllEqual((False | b), [[True, False], [False]])
|
||||
|
||||
self.assertRaggedEqual((a ^ b), [[False, True], [False]])
|
||||
self.assertRaggedEqual((a ^ True), [[False, False], [True]])
|
||||
self.assertRaggedEqual((True ^ b), [[False, True], [True]])
|
||||
self.assertAllEqual((a ^ b), [[False, True], [False]])
|
||||
self.assertAllEqual((a ^ True), [[False, False], [True]])
|
||||
self.assertAllEqual((True ^ b), [[False, True], [True]])
|
||||
|
||||
def testDummyOperators(self):
|
||||
a = ragged_factory_ops.constant([[True, True], [False]])
|
||||
|
@ -23,12 +23,11 @@ from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_test_util
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RaggedPlaceholderOpTest(ragged_test_util.RaggedTensorTestCase,
|
||||
class RaggedPlaceholderOpTest(test_util.TensorFlowTestCase,
|
||||
parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters([
|
||||
|
@ -21,40 +21,39 @@ from __future__ import print_function
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops.ragged import ragged_math_ops
|
||||
from tensorflow.python.ops.ragged import ragged_test_util
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RaggedRangeOpTest(ragged_test_util.RaggedTensorTestCase):
|
||||
class RaggedRangeOpTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testDocStringExamples(self):
|
||||
"""Examples from ragged_range.__doc__."""
|
||||
rt1 = ragged_math_ops.range([3, 5, 2])
|
||||
self.assertRaggedEqual(rt1, [[0, 1, 2], [0, 1, 2, 3, 4], [0, 1]])
|
||||
self.assertAllEqual(rt1, [[0, 1, 2], [0, 1, 2, 3, 4], [0, 1]])
|
||||
|
||||
rt2 = ragged_math_ops.range([0, 5, 8], [3, 3, 12])
|
||||
self.assertRaggedEqual(rt2, [[0, 1, 2], [], [8, 9, 10, 11]])
|
||||
self.assertAllEqual(rt2, [[0, 1, 2], [], [8, 9, 10, 11]])
|
||||
|
||||
rt3 = ragged_math_ops.range([0, 5, 8], [3, 3, 12], 2)
|
||||
self.assertRaggedEqual(rt3, [[0, 2], [], [8, 10]])
|
||||
self.assertAllEqual(rt3, [[0, 2], [], [8, 10]])
|
||||
|
||||
def testBasicRanges(self):
|
||||
# Specify limits only.
|
||||
self.assertRaggedEqual(
|
||||
self.assertAllEqual(
|
||||
ragged_math_ops.range([0, 3, 5]),
|
||||
[list(range(0)), list(range(3)),
|
||||
list(range(5))])
|
||||
|
||||
# Specify starts and limits.
|
||||
self.assertRaggedEqual(
|
||||
self.assertAllEqual(
|
||||
ragged_math_ops.range([0, 3, 5], [2, 3, 10]),
|
||||
[list(range(0, 2)),
|
||||
list(range(3, 3)),
|
||||
list(range(5, 10))])
|
||||
|
||||
# Specify starts, limits, and deltas.
|
||||
self.assertRaggedEqual(
|
||||
self.assertAllEqual(
|
||||
ragged_math_ops.range([0, 3, 5], [4, 4, 15], [2, 3, 4]),
|
||||
[list(range(0, 4, 2)),
|
||||
list(range(3, 4, 3)),
|
||||
@ -65,18 +64,16 @@ class RaggedRangeOpTest(ragged_test_util.RaggedTensorTestCase):
|
||||
[5.0, 7.2, 9.4, 11.6, 13.8]]
|
||||
actual = ragged_math_ops.range([0.0, 3.0, 5.0], [3.9, 4.0, 15.0],
|
||||
[0.4, 1.5, 2.2])
|
||||
self.assertEqual(
|
||||
expected,
|
||||
[[round(v, 5) for v in row] for row in self.eval_to_list(actual)])
|
||||
self.assertAllClose(actual, expected)
|
||||
|
||||
def testNegativeDeltas(self):
|
||||
self.assertRaggedEqual(
|
||||
self.assertAllEqual(
|
||||
ragged_math_ops.range([0, 3, 5], limits=0, deltas=-1),
|
||||
[list(range(0, 0, -1)),
|
||||
list(range(3, 0, -1)),
|
||||
list(range(5, 0, -1))])
|
||||
|
||||
self.assertRaggedEqual(
|
||||
self.assertAllEqual(
|
||||
ragged_math_ops.range([0, -3, 5], limits=0, deltas=[-1, 1, -2]),
|
||||
[list(range(0, 0, -1)),
|
||||
list(range(-3, 0, 1)),
|
||||
@ -84,21 +81,21 @@ class RaggedRangeOpTest(ragged_test_util.RaggedTensorTestCase):
|
||||
|
||||
def testBroadcast(self):
|
||||
# Specify starts and limits, broadcast deltas.
|
||||
self.assertRaggedEqual(
|
||||
self.assertAllEqual(
|
||||
ragged_math_ops.range([0, 3, 5], [4, 4, 15], 3),
|
||||
[list(range(0, 4, 3)),
|
||||
list(range(3, 4, 3)),
|
||||
list(range(5, 15, 3))])
|
||||
|
||||
# Broadcast all arguments.
|
||||
self.assertRaggedEqual(
|
||||
self.assertAllEqual(
|
||||
ragged_math_ops.range(0, 5, 1), [list(range(0, 5, 1))])
|
||||
|
||||
def testEmptyRanges(self):
|
||||
rt1 = ragged_math_ops.range([0, 5, 3], [0, 3, 5])
|
||||
rt2 = ragged_math_ops.range([0, 5, 5], [0, 3, 5], -1)
|
||||
self.assertRaggedEqual(rt1, [[], [], [3, 4]])
|
||||
self.assertRaggedEqual(rt2, [[], [5, 4], []])
|
||||
self.assertAllEqual(rt1, [[], [], [3, 4]])
|
||||
self.assertAllEqual(rt2, [[], [5, 4], []])
|
||||
|
||||
def testShapeFnErrors(self):
|
||||
self.assertRaises((ValueError, errors.InvalidArgumentError),
|
||||
@ -116,11 +113,11 @@ class RaggedRangeOpTest(ragged_test_util.RaggedTensorTestCase):
|
||||
self.evaluate(ragged_math_ops.range(0, 0, 0))
|
||||
|
||||
def testShape(self):
|
||||
self.assertRaggedEqual(
|
||||
self.assertAllEqual(
|
||||
ragged_math_ops.range(0, 0, 1).shape.as_list(), [1, None])
|
||||
self.assertRaggedEqual(
|
||||
self.assertAllEqual(
|
||||
ragged_math_ops.range([1, 2, 3]).shape.as_list(), [3, None])
|
||||
self.assertRaggedEqual(
|
||||
self.assertAllEqual(
|
||||
ragged_math_ops.range([1, 2, 3], [4, 5, 6]).shape.as_list(), [3, None])
|
||||
|
||||
|
||||
|
@ -21,12 +21,11 @@ from absl.testing import parameterized
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops.ragged import ragged_array_ops
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_test_util
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RaggedRankOpTest(ragged_test_util.RaggedTensorTestCase,
|
||||
class RaggedRankOpTest(test_util.TensorFlowTestCase,
|
||||
parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters([
|
||||
|
@ -28,7 +28,6 @@ from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_math_ops
|
||||
from tensorflow.python.ops.ragged import ragged_test_util
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
_MAX_INT32 = dtypes.int32.max
|
||||
@ -41,7 +40,7 @@ def mean(*values):
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RaggedReduceOpsTest(ragged_test_util.RaggedTensorTestCase,
|
||||
class RaggedReduceOpsTest(test_util.TensorFlowTestCase,
|
||||
parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters(
|
||||
@ -320,7 +319,7 @@ class RaggedReduceOpsTest(ragged_test_util.RaggedTensorTestCase,
|
||||
def testReduce(self, ragged_reduce_op, rt_input, axis, expected):
|
||||
rt_input = ragged_factory_ops.constant(rt_input)
|
||||
reduced = ragged_reduce_op(rt_input, axis)
|
||||
self.assertRaggedEqual(reduced, expected)
|
||||
self.assertAllEqual(reduced, expected)
|
||||
|
||||
def assertEqualWithNan(self, actual, expected):
|
||||
"""Like assertEqual, but NaN==NaN."""
|
||||
@ -340,7 +339,7 @@ class RaggedReduceOpsTest(ragged_test_util.RaggedTensorTestCase,
|
||||
tensor = [[1.0, 2.0, 3.0], [10.0, 20.0, 30.0]]
|
||||
expected = [2.0, 20.0]
|
||||
reduced = ragged_math_ops.reduce_mean(tensor, axis=1)
|
||||
self.assertRaggedEqual(reduced, expected)
|
||||
self.assertAllEqual(reduced, expected)
|
||||
|
||||
def testErrors(self):
|
||||
rt_input = ragged_factory_ops.constant([[1, 2, 3], [4, 5]])
|
||||
|
@ -24,12 +24,11 @@ from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.ops.ragged import ragged_test_util
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RaggedRowLengthsOp(ragged_test_util.RaggedTensorTestCase,
|
||||
class RaggedRowLengthsOp(test_util.TensorFlowTestCase,
|
||||
parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters([
|
||||
@ -120,7 +119,7 @@ class RaggedRowLengthsOp(ragged_test_util.RaggedTensorTestCase,
|
||||
expected_ragged_rank=None):
|
||||
rt = ragged_factory_ops.constant(rt_input, ragged_rank=ragged_rank)
|
||||
lengths = rt.row_lengths(axis)
|
||||
self.assertRaggedEqual(lengths, expected)
|
||||
self.assertAllEqual(lengths, expected)
|
||||
if expected_ragged_rank is not None:
|
||||
if isinstance(lengths, ragged_tensor.RaggedTensor):
|
||||
self.assertEqual(lengths.ragged_rank, expected_ragged_rank)
|
||||
|
@ -20,13 +20,12 @@ from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops.ragged import ragged_test_util
|
||||
from tensorflow.python.ops.ragged import segment_id_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RaggedSplitsToSegmentIdsOpTest(ragged_test_util.RaggedTensorTestCase):
|
||||
class RaggedSplitsToSegmentIdsOpTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testDocStringExample(self):
|
||||
splits = [0, 3, 3, 5, 6, 9]
|
||||
|
@ -20,13 +20,12 @@ from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops.ragged import ragged_test_util
|
||||
from tensorflow.python.ops.ragged import segment_id_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RaggedSplitsToSegmentIdsOpTest(ragged_test_util.RaggedTensorTestCase):
|
||||
class RaggedSplitsToSegmentIdsOpTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testDocStringExample(self):
|
||||
segment_ids = [0, 0, 0, 2, 2, 3, 4, 4, 4]
|
||||
|
@ -28,7 +28,6 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_math_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.ops.ragged import ragged_test_util
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
@ -49,7 +48,7 @@ def sqrt_n(values):
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RaggedSegmentOpsTest(ragged_test_util.RaggedTensorTestCase,
|
||||
class RaggedSegmentOpsTest(test_util.TensorFlowTestCase,
|
||||
parameterized.TestCase):
|
||||
|
||||
def expected_value(self, data, segment_ids, num_segments, combiner):
|
||||
@ -110,7 +109,7 @@ class RaggedSegmentOpsTest(ragged_test_util.RaggedTensorTestCase,
|
||||
combiner)
|
||||
|
||||
segmented = segment_op(rt, segment_ids, num_segments)
|
||||
self.assertRaggedEqual(segmented, expected)
|
||||
self.assertAllEqual(segmented, expected)
|
||||
|
||||
@parameterized.parameters(
|
||||
(ragged_math_ops.segment_sum, sum, [0, 0, 1, 1, 2, 2]),
|
||||
@ -146,7 +145,7 @@ class RaggedSegmentOpsTest(ragged_test_util.RaggedTensorTestCase,
|
||||
combiner)
|
||||
|
||||
segmented = segment_op(rt, segment_ids, num_segments)
|
||||
self.assertRaggedAlmostEqual(segmented, expected, places=5)
|
||||
self.assertAllClose(segmented, expected)
|
||||
|
||||
def testRaggedRankTwo(self):
|
||||
rt = ragged_factory_ops.constant([
|
||||
@ -161,14 +160,14 @@ class RaggedSegmentOpsTest(ragged_test_util.RaggedTensorTestCase,
|
||||
[], # row 1
|
||||
[[411, 412], [321, 322], [331]] # row 2
|
||||
] # pyformat: disable
|
||||
self.assertRaggedEqual(segmented1, expected1)
|
||||
self.assertAllEqual(segmented1, expected1)
|
||||
|
||||
segment_ids2 = [1, 2, 1, 1]
|
||||
segmented2 = ragged_math_ops.segment_sum(rt, segment_ids2, 3)
|
||||
expected2 = [[],
|
||||
[[111+411, 112+412, 113, 114], [121+321, 322], [331]],
|
||||
[]] # pyformat: disable
|
||||
self.assertRaggedEqual(segmented2, expected2)
|
||||
self.assertAllEqual(segmented2, expected2)
|
||||
|
||||
def testRaggedSegmentIds(self):
|
||||
rt = ragged_factory_ops.constant([
|
||||
@ -182,7 +181,7 @@ class RaggedSegmentOpsTest(ragged_test_util.RaggedTensorTestCase,
|
||||
expected = [[],
|
||||
[111+321, 112+322, 113, 114],
|
||||
[121+331+411, 412]] # pyformat: disable
|
||||
self.assertRaggedEqual(segmented, expected)
|
||||
self.assertAllEqual(segmented, expected)
|
||||
|
||||
def testShapeMismatchError1(self):
|
||||
dt = constant_op.constant([1, 2, 3, 4, 5, 6])
|
||||
|
@ -23,12 +23,11 @@ from absl.testing import parameterized
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops.ragged import ragged_array_ops
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_test_util
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RaggedSizeOpTest(ragged_test_util.RaggedTensorTestCase,
|
||||
class RaggedSizeOpTest(test_util.TensorFlowTestCase,
|
||||
parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters([
|
||||
|
@ -27,12 +27,11 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops.ragged import ragged_conversion_ops
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_squeeze_op
|
||||
from tensorflow.python.ops.ragged import ragged_test_util
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RaggedSqueezeTest(ragged_test_util.RaggedTensorTestCase,
|
||||
class RaggedSqueezeTest(test_util.TensorFlowTestCase,
|
||||
parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters([
|
||||
@ -52,7 +51,7 @@ class RaggedSqueezeTest(ragged_test_util.RaggedTensorTestCase,
|
||||
rt = ragged_squeeze_op.squeeze(
|
||||
ragged_factory_ops.constant(input_list), squeeze_ranks)
|
||||
dt = array_ops.squeeze(constant_op.constant(input_list), squeeze_ranks)
|
||||
self.assertRaggedEqual(ragged_conversion_ops.to_tensor(rt), dt)
|
||||
self.assertAllEqual(ragged_conversion_ops.to_tensor(rt), dt)
|
||||
|
||||
@parameterized.parameters([
|
||||
{
|
||||
@ -112,7 +111,7 @@ class RaggedSqueezeTest(ragged_test_util.RaggedTensorTestCase,
|
||||
rt = ragged_squeeze_op.squeeze(
|
||||
ragged_factory_ops.constant(input_list), squeeze_ranks)
|
||||
dt = array_ops.squeeze(constant_op.constant(input_list), squeeze_ranks)
|
||||
self.assertRaggedEqual(ragged_conversion_ops.to_tensor(rt), dt)
|
||||
self.assertAllEqual(ragged_conversion_ops.to_tensor(rt), dt)
|
||||
|
||||
@parameterized.parameters([
|
||||
# ragged_conversion_ops.from_tensor does not work for this
|
||||
@ -167,7 +166,7 @@ class RaggedSqueezeTest(ragged_test_util.RaggedTensorTestCase,
|
||||
rt = ragged_conversion_ops.from_tensor(dt)
|
||||
rt_s = ragged_squeeze_op.squeeze(rt, squeeze_ranks)
|
||||
dt_s = array_ops.squeeze(dt, squeeze_ranks)
|
||||
self.assertRaggedEqual(ragged_conversion_ops.to_tensor(rt_s), dt_s)
|
||||
self.assertAllEqual(ragged_conversion_ops.to_tensor(rt_s), dt_s)
|
||||
|
||||
@parameterized.parameters([
|
||||
{
|
||||
@ -185,7 +184,7 @@ class RaggedSqueezeTest(ragged_test_util.RaggedTensorTestCase,
|
||||
rt = ragged_factory_ops.constant(input_list)
|
||||
rt_s = ragged_squeeze_op.squeeze(rt, squeeze_ranks)
|
||||
ref = ragged_factory_ops.constant(output_list)
|
||||
self.assertRaggedEqual(rt_s, ref)
|
||||
self.assertAllEqual(rt_s, ref)
|
||||
|
||||
def test_passing_text(self):
|
||||
rt = ragged_factory_ops.constant([[[[[[[['H']], [['e']], [['l']], [['l']],
|
||||
@ -202,7 +201,7 @@ class RaggedSqueezeTest(ragged_test_util.RaggedTensorTestCase,
|
||||
['M', 'e', 'h', 'r', 'd', 'a', 'd'], ['.']]]
|
||||
ref = ragged_factory_ops.constant(output_list)
|
||||
rt_s = ragged_squeeze_op.squeeze(rt, [0, 1, 3, 6, 7])
|
||||
self.assertRaggedEqual(rt_s, ref)
|
||||
self.assertAllEqual(rt_s, ref)
|
||||
|
||||
@parameterized.parameters([
|
||||
{
|
||||
|
@ -24,12 +24,11 @@ from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops.ragged import ragged_concat_ops
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_test_util
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RaggedStackOpTest(ragged_test_util.RaggedTensorTestCase,
|
||||
class RaggedStackOpTest(test_util.TensorFlowTestCase,
|
||||
parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters(
|
||||
@ -335,7 +334,7 @@ class RaggedStackOpTest(ragged_test_util.RaggedTensorTestCase,
|
||||
self.assertEqual(stacked.ragged_rank, expected_ragged_rank)
|
||||
if expected_shape is not None:
|
||||
self.assertEqual(stacked.shape.as_list(), expected_shape)
|
||||
self.assertRaggedEqual(stacked, expected)
|
||||
self.assertAllEqual(stacked, expected)
|
||||
|
||||
@parameterized.parameters(
|
||||
dict(
|
||||
@ -372,7 +371,7 @@ class RaggedStackOpTest(ragged_test_util.RaggedTensorTestCase,
|
||||
"""
|
||||
rt_inputs = ragged_factory_ops.constant([[1, 2], [3, 4]])
|
||||
stacked = ragged_concat_ops.stack(rt_inputs, 0)
|
||||
self.assertRaggedEqual(stacked, [[[1, 2], [3, 4]]])
|
||||
self.assertAllEqual(stacked, [[[1, 2], [3, 4]]])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -21,43 +21,42 @@ from __future__ import print_function
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.ops.ragged import ragged_test_util
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RaggedTensorBoundingShapeOp(ragged_test_util.RaggedTensorTestCase):
|
||||
class RaggedTensorBoundingShapeOp(test_util.TensorFlowTestCase):
|
||||
|
||||
def testDocStringExample(self):
|
||||
# This is the example from ragged.bounding_shape.__doc__.
|
||||
rt = ragged_factory_ops.constant([[1, 2, 3, 4], [5], [], [6, 7, 8, 9],
|
||||
[10]])
|
||||
self.assertRaggedEqual(rt.bounding_shape(), [5, 4])
|
||||
self.assertAllEqual(rt.bounding_shape(), [5, 4])
|
||||
|
||||
def test2DRaggedTensorWithOneRaggedDimension(self):
|
||||
values = ['a', 'b', 'c', 'd', 'e', 'f', 'g']
|
||||
rt1 = ragged_tensor.RaggedTensor.from_row_splits(values, [0, 2, 5, 6, 6, 7])
|
||||
rt2 = ragged_tensor.RaggedTensor.from_row_splits(values, [0, 7])
|
||||
rt3 = ragged_tensor.RaggedTensor.from_row_splits(values, [0, 0, 7, 7])
|
||||
self.assertRaggedEqual(rt1.bounding_shape(), [5, 3])
|
||||
self.assertRaggedEqual(rt2.bounding_shape(), [1, 7])
|
||||
self.assertRaggedEqual(rt3.bounding_shape(), [3, 7])
|
||||
self.assertAllEqual(rt1.bounding_shape(), [5, 3])
|
||||
self.assertAllEqual(rt2.bounding_shape(), [1, 7])
|
||||
self.assertAllEqual(rt3.bounding_shape(), [3, 7])
|
||||
|
||||
def test3DRaggedTensorWithOneRaggedDimension(self):
|
||||
values = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13]]
|
||||
rt1 = ragged_tensor.RaggedTensor.from_row_splits(values, [0, 2, 5, 6, 6, 7])
|
||||
rt2 = ragged_tensor.RaggedTensor.from_row_splits(values, [0, 7])
|
||||
rt3 = ragged_tensor.RaggedTensor.from_row_splits(values, [0, 0, 7, 7])
|
||||
self.assertRaggedEqual(rt1.bounding_shape(), [5, 3, 2])
|
||||
self.assertRaggedEqual(rt2.bounding_shape(), [1, 7, 2])
|
||||
self.assertRaggedEqual(rt3.bounding_shape(), [3, 7, 2])
|
||||
self.assertAllEqual(rt1.bounding_shape(), [5, 3, 2])
|
||||
self.assertAllEqual(rt2.bounding_shape(), [1, 7, 2])
|
||||
self.assertAllEqual(rt3.bounding_shape(), [3, 7, 2])
|
||||
|
||||
def testExplicitAxisOptimizations(self):
|
||||
rt = ragged_tensor.RaggedTensor.from_row_splits(b'a b c d e f g'.split(),
|
||||
[0, 2, 5, 6, 6, 7])
|
||||
self.assertRaggedEqual(rt.bounding_shape(0), 5)
|
||||
self.assertRaggedEqual(rt.bounding_shape(1), 3)
|
||||
self.assertRaggedEqual(rt.bounding_shape([1, 0]), [3, 5])
|
||||
self.assertAllEqual(rt.bounding_shape(0), 5)
|
||||
self.assertAllEqual(rt.bounding_shape(1), 3)
|
||||
self.assertAllEqual(rt.bounding_shape([1, 0]), [3, 5])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -26,27 +26,20 @@ from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.ops.ragged import ragged_tensor_shape
|
||||
from tensorflow.python.ops.ragged import ragged_test_util
|
||||
from tensorflow.python.ops.ragged.ragged_tensor_shape import RaggedTensorDynamicShape
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RaggedTensorShapeTest(ragged_test_util.RaggedTensorTestCase,
|
||||
class RaggedTensorShapeTest(test_util.TensorFlowTestCase,
|
||||
parameterized.TestCase):
|
||||
|
||||
def assertShapeEq(self, x, y):
|
||||
assert isinstance(x, RaggedTensorDynamicShape)
|
||||
assert isinstance(y, RaggedTensorDynamicShape)
|
||||
x_partitioned_dim_sizes = [
|
||||
self.eval_to_list(splits) #
|
||||
for splits in x.partitioned_dim_sizes
|
||||
]
|
||||
y_partitioned_dim_sizes = [
|
||||
self.eval_to_list(splits) #
|
||||
for splits in y.partitioned_dim_sizes
|
||||
]
|
||||
self.assertEqual(x_partitioned_dim_sizes, y_partitioned_dim_sizes)
|
||||
self.assertLen(x.partitioned_dim_sizes, len(y.partitioned_dim_sizes))
|
||||
for x_dims, y_dims in zip(x.partitioned_dim_sizes, y.partitioned_dim_sizes):
|
||||
self.assertAllEqual(x_dims, y_dims)
|
||||
self.assertAllEqual(x.inner_dim_sizes, y.inner_dim_sizes)
|
||||
|
||||
@parameterized.parameters([
|
||||
@ -422,7 +415,7 @@ class RaggedTensorShapeTest(ragged_test_util.RaggedTensorTestCase,
|
||||
result = ragged_tensor_shape.broadcast_to(x, shape)
|
||||
self.assertEqual(
|
||||
getattr(result, 'ragged_rank', 0), getattr(expected, 'ragged_rank', 0))
|
||||
self.assertRaggedEqual(result, expected)
|
||||
self.assertAllEqual(result, expected)
|
||||
|
||||
@parameterized.parameters(
|
||||
[
|
||||
@ -484,7 +477,7 @@ class RaggedTensorShapeTest(ragged_test_util.RaggedTensorTestCase,
|
||||
self.assertEqual(expected_rrank, result_rrank)
|
||||
if hasattr(expected, 'tolist'):
|
||||
expected = expected.tolist()
|
||||
self.assertRaggedEqual(result, expected)
|
||||
self.assertAllEqual(result, expected)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -34,7 +34,6 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_math_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor_value
|
||||
from tensorflow.python.ops.ragged import ragged_test_util
|
||||
from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensor
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
@ -114,7 +113,7 @@ EXAMPLE_RAGGED_TENSOR_4D_VALUES = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10],
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
|
||||
class RaggedTensorTest(test_util.TensorFlowTestCase,
|
||||
parameterized.TestCase):
|
||||
longMessage = True # Property in unittest.Testcase. pylint: disable=invalid-name
|
||||
|
||||
@ -126,7 +125,7 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
|
||||
# From section: "Component Tensors"
|
||||
rt = RaggedTensor.from_row_splits(
|
||||
values=[3, 1, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8])
|
||||
self.assertRaggedEqual(rt, [[3, 1, 4, 1], [], [5, 9, 2], [6], []])
|
||||
self.assertAllEqual(rt, [[3, 1, 4, 1], [], [5, 9, 2], [6], []])
|
||||
del rt
|
||||
|
||||
# From section: "Alternative Row-Partitioning Schemes"
|
||||
@ -138,7 +137,7 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
|
||||
rt4 = RaggedTensor.from_row_starts(values, row_starts=[0, 4, 4, 7, 8])
|
||||
rt5 = RaggedTensor.from_row_limits(values, row_limits=[4, 4, 7, 8, 8])
|
||||
for rt in (rt1, rt2, rt3, rt4, rt5):
|
||||
self.assertRaggedEqual(rt, [[3, 1, 4, 1], [], [5, 9, 2], [6], []])
|
||||
self.assertAllEqual(rt, [[3, 1, 4, 1], [], [5, 9, 2], [6], []])
|
||||
del rt1, rt2, rt3, rt4, rt5
|
||||
|
||||
# From section: "Multiple Ragged Dimensions"
|
||||
@ -147,8 +146,8 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
|
||||
outer_rt = RaggedTensor.from_row_splits(
|
||||
values=inner_rt, row_splits=[0, 3, 3, 5])
|
||||
self.assertEqual(outer_rt.ragged_rank, 2)
|
||||
self.assertEqual(
|
||||
self.eval_to_list(outer_rt),
|
||||
self.assertAllEqual(
|
||||
outer_rt,
|
||||
[[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]])
|
||||
del inner_rt, outer_rt
|
||||
|
||||
@ -156,15 +155,15 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
|
||||
rt = RaggedTensor.from_nested_row_splits(
|
||||
flat_values=[3, 1, 4, 1, 5, 9, 2, 6],
|
||||
nested_row_splits=([0, 3, 3, 5], [0, 4, 4, 7, 8, 8]))
|
||||
self.assertEqual(
|
||||
self.eval_to_list(rt), [[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]])
|
||||
self.assertAllEqual(
|
||||
rt, [[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]])
|
||||
del rt
|
||||
|
||||
# From section: "Uniform Inner Dimensions"
|
||||
rt = RaggedTensor.from_row_splits(
|
||||
values=array_ops.ones([5, 3]), row_splits=[0, 2, 5])
|
||||
self.assertEqual(
|
||||
self.eval_to_list(rt),
|
||||
self.assertAllEqual(
|
||||
rt,
|
||||
[[[1, 1, 1], [1, 1, 1]], [[1, 1, 1], [1, 1, 1], [1, 1, 1]]])
|
||||
self.assertEqual(rt.shape.as_list(), [2, None, 3])
|
||||
del rt
|
||||
@ -211,8 +210,8 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
|
||||
row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)
|
||||
rt = RaggedTensor(values=values, row_splits=row_splits, internal=True)
|
||||
|
||||
self.assertEqual(
|
||||
self.eval_to_list(rt),
|
||||
self.assertAllEqual(
|
||||
rt,
|
||||
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
|
||||
|
||||
def testRaggedTensorConstructionErrors(self):
|
||||
@ -267,9 +266,9 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
|
||||
self.assertIs(rt_values, values)
|
||||
self.assertIs(rt_value_rowids, value_rowids) # cached_value_rowids
|
||||
self.assertAllEqual(rt_value_rowids, value_rowids)
|
||||
self.assertEqual(self.eval_to_list(rt_nrows), 5)
|
||||
self.assertEqual(
|
||||
self.eval_to_list(rt),
|
||||
self.assertAllEqual(rt_nrows, 5)
|
||||
self.assertAllEqual(
|
||||
rt,
|
||||
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
|
||||
|
||||
def testFromValueRowIdsWithDerivedNRowsDynamic(self):
|
||||
@ -293,9 +292,9 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
|
||||
self.assertIs(rt_values, values)
|
||||
self.assertIs(rt_value_rowids, value_rowids) # cached_value_rowids
|
||||
self.assertAllEqual(rt_value_rowids, value_rowids)
|
||||
self.assertEqual(self.eval_to_list(rt_nrows), 5)
|
||||
self.assertEqual(
|
||||
self.eval_to_list(rt),
|
||||
self.assertAllEqual(rt_nrows, 5)
|
||||
self.assertAllEqual(
|
||||
rt,
|
||||
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
|
||||
|
||||
def testFromValueRowIdsWithExplicitNRows(self):
|
||||
@ -316,8 +315,8 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
|
||||
self.assertIs(rt_values, values)
|
||||
self.assertIs(rt_value_rowids, value_rowids) # cached_value_rowids
|
||||
self.assertIs(rt_nrows, nrows) # cached_nrows
|
||||
self.assertEqual(
|
||||
self.eval_to_list(rt),
|
||||
self.assertAllEqual(
|
||||
rt,
|
||||
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g'], [], []])
|
||||
|
||||
def testFromValueRowIdsWithExplicitNRowsEqualToDefault(self):
|
||||
@ -340,8 +339,8 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
|
||||
self.assertIs(rt_nrows, nrows) # cached_nrows
|
||||
self.assertAllEqual(rt_value_rowids, value_rowids)
|
||||
self.assertAllEqual(rt_nrows, nrows)
|
||||
self.assertEqual(
|
||||
self.eval_to_list(rt),
|
||||
self.assertAllEqual(
|
||||
rt,
|
||||
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
|
||||
|
||||
def testFromValueRowIdsWithEmptyValues(self):
|
||||
@ -352,8 +351,8 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
|
||||
self.assertEqual(rt.ragged_rank, 1)
|
||||
self.assertEqual(rt.values.shape.as_list(), [0])
|
||||
self.assertEqual(rt.value_rowids().shape.as_list(), [0])
|
||||
self.assertEqual(self.eval_to_list(rt_nrows), 0)
|
||||
self.assertEqual(self.eval_to_list(rt), [])
|
||||
self.assertAllEqual(rt_nrows, 0)
|
||||
self.assertAllEqual(rt, [])
|
||||
|
||||
def testFromRowSplits(self):
|
||||
values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
|
||||
@ -370,9 +369,9 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
|
||||
|
||||
self.assertIs(rt_values, values)
|
||||
self.assertIs(rt_row_splits, row_splits)
|
||||
self.assertEqual(self.eval_to_list(rt_nrows), 5)
|
||||
self.assertEqual(
|
||||
self.eval_to_list(rt),
|
||||
self.assertAllEqual(rt_nrows, 5)
|
||||
self.assertAllEqual(
|
||||
rt,
|
||||
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
|
||||
|
||||
def testFromRowSplitsWithEmptySplits(self):
|
||||
@ -394,10 +393,10 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
|
||||
rt_nrows = rt.nrows()
|
||||
|
||||
self.assertIs(rt_values, values)
|
||||
self.assertEqual(self.eval_to_list(rt_nrows), 5)
|
||||
self.assertAllEqual(rt_nrows, 5)
|
||||
self.assertAllEqual(rt_row_starts, row_starts)
|
||||
self.assertEqual(
|
||||
self.eval_to_list(rt),
|
||||
self.assertAllEqual(
|
||||
rt,
|
||||
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
|
||||
|
||||
def testFromRowLimits(self):
|
||||
@ -414,10 +413,10 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
|
||||
rt_nrows = rt.nrows()
|
||||
|
||||
self.assertIs(rt_values, values)
|
||||
self.assertEqual(self.eval_to_list(rt_nrows), 5)
|
||||
self.assertAllEqual(rt_nrows, 5)
|
||||
self.assertAllEqual(rt_row_limits, row_limits)
|
||||
self.assertEqual(
|
||||
self.eval_to_list(rt),
|
||||
self.assertAllEqual(
|
||||
rt,
|
||||
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
|
||||
|
||||
def testFromRowLengths(self):
|
||||
@ -435,10 +434,10 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
|
||||
|
||||
self.assertIs(rt_values, values)
|
||||
self.assertIs(rt_row_lengths, row_lengths) # cached_nrows
|
||||
self.assertEqual(self.eval_to_list(rt_nrows), 5)
|
||||
self.assertAllEqual(rt_nrows, 5)
|
||||
self.assertAllEqual(rt_row_lengths, row_lengths)
|
||||
self.assertEqual(
|
||||
self.eval_to_list(rt),
|
||||
self.assertAllEqual(
|
||||
rt,
|
||||
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
|
||||
|
||||
def testFromNestedValueRowIdsWithDerivedNRows(self):
|
||||
@ -461,8 +460,8 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
|
||||
self.assertIs(rt_values_values, values)
|
||||
self.assertAllEqual(rt_value_rowids, nested_value_rowids[0])
|
||||
self.assertAllEqual(rt_values_value_rowids, nested_value_rowids[1])
|
||||
self.assertEqual(
|
||||
self.eval_to_list(rt),
|
||||
self.assertAllEqual(
|
||||
rt,
|
||||
[[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], [[b'f'], [b'g']]])
|
||||
|
||||
def testFromNestedValueRowIdsWithExplicitNRows(self):
|
||||
@ -494,8 +493,8 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
|
||||
self.assertAllEqual(rt_values_value_rowids, nested_value_rowids[1])
|
||||
self.assertAllEqual(rt_nrows, nrows[0])
|
||||
self.assertAllEqual(rt_values_nrows, nrows[1])
|
||||
self.assertEqual(
|
||||
self.eval_to_list(rt), [[[b'a', b'b'], []], [[b'c', b'd', b'e']], [],
|
||||
self.assertAllEqual(
|
||||
rt, [[[b'a', b'b'], []], [[b'c', b'd', b'e']], [],
|
||||
[[b'f'], [b'g'], []], [], []])
|
||||
|
||||
def testFromNestedValueRowIdsWithExplicitNRowsMismatch(self):
|
||||
@ -541,8 +540,8 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
|
||||
self.assertIs(rt_values_values, flat_values)
|
||||
self.assertIs(rt_row_splits, nested_row_splits[0])
|
||||
self.assertIs(rt_values_row_splits, nested_row_splits[1])
|
||||
self.assertEqual(
|
||||
self.eval_to_list(rt),
|
||||
self.assertAllEqual(
|
||||
rt,
|
||||
[[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], [[b'f'], [b'g']]])
|
||||
|
||||
def testFromNestedRowSplitsWithNonListInput(self):
|
||||
@ -609,7 +608,7 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
|
||||
rt2 = RaggedTensor.from_value_rowids(values, value_rowids)
|
||||
|
||||
for rt in [rt1, rt2]:
|
||||
self.assertRaggedEqual(
|
||||
self.assertAllEqual(
|
||||
rt, [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
|
||||
self.assertAllEqual(rt.values, [b'a', b'b', b'c', b'd', b'e', b'f', b'g'])
|
||||
self.assertEqual(rt.values.shape.dims[0].value, 7)
|
||||
@ -632,26 +631,26 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
|
||||
rt2 = RaggedTensor.from_value_rowids(values, value_rowids)
|
||||
|
||||
for rt in [rt1, rt2]:
|
||||
self.assertEqual(
|
||||
self.eval_to_list(rt),
|
||||
self.assertAllEqual(
|
||||
rt,
|
||||
[[[0, 1], [2, 3]], [], [[4, 5], [6, 7], [8, 9]], [[10, 11]],
|
||||
[[12, 13]]])
|
||||
self.assertEqual(
|
||||
self.eval_to_list(rt.values),
|
||||
self.assertAllEqual(
|
||||
rt.values,
|
||||
[[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13]])
|
||||
self.assertEqual(rt.values.shape.dims[0].value, 7)
|
||||
self.assertEqual(
|
||||
self.eval_to_list(rt.value_rowids()), [0, 0, 2, 2, 2, 3, 4])
|
||||
self.assertEqual(self.eval_to_list(rt.nrows()), 5)
|
||||
self.assertEqual(self.eval_to_list(rt.row_splits), [0, 2, 2, 5, 6, 7])
|
||||
self.assertEqual(self.eval_to_list(rt.row_starts()), [0, 2, 2, 5, 6])
|
||||
self.assertEqual(self.eval_to_list(rt.row_limits()), [2, 2, 5, 6, 7])
|
||||
self.assertEqual(self.eval_to_list(rt.row_lengths()), [2, 0, 3, 1, 1])
|
||||
self.assertEqual(
|
||||
self.eval_to_list(rt.flat_values),
|
||||
self.assertAllEqual(
|
||||
rt.value_rowids(), [0, 0, 2, 2, 2, 3, 4])
|
||||
self.assertAllEqual(rt.nrows(), 5)
|
||||
self.assertAllEqual(rt.row_splits, [0, 2, 2, 5, 6, 7])
|
||||
self.assertAllEqual(rt.row_starts(), [0, 2, 2, 5, 6])
|
||||
self.assertAllEqual(rt.row_limits(), [2, 2, 5, 6, 7])
|
||||
self.assertAllEqual(rt.row_lengths(), [2, 0, 3, 1, 1])
|
||||
self.assertAllEqual(
|
||||
rt.flat_values,
|
||||
[[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13]])
|
||||
self.assertEqual([self.eval_to_list(s) for s in rt.nested_row_splits],
|
||||
[[0, 2, 2, 5, 6, 7]])
|
||||
self.assertLen(rt.nested_row_splits, 1)
|
||||
self.assertAllEqual(rt.nested_row_splits[0], [0, 2, 2, 5, 6, 7])
|
||||
|
||||
def testRaggedTensorAccessors_3d_with_ragged_rank_2(self):
|
||||
values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
|
||||
@ -667,24 +666,25 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
|
||||
rt2 = RaggedTensor.from_nested_value_rowids(values, nested_value_rowids)
|
||||
|
||||
for rt in [rt1, rt2]:
|
||||
self.assertEqual(
|
||||
self.eval_to_list(rt),
|
||||
self.assertAllEqual(
|
||||
rt,
|
||||
[[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], [[b'f'], [b'g']]])
|
||||
self.assertEqual(
|
||||
self.eval_to_list(rt.values),
|
||||
self.assertAllEqual(
|
||||
rt.values,
|
||||
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
|
||||
self.assertEqual(rt.values.shape.dims[0].value, 5)
|
||||
self.assertEqual(self.eval_to_list(rt.value_rowids()), [0, 0, 1, 3, 3])
|
||||
self.assertEqual(self.eval_to_list(rt.nrows()), 4)
|
||||
self.assertEqual(self.eval_to_list(rt.row_splits), [0, 2, 3, 3, 5])
|
||||
self.assertEqual(self.eval_to_list(rt.row_starts()), [0, 2, 3, 3])
|
||||
self.assertEqual(self.eval_to_list(rt.row_limits()), [2, 3, 3, 5])
|
||||
self.assertEqual(self.eval_to_list(rt.row_lengths()), [2, 1, 0, 2])
|
||||
self.assertEqual(
|
||||
self.eval_to_list(rt.flat_values),
|
||||
self.assertAllEqual(rt.value_rowids(), [0, 0, 1, 3, 3])
|
||||
self.assertAllEqual(rt.nrows(), 4)
|
||||
self.assertAllEqual(rt.row_splits, [0, 2, 3, 3, 5])
|
||||
self.assertAllEqual(rt.row_starts(), [0, 2, 3, 3])
|
||||
self.assertAllEqual(rt.row_limits(), [2, 3, 3, 5])
|
||||
self.assertAllEqual(rt.row_lengths(), [2, 1, 0, 2])
|
||||
self.assertAllEqual(
|
||||
rt.flat_values,
|
||||
[b'a', b'b', b'c', b'd', b'e', b'f', b'g'])
|
||||
self.assertEqual([self.eval_to_list(s) for s in rt.nested_row_splits],
|
||||
[[0, 2, 3, 3, 5], [0, 2, 2, 5, 6, 7]])
|
||||
self.assertLen(rt.nested_row_splits, 2)
|
||||
self.assertAllEqual(rt.nested_row_splits[0], [0, 2, 3, 3, 5])
|
||||
self.assertAllEqual(rt.nested_row_splits[1], [0, 2, 2, 5, 6, 7])
|
||||
|
||||
#=============================================================================
|
||||
# RaggedTensor.shape
|
||||
@ -742,12 +742,12 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
|
||||
"""
|
||||
tensor_slice_spec1 = _make_tensor_slice_spec(slice_spec, True)
|
||||
tensor_slice_spec2 = _make_tensor_slice_spec(slice_spec, False)
|
||||
value1 = self.eval_to_list(rt.__getitem__(slice_spec))
|
||||
value2 = self.eval_to_list(rt.__getitem__(tensor_slice_spec1))
|
||||
value3 = self.eval_to_list(rt.__getitem__(tensor_slice_spec2))
|
||||
self.assertEqual(value1, expected, 'slice_spec=%s' % (slice_spec,))
|
||||
self.assertEqual(value2, expected, 'slice_spec=%s' % (slice_spec,))
|
||||
self.assertEqual(value3, expected, 'slice_spec=%s' % (slice_spec,))
|
||||
value1 = rt.__getitem__(slice_spec)
|
||||
value2 = rt.__getitem__(tensor_slice_spec1)
|
||||
value3 = rt.__getitem__(tensor_slice_spec2)
|
||||
self.assertAllEqual(value1, expected, 'slice_spec=%s' % (slice_spec,))
|
||||
self.assertAllEqual(value2, expected, 'slice_spec=%s' % (slice_spec,))
|
||||
self.assertAllEqual(value3, expected, 'slice_spec=%s' % (slice_spec,))
|
||||
|
||||
def _TestGetItemException(self, rt, slice_spec, expected, message):
|
||||
"""Helper function for testing RaggedTensor.__getitem__ exceptions."""
|
||||
@ -826,7 +826,7 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
|
||||
rt = RaggedTensor.from_row_splits(EXAMPLE_RAGGED_TENSOR_2D_VALUES,
|
||||
EXAMPLE_RAGGED_TENSOR_2D_SPLITS)
|
||||
|
||||
self.assertEqual(self.eval_to_list(rt), EXAMPLE_RAGGED_TENSOR_2D)
|
||||
self.assertAllEqual(rt, EXAMPLE_RAGGED_TENSOR_2D)
|
||||
self._TestGetItem(rt, slice_spec, expected)
|
||||
|
||||
# pylint: disable=invalid-slice-index
|
||||
@ -872,7 +872,7 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
|
||||
rt = RaggedTensor.from_row_splits(EXAMPLE_RAGGED_TENSOR_2D_VALUES,
|
||||
EXAMPLE_RAGGED_TENSOR_2D_SPLITS)
|
||||
|
||||
self.assertEqual(self.eval_to_list(rt), EXAMPLE_RAGGED_TENSOR_2D)
|
||||
self.assertAllEqual(rt, EXAMPLE_RAGGED_TENSOR_2D)
|
||||
self._TestGetItemException(rt, slice_spec, expected, message)
|
||||
|
||||
@parameterized.parameters(
|
||||
@ -946,7 +946,7 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
|
||||
rt = RaggedTensor.from_nested_row_splits(
|
||||
EXAMPLE_RAGGED_TENSOR_4D_VALUES,
|
||||
[EXAMPLE_RAGGED_TENSOR_4D_SPLITS1, EXAMPLE_RAGGED_TENSOR_4D_SPLITS2])
|
||||
self.assertEqual(self.eval_to_list(rt), EXAMPLE_RAGGED_TENSOR_4D)
|
||||
self.assertAllEqual(rt, EXAMPLE_RAGGED_TENSOR_4D)
|
||||
self._TestGetItem(rt, slice_spec, expected)
|
||||
|
||||
@parameterized.parameters(
|
||||
@ -973,7 +973,7 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
|
||||
rt = RaggedTensor.from_nested_row_splits(
|
||||
EXAMPLE_RAGGED_TENSOR_4D_VALUES,
|
||||
[EXAMPLE_RAGGED_TENSOR_4D_SPLITS1, EXAMPLE_RAGGED_TENSOR_4D_SPLITS2])
|
||||
self.assertEqual(self.eval_to_list(rt), EXAMPLE_RAGGED_TENSOR_4D)
|
||||
self.assertAllEqual(rt, EXAMPLE_RAGGED_TENSOR_4D)
|
||||
self._TestGetItemException(rt, slice_spec, expected, message)
|
||||
|
||||
@parameterized.parameters(
|
||||
@ -1015,7 +1015,7 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
|
||||
EXAMPLE_RAGGED_TENSOR_2D_SPLITS, dtype=dtypes.int64)
|
||||
splits = array_ops.placeholder_with_default(splits, None)
|
||||
rt = RaggedTensor.from_row_splits(EXAMPLE_RAGGED_TENSOR_2D_VALUES, splits)
|
||||
self.assertEqual(self.eval_to_list(rt), EXAMPLE_RAGGED_TENSOR_2D)
|
||||
self.assertAllEqual(rt, EXAMPLE_RAGGED_TENSOR_2D)
|
||||
self._TestGetItem(rt, slice_spec, expected)
|
||||
|
||||
@parameterized.parameters(
|
||||
@ -1042,23 +1042,23 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
|
||||
rt_newaxis3 = rt[:, :, :, array_ops.newaxis]
|
||||
rt_newaxis4 = rt[:, :, :, :, array_ops.newaxis]
|
||||
|
||||
self.assertEqual(
|
||||
self.eval_to_list(rt),
|
||||
self.assertAllEqual(
|
||||
rt,
|
||||
[[[[b'a', b'b'], [b'c', b'd']], [], [[b'e', b'f']]], []])
|
||||
self.assertEqual(
|
||||
self.eval_to_list(rt_newaxis0),
|
||||
self.assertAllEqual(
|
||||
rt_newaxis0,
|
||||
[[[[[b'a', b'b'], [b'c', b'd']], [], [[b'e', b'f']]], []]])
|
||||
self.assertEqual(
|
||||
self.eval_to_list(rt_newaxis1),
|
||||
self.assertAllEqual(
|
||||
rt_newaxis1,
|
||||
[[[[[b'a', b'b'], [b'c', b'd']], [], [[b'e', b'f']]]], [[]]])
|
||||
self.assertEqual(
|
||||
self.eval_to_list(rt_newaxis2),
|
||||
self.assertAllEqual(
|
||||
rt_newaxis2,
|
||||
[[[[[b'a', b'b'], [b'c', b'd']]], [[]], [[[b'e', b'f']]]], []])
|
||||
self.assertEqual(
|
||||
self.eval_to_list(rt_newaxis3),
|
||||
self.assertAllEqual(
|
||||
rt_newaxis3,
|
||||
[[[[[b'a', b'b']], [[b'c', b'd']]], [], [[[b'e', b'f']]]], []])
|
||||
self.assertEqual(
|
||||
self.eval_to_list(rt_newaxis4),
|
||||
self.assertAllEqual(
|
||||
rt_newaxis4,
|
||||
[[[[[b'a'], [b'b']], [[b'c'], [b'd']]], [], [[[b'e'], [b'f']]]], []])
|
||||
|
||||
self.assertEqual(rt.ragged_rank, 2)
|
||||
@ -1121,14 +1121,14 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
|
||||
rt2_times_10 = rt2.with_flat_values(rt2.flat_values * 10)
|
||||
rt1_expanded = rt1.with_values(array_ops.expand_dims(rt1.values, axis=1))
|
||||
|
||||
self.assertEqual(
|
||||
self.eval_to_list(rt1_plus_10),
|
||||
self.assertAllEqual(
|
||||
rt1_plus_10,
|
||||
[[11, 12], [13, 14, 15], [16], [], [17]])
|
||||
self.assertEqual(
|
||||
self.eval_to_list(rt2_times_10),
|
||||
self.assertAllEqual(
|
||||
rt2_times_10,
|
||||
[[[10, 20], [30, 40, 50]], [[60]], [], [[], [70]]])
|
||||
self.assertEqual(
|
||||
self.eval_to_list(rt1_expanded),
|
||||
self.assertAllEqual(
|
||||
rt1_expanded,
|
||||
[[[1], [2]], [[3], [4], [5]], [[6]], [], [[7]]])
|
||||
|
||||
#=============================================================================
|
||||
@ -1494,7 +1494,7 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
|
||||
et = rt._to_variant()
|
||||
round_trip_rt = RaggedTensor._from_variant(
|
||||
et, dtype, output_ragged_rank=ragged_rank)
|
||||
self.assertRaggedEqual(rt, round_trip_rt)
|
||||
self.assertAllEqual(rt, round_trip_rt)
|
||||
|
||||
def testBatchedVariantRoundTripInputRaggedRankInferred(self):
|
||||
ragged_rank = 1
|
||||
@ -1510,7 +1510,7 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
|
||||
expected_rt = ragged_factory_ops.constant([[[0], [1]], [[2], [3]], [[4],
|
||||
[5]],
|
||||
[[6], [7]], [[8], [9]]])
|
||||
self.assertRaggedEqual(decoded_rt, expected_rt)
|
||||
self.assertAllEqual(decoded_rt, expected_rt)
|
||||
|
||||
def testBatchedVariantRoundTripWithInputRaggedRank(self):
|
||||
ragged_rank = 1
|
||||
@ -1527,7 +1527,7 @@ class RaggedTensorTest(ragged_test_util.RaggedTensorTestCase,
|
||||
expected_rt = ragged_factory_ops.constant([[[0], [1]], [[2], [3]], [[4],
|
||||
[5]],
|
||||
[[6], [7]], [[8], [9]]])
|
||||
self.assertRaggedEqual(decoded_rt, expected_rt)
|
||||
self.assertAllEqual(decoded_rt, expected_rt)
|
||||
|
||||
def testFromVariantInvalidParams(self):
|
||||
rt = ragged_factory_ops.constant([[0], [1], [2], [3]])
|
||||
|
@ -1,106 +0,0 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
"""Test utils for tensorflow RaggedTensors."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.ops.ragged import ragged_tensor_value
|
||||
|
||||
|
||||
class RaggedTensorTestCase(test_util.TensorFlowTestCase):
|
||||
"""Base class for RaggedTensor test cases."""
|
||||
|
||||
def _GetPyList(self, a):
|
||||
"""Converts a to a nested python list."""
|
||||
if isinstance(a, ragged_tensor.RaggedTensor):
|
||||
return self.evaluate(a).to_list()
|
||||
elif isinstance(a, ops.Tensor):
|
||||
a = self.evaluate(a)
|
||||
return a.tolist() if isinstance(a, np.ndarray) else a
|
||||
elif isinstance(a, np.ndarray):
|
||||
return a.tolist()
|
||||
elif isinstance(a, ragged_tensor_value.RaggedTensorValue):
|
||||
return a.to_list()
|
||||
else:
|
||||
return np.array(a).tolist()
|
||||
|
||||
def assertRaggedEqual(self, a, b):
|
||||
"""Asserts that two potentially ragged tensors are equal."""
|
||||
a_list = self._GetPyList(a)
|
||||
b_list = self._GetPyList(b)
|
||||
self.assertEqual(a_list, b_list)
|
||||
|
||||
if not (isinstance(a, (list, tuple)) or isinstance(b, (list, tuple))):
|
||||
a_ragged_rank = a.ragged_rank if ragged_tensor.is_ragged(a) else 0
|
||||
b_ragged_rank = b.ragged_rank if ragged_tensor.is_ragged(b) else 0
|
||||
self.assertEqual(a_ragged_rank, b_ragged_rank)
|
||||
|
||||
def assertRaggedAlmostEqual(self, a, b, places=7):
|
||||
a_list = self._GetPyList(a)
|
||||
b_list = self._GetPyList(b)
|
||||
self.assertNestedListAlmostEqual(a_list, b_list, places, context='value')
|
||||
|
||||
if not (isinstance(a, (list, tuple)) or isinstance(b, (list, tuple))):
|
||||
a_ragged_rank = a.ragged_rank if ragged_tensor.is_ragged(a) else 0
|
||||
b_ragged_rank = b.ragged_rank if ragged_tensor.is_ragged(b) else 0
|
||||
self.assertEqual(a_ragged_rank, b_ragged_rank)
|
||||
|
||||
def assertNestedListAlmostEqual(self, a, b, places=7, context='value'):
|
||||
self.assertEqual(type(a), type(b))
|
||||
if isinstance(a, (list, tuple)):
|
||||
self.assertLen(a, len(b), 'Length differs for %s' % context)
|
||||
for i in range(len(a)):
|
||||
self.assertNestedListAlmostEqual(a[i], b[i], places,
|
||||
'%s[%s]' % (context, i))
|
||||
else:
|
||||
self.assertAlmostEqual(
|
||||
a, b, places,
|
||||
'%s != %s within %s places at %s' % (a, b, places, context))
|
||||
|
||||
def eval_to_list(self, tensor):
|
||||
value = self.evaluate(tensor)
|
||||
if ragged_tensor.is_ragged(value):
|
||||
return value.to_list()
|
||||
elif isinstance(value, np.ndarray):
|
||||
return value.tolist()
|
||||
else:
|
||||
return value
|
||||
|
||||
def _eval_tensor(self, tensor):
|
||||
if ragged_tensor.is_ragged(tensor):
|
||||
return ragged_tensor_value.RaggedTensorValue(
|
||||
self._eval_tensor(tensor.values),
|
||||
self._eval_tensor(tensor.row_splits))
|
||||
else:
|
||||
return test_util.TensorFlowTestCase._eval_tensor(self, tensor)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_pylist(item):
|
||||
"""Convert all (possibly nested) np.arrays contained in item to list."""
|
||||
# convert np.arrays in current level to list
|
||||
if np.ndim(item) == 0:
|
||||
return item
|
||||
level = (x.tolist() if isinstance(x, np.ndarray) else x for x in item)
|
||||
_normalize = RaggedTensorTestCase._normalize_pylist
|
||||
return [_normalize(el) if np.ndim(el) != 0 else el for el in level]
|
@ -26,12 +26,11 @@ from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops.ragged import ragged_array_ops
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_test_util
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RaggedTileOpTest(ragged_test_util.RaggedTensorTestCase,
|
||||
class RaggedTileOpTest(test_util.TensorFlowTestCase,
|
||||
parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters([
|
||||
@ -209,7 +208,7 @@ class RaggedTileOpTest(ragged_test_util.RaggedTensorTestCase,
|
||||
self.assertEqual(tiled.shape.ndims, rt.shape.ndims)
|
||||
if multiples_tensor is const_multiples:
|
||||
self.assertEqual(tiled.shape.as_list(), expected_shape)
|
||||
self.assertRaggedEqual(tiled, expected)
|
||||
self.assertAllEqual(tiled, expected)
|
||||
|
||||
def testRaggedTileWithTensorInput(self):
|
||||
# When the input is a `Tensor`, ragged_tile just delegates to tf.tile.
|
||||
@ -218,7 +217,7 @@ class RaggedTileOpTest(ragged_test_util.RaggedTensorTestCase,
|
||||
expected = [[1, 2, 1, 2], [3, 4, 3, 4],
|
||||
[1, 2, 1, 2], [3, 4, 3, 4],
|
||||
[1, 2, 1, 2], [3, 4, 3, 4]] # pyformat: disable
|
||||
self.assertRaggedEqual(tiled, expected)
|
||||
self.assertAllEqual(tiled, expected)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -28,12 +28,11 @@ from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_functional_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.ops.ragged import ragged_test_util
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RaggedTensorToSparseOpTest(ragged_test_util.RaggedTensorTestCase):
|
||||
class RaggedTensorToSparseOpTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testDocStringExample(self):
|
||||
rt = ragged_factory_ops.constant([[1, 2, 3], [4], [], [5, 6]])
|
||||
@ -193,8 +192,8 @@ class RaggedTensorToSparseOpTest(ragged_test_util.RaggedTensorTestCase):
|
||||
|
||||
g1, g2 = gradients_impl.gradients(st.values,
|
||||
[rt1.flat_values, rt2.flat_values])
|
||||
self.assertRaggedEqual(g1, [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]])
|
||||
self.assertRaggedEqual(g2, [[2.0, 2.0], [2.0, 2.0], [2.0, 2.0]])
|
||||
self.assertAllEqual(g1, [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]])
|
||||
self.assertAllEqual(g2, [[2.0, 2.0], [2.0, 2.0], [2.0, 2.0]])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -19,17 +19,16 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_test_util
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RaggedTensorToTensorOpTest(ragged_test_util.RaggedTensorTestCase,
|
||||
class RaggedTensorToTensorOpTest(test_util.TensorFlowTestCase,
|
||||
parameterized.TestCase):
|
||||
|
||||
def testDocStringExamples(self):
|
||||
@ -105,10 +104,9 @@ class RaggedTensorToTensorOpTest(ragged_test_util.RaggedTensorTestCase,
|
||||
self.assertIsInstance(dt, ops.Tensor)
|
||||
self.assertEqual(rt.dtype, dt.dtype)
|
||||
self.assertTrue(dt.shape.is_compatible_with(rt.shape))
|
||||
self.assertAllEqual(self.eval_to_list(dt), expected)
|
||||
if expected_shape is not None:
|
||||
dt_shape = array_ops.shape(dt)
|
||||
self.assertAllEqual(dt_shape, expected_shape)
|
||||
expected = np.ndarray(expected_shape, buffer=np.array(expected))
|
||||
self.assertAllEqual(dt, expected)
|
||||
|
||||
@parameterized.parameters(
|
||||
{
|
||||
|
@ -24,7 +24,6 @@ import numpy as np
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops.ragged import ragged_test_util
|
||||
from tensorflow.python.ops.ragged import ragged_util
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
@ -43,7 +42,7 @@ TENSOR_4D = [[[[('%d%d%d%d' % (i, j, k, l)).encode('utf-8')
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RaggedUtilTest(ragged_test_util.RaggedTensorTestCase,
|
||||
class RaggedUtilTest(test_util.TensorFlowTestCase,
|
||||
parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters([
|
||||
|
@ -20,13 +20,12 @@ from __future__ import print_function
|
||||
from absl.testing import parameterized
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_test_util
|
||||
from tensorflow.python.ops.ragged import ragged_where_op
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RaggedWhereOpTest(ragged_test_util.RaggedTensorTestCase,
|
||||
class RaggedWhereOpTest(test_util.TensorFlowTestCase,
|
||||
parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters([
|
||||
@ -188,7 +187,7 @@ class RaggedWhereOpTest(ragged_test_util.RaggedTensorTestCase,
|
||||
]) # pyformat: disable
|
||||
def testRaggedWhere(self, condition, expected, x=None, y=None):
|
||||
result = ragged_where_op.where(condition, x, y)
|
||||
self.assertRaggedEqual(result, expected)
|
||||
self.assertAllEqual(result, expected)
|
||||
|
||||
@parameterized.parameters([
|
||||
dict(
|
||||
|
@ -91,7 +91,6 @@ COMMON_PIP_DEPS = [
|
||||
"//tensorflow/python/kernel_tests/random:util",
|
||||
"//tensorflow/python/kernel_tests/signal:test_util",
|
||||
"//tensorflow/python/kernel_tests/testdata:self_adjoint_eig_op_test_files",
|
||||
"//tensorflow/python/ops/ragged:ragged_test_util",
|
||||
"//tensorflow/python/saved_model:saved_model",
|
||||
"//tensorflow/python/tools:tools_pip",
|
||||
"//tensorflow/python/tools/api/generator:create_python_api",
|
||||
|
Loading…
Reference in New Issue
Block a user