Merged commit includes the following changes:
237836124 by A. Unique TensorFlower<gardener@tensorflow.org>: Internal changes -- 237831779 by A. Unique TensorFlower<gardener@tensorflow.org>: Fix rnn.reset_states() to throw proper error when before .build() is invoked. See https://github.com/tensorflow/tensorflow/issues/25852 for more details. -- 237814078 by A. Unique TensorFlower<gardener@tensorflow.org>: Fix GatherV2 converter. It assumed that the IGatherLayer has same output dimensions as tf.gather, which is not the case. -- 237814023 by A. Unique TensorFlower<gardener@tensorflow.org>: Split py_binary into py_binary and py_library to avoid having py_binary in deps. -- PiperOrigin-RevId: 237836124
This commit is contained in:
parent
2b9900638a
commit
6108cbd1db
@ -2243,8 +2243,8 @@ Status ConvertSqueeze(OpConverterParams* params) {
|
||||
// Make sure target dimension is size 1.
|
||||
if (input_dims[axis] != 1) {
|
||||
return errors::InvalidArgument(
|
||||
"Cannot squeeze a dimension which isn't size 1, at ",
|
||||
node_def.name());
|
||||
"Cannot squeeze ", axis, "th dimension ", input_dims[axis],
|
||||
" which isn't size 1, at ", node_def.name());
|
||||
}
|
||||
// Mark dim for removal by setting to 0.
|
||||
input_dims[axis] = 0;
|
||||
@ -3700,13 +3700,59 @@ Status ConvertGather(OpConverterParams* params) {
|
||||
int trt_axis = 0;
|
||||
TF_RETURN_IF_ERROR(ConvertAxis(axis[0], inputs.at(0).GetTrtDims().nbDims,
|
||||
node_def.name(), &trt_axis));
|
||||
TRT_TensorOrWeights params_tensor = inputs.at(0);
|
||||
TRT_TensorOrWeights indices_tensor = inputs.at(1);
|
||||
if (indices_tensor.batch_size() != 1) {
|
||||
return errors::InvalidArgument("Only indices with batch 1 are supported.");
|
||||
}
|
||||
// Both input are tensors, and the TF gather result will have rank:
|
||||
// (params.nbDims + 1) + (indices.nbDims + 1) - 1,
|
||||
// where "+ 1" adds the batch dim.
|
||||
const int tf_gather_output_rank = params_tensor.GetTrtDims().nbDims +
|
||||
indices_tensor.GetTrtDims().nbDims + 1;
|
||||
if (tf_gather_output_rank > nvinfer1::Dims::MAX_DIMS + 1) {
|
||||
return errors::InvalidArgument(
|
||||
"Result of gather has dimension greater than ",
|
||||
nvinfer1::Dims::MAX_DIMS + 1);
|
||||
}
|
||||
if (params->validation_only) return Status::OK();
|
||||
|
||||
// Note on how IGatherLayer works: if both the data and indices tensors have
|
||||
// a batch size dimension of size N, it performs:
|
||||
// for batchid in xrange(N):
|
||||
// output[batchid, a0, ..., an, i, ..., j, b0, ..., bn] = (
|
||||
// data[batchid, a0, ..., an, indices[batchid, i, ..., j] b0, ..., bn])
|
||||
nvinfer1::IGatherLayer* layer = params->converter->network()->addGather(
|
||||
*const_cast<nvinfer1::ITensor*>(inputs.at(0).tensor()),
|
||||
*const_cast<nvinfer1::ITensor*>(inputs.at(1).tensor()), trt_axis);
|
||||
*const_cast<nvinfer1::ITensor*>(params_tensor.tensor()),
|
||||
*const_cast<nvinfer1::ITensor*>(indices_tensor.tensor()), trt_axis);
|
||||
TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
|
||||
params->outputs->push_back(TRT_TensorOrWeights(layer->getOutput(0)));
|
||||
|
||||
nvinfer1::ITensor* gather_output = layer->getOutput(0);
|
||||
nvinfer1::Dims trt_gather_output_dims = gather_output->getDimensions();
|
||||
// Note for the "- 2": one is for the output batch dim encapsulated by TF-TRT,
|
||||
// and the other is for the output dimension that is squeezed by IGatherLayer
|
||||
// because of the implicit batch dim in the indices (see the above note).
|
||||
if (trt_gather_output_dims.nbDims != tf_gather_output_rank - 2) {
|
||||
return errors::Internal(
|
||||
"Get unexpected output dimensions of IGatherLayer. Expect nbDims: ",
|
||||
tf_gather_output_rank - 2,
|
||||
", actual nbDims: ", trt_gather_output_dims.nbDims);
|
||||
}
|
||||
// Reshape the output so after adding the implicit batch dim it'll match the
|
||||
// output shape of TF GatherV2.
|
||||
for (int i = trt_gather_output_dims.nbDims; i > trt_axis; --i) {
|
||||
trt_gather_output_dims.d[i] = trt_gather_output_dims.d[i - 1];
|
||||
}
|
||||
trt_gather_output_dims.d[trt_axis] = 1;
|
||||
++trt_gather_output_dims.nbDims;
|
||||
|
||||
const nvinfer1::ITensor* output_tensor = nullptr;
|
||||
TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape(
|
||||
TRT_TensorOrWeights(gather_output), trt_gather_output_dims,
|
||||
/*validation_only=*/false, &output_tensor));
|
||||
|
||||
params->outputs->push_back(
|
||||
TRT_TensorOrWeights(const_cast<nvinfer1::ITensor*>(output_tensor)));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -3400,14 +3400,23 @@ void TestConvertGather(OpConverterTest* test) {
|
||||
|
||||
// Input is the same {1, 2, 3, 4, 5, 6} for all cases.
|
||||
const int kGatherOKCases = 5;
|
||||
const std::vector<CType> params_input = {CType(1), CType(2), CType(3),
|
||||
CType(4), CType(5), CType(6)};
|
||||
TestParams ok_params[kGatherOKCases] = {
|
||||
// Vector indices (output is rank(params)).
|
||||
TestParams{{1, 2, 3}, {1}, {0}, 3, {1, 2, 1}, {1, 4}},
|
||||
TestParams{{1, 2, 3}, {1}, {1}, 3, {1, 2, 1}, {2, 5}},
|
||||
TestParams{{1, 2, 3}, {1}, {2}, -1, {1, 2, 1}, {3, 6}},
|
||||
TestParams{{1, 2, 3}, {3}, {2, 0, 1}, 3, {1, 2, 3}, {3, 1, 2, 6, 4, 5}},
|
||||
// Higher rank indices (output is rank(params) + rank(indices) - 1).
|
||||
TestParams{{1, 2, 3}, {1, 1}, {0}, 2, {1, 1, 1, 3}, {1, 2, 3}},
|
||||
// Indices are always of rank>1, and output rank is
|
||||
// rank(params) + rank(indices) - 1.
|
||||
// TODO(laigd): do we support 0-rank ITensor as indices?
|
||||
TestParams{{1, 2, 3}, {1}, {0}, 3, {1, 2, 1, 1}, {1, 4}},
|
||||
TestParams{{1, 2, 3}, {1}, {1}, 3, {1, 2, 1, 1}, {2, 5}},
|
||||
TestParams{{1, 2, 3}, {1}, {2}, -1, {1, 2, 1, 1}, {3, 6}},
|
||||
TestParams{
|
||||
{1, 2, 3}, {3}, {2, 0, 1}, 3, {1, 2, 1, 3}, {3, 1, 2, 6, 4, 5}},
|
||||
TestParams{{3, 2},
|
||||
{2, 2},
|
||||
{0, 0, 1, 0},
|
||||
2,
|
||||
{3, 1, 2, 2},
|
||||
{1, 1, 2, 1, 3, 3, 4, 3, 5, 5, 6, 5}},
|
||||
};
|
||||
|
||||
// Ok.
|
||||
@ -3426,14 +3435,12 @@ void TestConvertGather(OpConverterTest* test) {
|
||||
output.tensor()->getDimensions());
|
||||
|
||||
// Create input in CType and convert expected output to CType.
|
||||
std::vector<CType> inputs = {CType(1), CType(2), CType(3),
|
||||
CType(4), CType(5), CType(6)};
|
||||
std::vector<CType> converted_expected_output(
|
||||
ok_params[i].expected_output.begin(),
|
||||
ok_params[i].expected_output.end());
|
||||
|
||||
const DataVec input_data{
|
||||
{"params", test::AsTensor<CType>(inputs)},
|
||||
{"params", test::AsTensor<CType>(params_input)},
|
||||
{"indices", test::AsTensor<int32>(ok_params[i].indices)}};
|
||||
DataVec output_data{
|
||||
{"my_gather",
|
||||
|
@ -72,7 +72,7 @@ py_test(
|
||||
"//tensorflow/lite/python:lite",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python/tools:optimize_for_inference",
|
||||
"//tensorflow/python/tools:optimize_for_inference_main_lib",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
@ -95,7 +95,7 @@ py_test(
|
||||
"//tensorflow/lite/python:lite",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python/tools:optimize_for_inference",
|
||||
"//tensorflow/python/tools:optimize_for_inference_main_lib",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
@ -118,7 +118,7 @@ py_test(
|
||||
"//tensorflow/lite/python:lite",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python/tools:optimize_for_inference",
|
||||
"//tensorflow/python/tools:optimize_for_inference_main_lib",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
|
@ -1,245 +1,245 @@
|
||||
tensorflow/contrib/tpu/profiler/pip_package/BUILD
|
||||
tensorflow/contrib/tpu/profiler/pip_package/setup.py
|
||||
tensorflow/contrib/tpu/profiler/pip_package/README
|
||||
tensorflow/contrib/tpu/profiler/pip_package/build_pip_package.sh
|
||||
tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py
|
||||
tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/__init__.py
|
||||
tensorflow/contrib/mpi/BUILD
|
||||
tensorflow/tools/ci_build/remote/BUILD
|
||||
tensorflow/tools/pip_package/README
|
||||
tensorflow/tools/pip_package/MANIFEST.in
|
||||
tensorflow/tools/pip_package/simple_console.py
|
||||
tensorflow/tools/pip_package/build_pip_package.sh
|
||||
tensorflow/tools/pip_package/check_load_py_test.py
|
||||
tensorflow/tools/pip_package/pip_smoke_test.py
|
||||
tensorflow/tools/pip_package/simple_console_for_windows.py
|
||||
tensorflow/tools/pip_package/setup.py
|
||||
tensorflow/tools/pip_package/BUILD
|
||||
tensorflow/tools/lib_package/concat_licenses.sh
|
||||
tensorflow/tools/lib_package/libtensorflow_test.c
|
||||
tensorflow/tools/lib_package/LibTensorFlowTest.java
|
||||
tensorflow/tools/lib_package/BUILD
|
||||
tensorflow/tools/lib_package/libtensorflow_test.sh
|
||||
tensorflow/tools/lib_package/README.md
|
||||
tensorflow/tools/lib_package/libtensorflow_java_test.sh
|
||||
tensorflow/tools/def_file_filter/def_file_filter_configure.bzl
|
||||
tensorflow/tools/def_file_filter/BUILD
|
||||
tensorflow/tools/def_file_filter/BUILD.tpl
|
||||
tensorflow/tools/def_file_filter/def_file_filter.py.tpl
|
||||
tensorflow/third_party/mkl/MKL_LICENSE
|
||||
tensorflow/third_party/mkl/LICENSE
|
||||
tensorflow/third_party/mkl/BUILD
|
||||
tensorflow/third_party/mkl/mkl.BUILD
|
||||
tensorflow/third_party/mkl/build_defs.bzl
|
||||
tensorflow/third_party/backports_weakref.BUILD
|
||||
tensorflow/third_party/toolchains/clang6/BUILD
|
||||
tensorflow/third_party/toolchains/clang6/README.md
|
||||
tensorflow/third_party/toolchains/clang6/repo.bzl
|
||||
tensorflow/third_party/toolchains/clang6/CROSSTOOL.tpl
|
||||
tensorflow/third_party/toolchains/clang6/clang.BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc7-nvcc-cuda10.0/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/py3/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda9.0/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/cuda/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/cuda/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/generate/workspace.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/generate/containers.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/generate/generate.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/generate/archives.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/generate/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/dummy_toolchain.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/py3/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/win_1803/bazel_018/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/win_1803/bazel_018/dummy_toolchain.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/win_1803/py36/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/win_1803/BUILD
|
||||
tensorflow/third_party/systemlibs/nsync.BUILD
|
||||
tensorflow/third_party/systemlibs/absl_py.absl.flags.BUILD
|
||||
tensorflow/third_party/systemlibs/google_cloud_cpp.google.cloud.bigtable.BUILD
|
||||
tensorflow/third_party/systemlibs/build_defs.bzl.tpl
|
||||
tensorflow/third_party/systemlibs/curl.BUILD
|
||||
tensorflow/third_party/systemlibs/cython.BUILD
|
||||
tensorflow/third_party/systemlibs/astor.BUILD
|
||||
tensorflow/third_party/systemlibs/jsoncpp.BUILD
|
||||
tensorflow/third_party/systemlibs/png.BUILD
|
||||
tensorflow/third_party/systemlibs/pcre.BUILD
|
||||
tensorflow/third_party/systemlibs/grpc.BUILD
|
||||
tensorflow/third_party/systemlibs/protobuf.BUILD
|
||||
tensorflow/third_party/systemlibs/double_conversion.BUILD
|
||||
tensorflow/third_party/systemlibs/six.BUILD
|
||||
tensorflow/third_party/systemlibs/zlib.BUILD
|
||||
tensorflow/third_party/systemlibs/lmdb.BUILD
|
||||
tensorflow/third_party/systemlibs/sqlite.BUILD
|
||||
tensorflow/third_party/systemlibs/gast.BUILD
|
||||
tensorflow/third_party/systemlibs/absl_py.BUILD
|
||||
tensorflow/third_party/systemlibs/boringssl.BUILD
|
||||
tensorflow/third_party/systemlibs/BUILD.tpl
|
||||
tensorflow/third_party/systemlibs/BUILD
|
||||
tensorflow/third_party/systemlibs/termcolor.BUILD
|
||||
tensorflow/third_party/systemlibs/gif.BUILD
|
||||
tensorflow/third_party/systemlibs/protobuf.bzl
|
||||
tensorflow/third_party/systemlibs/snappy.BUILD
|
||||
tensorflow/third_party/systemlibs/googleapis.BUILD
|
||||
tensorflow/third_party/systemlibs/google_cloud_cpp.BUILD
|
||||
tensorflow/third_party/systemlibs/re2.BUILD
|
||||
tensorflow/third_party/systemlibs/swig.BUILD
|
||||
tensorflow/third_party/systemlibs/syslibs_configure.bzl
|
||||
tensorflow/third_party/systemlibs/absl_py.absl.testing.BUILD
|
||||
tensorflow/third_party/pprof.BUILD
|
||||
tensorflow/third_party/toolchains/remote/execution.bzl.tpl
|
||||
tensorflow/third_party/toolchains/remote/BUILD.tpl
|
||||
tensorflow/third_party/toolchains/remote/BUILD
|
||||
tensorflow/third_party/toolchains/remote/configure.bzl
|
||||
tensorflow/third_party/toolchains/cpus/py3/BUILD
|
||||
tensorflow/third_party/toolchains/cpus/py/BUILD
|
||||
tensorflow/third_party/toolchains/cpus/arm/arm_compiler_configure.bzl
|
||||
tensorflow/third_party/toolchains/cpus/arm/CROSSTOOL.tpl
|
||||
tensorflow/third_party/toolchains/cpus/arm/BUILD
|
||||
tensorflow/third_party/toolchains/cpus/py3/BUILD
|
||||
tensorflow/third_party/toolchains/cpus/py/BUILD
|
||||
tensorflow/third_party/toolchains/remote/configure.bzl
|
||||
tensorflow/third_party/toolchains/remote/BUILD.tpl
|
||||
tensorflow/third_party/toolchains/remote/BUILD
|
||||
tensorflow/third_party/toolchains/remote/execution.bzl.tpl
|
||||
tensorflow/third_party/toolchains/BUILD
|
||||
tensorflow/third_party/gpus/BUILD
|
||||
tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl
|
||||
tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
|
||||
tensorflow/third_party/gpus/crosstool/CROSSTOOL.tpl
|
||||
tensorflow/third_party/gpus/crosstool/CROSSTOOL_hipcc.tpl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/py3/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu16.04/clang/dummy_toolchain.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/win_1803/bazel_018/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/win_1803/bazel_018/dummy_toolchain.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/win_1803/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/win_1803/py36/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/generate/containers.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/generate/workspace.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/generate/archives.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/generate/generate.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/generate/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/py3/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda9.0/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc7-nvcc-cuda10.0/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc-cuda10.0/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/cuda/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/cuda/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/cuda10.0-cudnn7/cuda/BUILD
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/build_defs.bzl
|
||||
tensorflow/third_party/toolchains/preconfig/ubuntu14.04/tensorrt5/BUILD
|
||||
tensorflow/third_party/toolchains/clang6/repo.bzl
|
||||
tensorflow/third_party/toolchains/clang6/CROSSTOOL.tpl
|
||||
tensorflow/third_party/toolchains/clang6/BUILD
|
||||
tensorflow/third_party/toolchains/clang6/clang.BUILD
|
||||
tensorflow/third_party/toolchains/clang6/README.md
|
||||
tensorflow/third_party/farmhash.BUILD
|
||||
tensorflow/third_party/git/BUILD.tpl
|
||||
tensorflow/third_party/git/git_configure.bzl
|
||||
tensorflow/third_party/git/BUILD
|
||||
tensorflow/third_party/cub.BUILD
|
||||
tensorflow/third_party/gpus/cuda_configure.bzl
|
||||
tensorflow/third_party/gpus/rocm/build_defs.bzl.tpl
|
||||
tensorflow/third_party/gpus/rocm/BUILD.tpl
|
||||
tensorflow/third_party/gpus/rocm/BUILD
|
||||
tensorflow/third_party/gpus/rocm/rocm_config.h.tpl
|
||||
tensorflow/third_party/gpus/rocm_configure.bzl
|
||||
tensorflow/third_party/gpus/crosstool/LICENSE
|
||||
tensorflow/third_party/gpus/crosstool/windows/msvc_wrapper_for_nvcc.py.tpl
|
||||
tensorflow/third_party/gpus/crosstool/CROSSTOOL_hipcc.tpl
|
||||
tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
|
||||
tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl
|
||||
tensorflow/third_party/gpus/crosstool/CROSSTOOL.tpl
|
||||
tensorflow/third_party/gpus/crosstool/BUILD.tpl
|
||||
tensorflow/third_party/gpus/crosstool/BUILD
|
||||
tensorflow/third_party/gpus/cuda/build_defs.bzl.tpl
|
||||
tensorflow/third_party/gpus/cuda/LICENSE
|
||||
tensorflow/third_party/gpus/cuda/BUILD.tpl
|
||||
tensorflow/third_party/gpus/cuda/BUILD.windows.tpl
|
||||
tensorflow/third_party/gpus/cuda/cuda_config.h.tpl
|
||||
tensorflow/third_party/gpus/cuda/BUILD.tpl
|
||||
tensorflow/third_party/gpus/cuda/BUILD
|
||||
tensorflow/third_party/gpus/cuda/build_defs.bzl.tpl
|
||||
tensorflow/third_party/gpus/rocm/rocm_config.h.tpl
|
||||
tensorflow/third_party/gpus/rocm/BUILD
|
||||
tensorflow/third_party/gpus/rocm/BUILD.tpl
|
||||
tensorflow/third_party/gpus/rocm/build_defs.bzl.tpl
|
||||
tensorflow/third_party/gpus/cuda_configure.bzl
|
||||
tensorflow/third_party/gpus/rocm_configure.bzl
|
||||
tensorflow/third_party/snappy.BUILD
|
||||
tensorflow/third_party/cython.BUILD
|
||||
tensorflow/third_party/farmhash.BUILD
|
||||
tensorflow/third_party/eigen3/Eigen/Cholesky
|
||||
tensorflow/third_party/eigen3/Eigen/QR
|
||||
tensorflow/third_party/eigen3/Eigen/LU
|
||||
tensorflow/third_party/eigen3/Eigen/Core
|
||||
tensorflow/third_party/eigen3/Eigen/SVD
|
||||
tensorflow/third_party/eigen3/Eigen/Eigenvalues
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX512.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProduct.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatVecProduct.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductNEON.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX2.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/SpecialFunctions
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/MatrixFunctions
|
||||
tensorflow/third_party/eigen3/gpu_packet_math.patch
|
||||
tensorflow/third_party/eigen3/LICENSE
|
||||
tensorflow/third_party/eigen3/BUILD
|
||||
tensorflow/third_party/systemlibs/build_defs.bzl.tpl
|
||||
tensorflow/third_party/systemlibs/absl_py.BUILD
|
||||
tensorflow/third_party/systemlibs/curl.BUILD
|
||||
tensorflow/third_party/systemlibs/termcolor.BUILD
|
||||
tensorflow/third_party/systemlibs/absl_py.absl.flags.BUILD
|
||||
tensorflow/third_party/systemlibs/grpc.BUILD
|
||||
tensorflow/third_party/systemlibs/swig.BUILD
|
||||
tensorflow/third_party/systemlibs/protobuf.bzl
|
||||
tensorflow/third_party/systemlibs/protobuf.BUILD
|
||||
tensorflow/third_party/systemlibs/BUILD
|
||||
tensorflow/third_party/systemlibs/google_cloud_cpp.BUILD
|
||||
tensorflow/third_party/systemlibs/astor.BUILD
|
||||
tensorflow/third_party/systemlibs/six.BUILD
|
||||
tensorflow/third_party/systemlibs/absl_py.absl.testing.BUILD
|
||||
tensorflow/third_party/systemlibs/boringssl.BUILD
|
||||
tensorflow/third_party/systemlibs/nsync.BUILD
|
||||
tensorflow/third_party/systemlibs/google_cloud_cpp.google.cloud.bigtable.BUILD
|
||||
tensorflow/third_party/systemlibs/gif.BUILD
|
||||
tensorflow/third_party/systemlibs/pcre.BUILD
|
||||
tensorflow/third_party/systemlibs/BUILD.tpl
|
||||
tensorflow/third_party/systemlibs/snappy.BUILD
|
||||
tensorflow/third_party/systemlibs/gast.BUILD
|
||||
tensorflow/third_party/systemlibs/cython.BUILD
|
||||
tensorflow/third_party/systemlibs/double_conversion.BUILD
|
||||
tensorflow/third_party/systemlibs/zlib.BUILD
|
||||
tensorflow/third_party/systemlibs/jsoncpp.BUILD
|
||||
tensorflow/third_party/systemlibs/re2.BUILD
|
||||
tensorflow/third_party/systemlibs/lmdb.BUILD
|
||||
tensorflow/third_party/systemlibs/googleapis.BUILD
|
||||
tensorflow/third_party/systemlibs/png.BUILD
|
||||
tensorflow/third_party/systemlibs/syslibs_configure.bzl
|
||||
tensorflow/third_party/systemlibs/sqlite.BUILD
|
||||
tensorflow/third_party/python_runtime/BUILD
|
||||
tensorflow/third_party/sycl/crosstool/BUILD
|
||||
tensorflow/third_party/ngraph/LICENSE
|
||||
tensorflow/third_party/ngraph/tbb.BUILD
|
||||
tensorflow/third_party/ngraph/BUILD
|
||||
tensorflow/third_party/ngraph/ngraph.BUILD
|
||||
tensorflow/third_party/ngraph/build_defs.bzl
|
||||
tensorflow/third_party/ngraph/NGRAPH_LICENSE
|
||||
tensorflow/third_party/ngraph/ngraph_tf.BUILD
|
||||
tensorflow/third_party/ngraph/nlohmann_json.BUILD
|
||||
tensorflow/third_party/clang_toolchain/download_clang.bzl
|
||||
tensorflow/third_party/clang_toolchain/BUILD
|
||||
tensorflow/third_party/clang_toolchain/cc_configure_clang.bzl
|
||||
tensorflow/third_party/gast.BUILD
|
||||
tensorflow/third_party/llvm/BUILD
|
||||
tensorflow/third_party/llvm/expand_cmake_vars.py
|
||||
tensorflow/third_party/llvm/llvm.autogenerated.BUILD
|
||||
tensorflow/third_party/llvm/llvm.bzl
|
||||
tensorflow/third_party/icu/udata.patch
|
||||
tensorflow/third_party/nccl/archive.BUILD
|
||||
tensorflow/third_party/nccl/LICENSE
|
||||
tensorflow/third_party/nccl/system.BUILD.tpl
|
||||
tensorflow/third_party/nccl/nccl_configure.bzl
|
||||
tensorflow/third_party/nccl/build_defs.bzl.tpl
|
||||
tensorflow/third_party/nccl/BUILD
|
||||
tensorflow/third_party/fft2d/BUILD
|
||||
tensorflow/third_party/fft2d/fft.h
|
||||
tensorflow/third_party/fft2d/LICENSE
|
||||
tensorflow/third_party/fft2d/fft2d.BUILD
|
||||
tensorflow/third_party/boringssl/BUILD
|
||||
tensorflow/third_party/mpi/.gitignore
|
||||
tensorflow/third_party/mpi/BUILD
|
||||
tensorflow/third_party/tensorrt/LICENSE
|
||||
tensorflow/third_party/tensorrt/BUILD
|
||||
tensorflow/third_party/tensorrt/build_defs.bzl.tpl
|
||||
tensorflow/third_party/tensorrt/BUILD.tpl
|
||||
tensorflow/third_party/tensorrt/tensorrt_configure.bzl
|
||||
tensorflow/third_party/kafka/config.patch
|
||||
tensorflow/third_party/kafka/BUILD
|
||||
tensorflow/third_party/android/BUILD
|
||||
tensorflow/third_party/android/android.bzl.tpl
|
||||
tensorflow/third_party/android/android_configure.bzl
|
||||
tensorflow/third_party/android/android_configure.BUILD.tpl
|
||||
tensorflow/third_party/tflite_smartreply.BUILD
|
||||
tensorflow/third_party/gpus/BUILD
|
||||
tensorflow/third_party/common.bzl
|
||||
tensorflow/third_party/tflite_mobilenet_quant.BUILD
|
||||
tensorflow/third_party/linenoise.BUILD
|
||||
tensorflow/third_party/curl.BUILD
|
||||
tensorflow/third_party/mkl_dnn/LICENSE
|
||||
tensorflow/third_party/mkl_dnn/mkldnn.BUILD
|
||||
tensorflow/third_party/pcre.BUILD
|
||||
tensorflow/third_party/linenoise.BUILD
|
||||
tensorflow/third_party/sqlite.BUILD
|
||||
tensorflow/third_party/common.bzl
|
||||
tensorflow/third_party/com_google_absl.BUILD
|
||||
tensorflow/third_party/pprof.BUILD
|
||||
tensorflow/third_party/BUILD
|
||||
tensorflow/third_party/tflite_mobilenet_quant.BUILD
|
||||
tensorflow/third_party/lmdb.BUILD
|
||||
tensorflow/third_party/git/BUILD.tpl
|
||||
tensorflow/third_party/git/BUILD
|
||||
tensorflow/third_party/git/git_configure.bzl
|
||||
tensorflow/third_party/protobuf/BUILD
|
||||
tensorflow/third_party/enum34.BUILD
|
||||
tensorflow/third_party/tflite_mobilenet.BUILD
|
||||
tensorflow/third_party/py/BUILD
|
||||
tensorflow/third_party/py/BUILD.tpl
|
||||
tensorflow/third_party/py/numpy/BUILD
|
||||
tensorflow/third_party/py/python_configure.bzl
|
||||
tensorflow/third_party/termcolor.BUILD
|
||||
tensorflow/third_party/png_fix_rpi.patch
|
||||
tensorflow/third_party/swig.BUILD
|
||||
tensorflow/third_party/astor.BUILD
|
||||
tensorflow/third_party/fft2d/LICENSE
|
||||
tensorflow/third_party/fft2d/fft2d.BUILD
|
||||
tensorflow/third_party/fft2d/fft.h
|
||||
tensorflow/third_party/fft2d/BUILD
|
||||
tensorflow/third_party/ngraph/LICENSE
|
||||
tensorflow/third_party/ngraph/build_defs.bzl
|
||||
tensorflow/third_party/ngraph/tbb.BUILD
|
||||
tensorflow/third_party/ngraph/ngraph.BUILD
|
||||
tensorflow/third_party/ngraph/nlohmann_json.BUILD
|
||||
tensorflow/third_party/ngraph/BUILD
|
||||
tensorflow/third_party/ngraph/ngraph_tf.BUILD
|
||||
tensorflow/third_party/ngraph/NGRAPH_LICENSE
|
||||
tensorflow/third_party/grpc/BUILD
|
||||
tensorflow/third_party/curl.BUILD
|
||||
tensorflow/third_party/arm_neon_2_x86_sse.BUILD
|
||||
tensorflow/third_party/cython.BUILD
|
||||
tensorflow/third_party/icu/udata.patch
|
||||
tensorflow/third_party/astor.BUILD
|
||||
tensorflow/third_party/jsoncpp.BUILD
|
||||
tensorflow/third_party/sycl/crosstool/BUILD
|
||||
tensorflow/third_party/llvm/llvm.autogenerated.BUILD
|
||||
tensorflow/third_party/llvm/expand_cmake_vars.py
|
||||
tensorflow/third_party/llvm/llvm.bzl
|
||||
tensorflow/third_party/llvm/BUILD
|
||||
tensorflow/third_party/png.BUILD
|
||||
tensorflow/third_party/googleapis.BUILD
|
||||
tensorflow/third_party/mpi_collectives/BUILD
|
||||
tensorflow/third_party/nanopb.BUILD
|
||||
tensorflow/third_party/gif.BUILD
|
||||
tensorflow/third_party/arm_neon_2_x86_sse.BUILD
|
||||
tensorflow/third_party/codegen.BUILD
|
||||
tensorflow/third_party/enum34.BUILD
|
||||
tensorflow/third_party/kafka/config.patch
|
||||
tensorflow/third_party/kafka/BUILD
|
||||
tensorflow/third_party/pcre.BUILD
|
||||
tensorflow/third_party/mpi/BUILD
|
||||
tensorflow/third_party/mpi/.gitignore
|
||||
tensorflow/third_party/clang_toolchain/BUILD
|
||||
tensorflow/third_party/clang_toolchain/download_clang.bzl
|
||||
tensorflow/third_party/clang_toolchain/cc_configure_clang.bzl
|
||||
tensorflow/third_party/tflite_ovic_testdata.BUILD
|
||||
tensorflow/third_party/repo.bzl
|
||||
tensorflow/third_party/png_fix_rpi.patch
|
||||
tensorflow/third_party/py/python_configure.bzl
|
||||
tensorflow/third_party/py/BUILD.tpl
|
||||
tensorflow/third_party/py/BUILD
|
||||
tensorflow/third_party/py/numpy/BUILD
|
||||
tensorflow/third_party/double_conversion.BUILD
|
||||
tensorflow/third_party/six.BUILD
|
||||
tensorflow/third_party/tflite_mobilenet_float.BUILD
|
||||
tensorflow/third_party/repo.bzl
|
||||
tensorflow/third_party/codegen.BUILD
|
||||
tensorflow/third_party/cub.BUILD
|
||||
tensorflow/third_party/jsoncpp.BUILD
|
||||
tensorflow/third_party/tflite_ovic_testdata.BUILD
|
||||
tensorflow/third_party/libxsmm.BUILD
|
||||
tensorflow/third_party/zlib.BUILD
|
||||
tensorflow/third_party/lmdb.BUILD
|
||||
tensorflow/third_party/nanopb.BUILD
|
||||
tensorflow/third_party/android/android.bzl.tpl
|
||||
tensorflow/third_party/android/BUILD
|
||||
tensorflow/third_party/android/android_configure.BUILD.tpl
|
||||
tensorflow/third_party/android/android_configure.bzl
|
||||
tensorflow/third_party/tflite_mobilenet_float.BUILD
|
||||
tensorflow/third_party/sqlite.BUILD
|
||||
tensorflow/third_party/tensorrt/build_defs.bzl.tpl
|
||||
tensorflow/third_party/tensorrt/LICENSE
|
||||
tensorflow/third_party/tensorrt/tensorrt_configure.bzl
|
||||
tensorflow/third_party/tensorrt/BUILD.tpl
|
||||
tensorflow/third_party/tensorrt/BUILD
|
||||
tensorflow/third_party/gast.BUILD
|
||||
tensorflow/third_party/mpi_collectives/BUILD
|
||||
tensorflow/third_party/libxsmm.BUILD
|
||||
tensorflow/third_party/eigen.BUILD
|
||||
tensorflow/third_party/com_google_absl.BUILD
|
||||
tensorflow/third_party/eigen3/LICENSE
|
||||
tensorflow/third_party/eigen3/gpu_packet_math.patch
|
||||
tensorflow/third_party/eigen3/BUILD
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/MatrixFunctions
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/SpecialFunctions
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX512.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProduct.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductAVX2.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX512.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/TypeCastingAVX2.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatMatProductNEON.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/FixedPointTypes.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/MatVecProduct.h
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/FixedPoint
|
||||
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool
|
||||
tensorflow/third_party/eigen3/Eigen/QR
|
||||
tensorflow/third_party/eigen3/Eigen/SVD
|
||||
tensorflow/third_party/eigen3/Eigen/LU
|
||||
tensorflow/third_party/eigen3/Eigen/Cholesky
|
||||
tensorflow/third_party/eigen3/Eigen/Eigenvalues
|
||||
tensorflow/third_party/eigen3/Eigen/Core
|
||||
tensorflow/third_party/BUILD
|
||||
tensorflow/third_party/termcolor.BUILD
|
||||
tensorflow/third_party/gif.BUILD
|
||||
tensorflow/third_party/tflite_mobilenet.BUILD
|
||||
tensorflow/third_party/mkl/LICENSE
|
||||
tensorflow/third_party/mkl/build_defs.bzl
|
||||
tensorflow/third_party/mkl/mkl.BUILD
|
||||
tensorflow/third_party/mkl/MKL_LICENSE
|
||||
tensorflow/third_party/mkl/BUILD
|
||||
tensorflow/third_party/nccl/build_defs.bzl.tpl
|
||||
tensorflow/third_party/nccl/LICENSE
|
||||
tensorflow/third_party/nccl/nccl_configure.bzl
|
||||
tensorflow/third_party/nccl/archive.BUILD
|
||||
tensorflow/third_party/nccl/BUILD
|
||||
tensorflow/third_party/nccl/system.BUILD.tpl
|
||||
tensorflow/third_party/snappy.BUILD
|
||||
tensorflow/third_party/python_runtime/BUILD
|
||||
tensorflow/third_party/googleapis.BUILD
|
||||
tensorflow/third_party/boringssl/BUILD
|
||||
tensorflow/third_party/protobuf/BUILD
|
||||
tensorflow/third_party/backports_weakref.BUILD
|
||||
tensorflow/third_party/tflite_smartreply.BUILD
|
||||
tensorflow/third_party/swig.BUILD
|
||||
tensorflow/compat_template.__init__.py
|
||||
tensorflow/tools/lib_package/libtensorflow_test.sh
|
||||
tensorflow/tools/lib_package/libtensorflow_java_test.sh
|
||||
tensorflow/tools/lib_package/libtensorflow_test.c
|
||||
tensorflow/tools/lib_package/concat_licenses.sh
|
||||
tensorflow/tools/lib_package/LibTensorFlowTest.java
|
||||
tensorflow/tools/lib_package/BUILD
|
||||
tensorflow/tools/lib_package/README.md
|
||||
tensorflow/tools/pip_package/check_load_py_test.py
|
||||
tensorflow/tools/pip_package/simple_console.py
|
||||
tensorflow/tools/pip_package/pip_smoke_test.py
|
||||
tensorflow/tools/pip_package/BUILD
|
||||
tensorflow/tools/pip_package/simple_console_for_windows.py
|
||||
tensorflow/tools/pip_package/build_pip_package.sh
|
||||
tensorflow/tools/pip_package/README
|
||||
tensorflow/tools/pip_package/setup.py
|
||||
tensorflow/tools/pip_package/MANIFEST.in
|
||||
tensorflow/tools/ci_build/remote/BUILD
|
||||
tensorflow/tools/def_file_filter/def_file_filter.py.tpl
|
||||
tensorflow/tools/def_file_filter/BUILD.tpl
|
||||
tensorflow/tools/def_file_filter/BUILD
|
||||
tensorflow/tools/def_file_filter/def_file_filter_configure.bzl
|
||||
tensorflow/api_template.__init__.py
|
||||
tensorflow/contrib/tpu/profiler/pip_package/BUILD
|
||||
tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/__init__.py
|
||||
tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py
|
||||
tensorflow/contrib/tpu/profiler/pip_package/build_pip_package.sh
|
||||
tensorflow/contrib/tpu/profiler/pip_package/README
|
||||
tensorflow/contrib/tpu/profiler/pip_package/setup.py
|
||||
tensorflow/contrib/mpi/BUILD
|
||||
tensorflow/__init__.py
|
||||
tensorflow/stream_executor/build_defs.bzl
|
||||
tensorflow/api_template_v1.__init__.py
|
||||
tensorflow/compat_template_v1.__init__.py
|
||||
tensorflow/compat_template.__init__.py
|
||||
tensorflow/api_template.__init__.py
|
||||
tensorflow/__init__.py
|
||||
tensorflow/compat_template_v1.__init__.py
|
@ -47,7 +47,7 @@ py_library(
|
||||
":cli_test_utils",
|
||||
":debug_py",
|
||||
":grpc_debug_test_server",
|
||||
":grpc_tensorflow_server",
|
||||
":grpc_tensorflow_server_lib",
|
||||
":offline_analyzer_lib",
|
||||
":session_debug_testlib",
|
||||
":source_remote",
|
||||
@ -1124,7 +1124,7 @@ cuda_py_test(
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:variables",
|
||||
],
|
||||
data = [":grpc_tensorflow_server"],
|
||||
data = [":grpc_tensorflow_server_lib"],
|
||||
grpc_enabled = True,
|
||||
tags = [
|
||||
"no_oss", # Incompatible with bazel_pip.
|
||||
|
@ -833,10 +833,14 @@ class RNN(Layer):
|
||||
def reset_states(self, states=None):
|
||||
if not self.stateful:
|
||||
raise AttributeError('Layer must be stateful.')
|
||||
if self.time_major:
|
||||
batch_size = self.input_spec[0].shape[1]
|
||||
spec_shape = None if self.input_spec is None else self.input_spec[0].shape
|
||||
if spec_shape is None:
|
||||
# It is possible to have spec shape to be None, eg when construct a RNN
|
||||
# with a custom cell, or standard RNN layers (LSTM/GRU) which we only know
|
||||
# it has 3 dim input, but not its full shape spec before build().
|
||||
batch_size = None
|
||||
else:
|
||||
batch_size = self.input_spec[0].shape[0]
|
||||
batch_size = spec_shape[1] if self.time_major else spec_shape[0]
|
||||
if not batch_size:
|
||||
raise ValueError('If a RNN is stateful, it needs to know '
|
||||
'its batch size. Specify the batch size '
|
||||
|
@ -1315,6 +1315,17 @@ class RNNTest(keras_parameterized.TestCase):
|
||||
model = keras.Model([inputs, state_h, state_c], decoder_out)
|
||||
model.reset_states()
|
||||
|
||||
def test_reset_states(self):
|
||||
# See https://github.com/tensorflow/tensorflow/issues/25852
|
||||
with self.assertRaisesRegexp(ValueError, 'it needs to know its batch size'):
|
||||
simple_rnn = keras.layers.SimpleRNN(1, stateful=True)
|
||||
simple_rnn.reset_states()
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, 'it needs to know its batch size'):
|
||||
cell = Minimal2DRNNCell(1, 2)
|
||||
custom_rnn = keras.layers.RNN(cell, stateful=True)
|
||||
custom_rnn.reset_states()
|
||||
|
||||
|
||||
class Minimal2DRNNCell(keras.layers.Layer):
|
||||
"""The minimal 2D RNN cell is a simple combination of 2 1-D RNN cell.
|
||||
|
Loading…
x
Reference in New Issue
Block a user