diff --git a/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc b/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc index 168d2507e34..8bd45db2206 100644 --- a/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc +++ b/tensorflow/core/tpu/kernels/tpu_compile_op_common.cc @@ -413,46 +413,6 @@ Status TpuCompileOpKernelCommon::CompileTFFunctionToHlo( return Status::OK(); } -/* static */ -Status TpuCompileOpKernelCommon::ComputeArgumentShapes( - const tpu::TPUCompileMetadataProto& metadata, - const std::vector& dynamic_shapes, - std::vector* arg_shapes) { - arg_shapes->resize(metadata.args_size()); - int dynamic_shape_pos = 0; - for (int i = 0; i < metadata.args_size(); ++i) { - const tpu::TPUCompileMetadataProto::Arg& arg = metadata.args(i); - // The XLA compiler determines the shape of each constant by inspecting the - // value of its corresponding host-memory tensor. As a result, we don't need - // to give the compiler graph-inferred shapes for constant arguments. - if (arg.kind() == tpu::TPUCompileMetadataProto::Arg::GUARANTEED_CONSTANT) { - continue; - } - TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(arg.shape())); - PartialTensorShape static_shape(arg.shape()); - - TensorShape& shape = (*arg_shapes)[i]; - if (static_shape.IsFullyDefined()) { - TF_RET_CHECK(static_shape.AsTensorShape(&shape)); - } else { - TF_RET_CHECK(dynamic_shape_pos < dynamic_shapes.size()) - << "Too few dynamic shapes"; - shape = dynamic_shapes[dynamic_shape_pos++]; - if (!static_shape.IsCompatibleWith(shape)) { - return errors::InvalidArgument( - "Mismatch between static and dynamic shape for argument. Static " - "shape: ", - static_shape.DebugString(), - "; dynamic shape: ", shape.DebugString()); - } - } - } - // Checks we consumed all of the dynamic shapes. - TF_RET_CHECK(dynamic_shape_pos == dynamic_shapes.size()) - << "Too many dynamic shapes"; - return Status::OK(); -} - // Function arguments and return values lose their device assignments, so we // must recreate them. /* static */ Status TpuCompileOpKernelCommon::AssignDevicesToArgsAndRetvals( diff --git a/tensorflow/core/tpu/kernels/tpu_compile_op_common.h b/tensorflow/core/tpu/kernels/tpu_compile_op_common.h index 3d3f0afcdb7..327aa460ddd 100644 --- a/tensorflow/core/tpu/kernels/tpu_compile_op_common.h +++ b/tensorflow/core/tpu/kernels/tpu_compile_op_common.h @@ -99,15 +99,6 @@ class TpuCompileOpKernelCommon { const std::vector& arg_shapes, TpuProgramGroupInterface* tpu_program_group) = 0; - // Computes shapes for each argument. Uses both the static shape from the - // metadata, and the dynamic shapes where the static shape is not - // defined. There must be one dynamic_shape for each argument with a - // partially defined shape, in index order. - static Status ComputeArgumentShapes( - const tpu::TPUCompileMetadataProto& metadata, - const std::vector& dynamic_shapes, - std::vector* arg_shapes); - // Performs shape inference on `computation`, filling shape_info with operator // shapes. The shapes of the _Arg nodes are taken from `arg_shapes`. static Status RunShapeInferenceOnComputation( diff --git a/tensorflow/core/tpu/kernels/tpu_compile_op_support.cc b/tensorflow/core/tpu/kernels/tpu_compile_op_support.cc index 5cc35a07e66..3440b6d265a 100644 --- a/tensorflow/core/tpu/kernels/tpu_compile_op_support.cc +++ b/tensorflow/core/tpu/kernels/tpu_compile_op_support.cc @@ -540,5 +540,43 @@ Status CompileOpMetadataFromContext(OpKernelConstruction* ctx, } return Status::OK(); } + +Status ComputeArgumentShapes(const tpu::TPUCompileMetadataProto& metadata, + const std::vector& dynamic_shapes, + std::vector* arg_shapes) { + arg_shapes->resize(metadata.args_size()); + int dynamic_shape_pos = 0; + for (int i = 0; i < metadata.args_size(); ++i) { + const tpu::TPUCompileMetadataProto::Arg& arg = metadata.args(i); + // The XLA compiler determines the shape of each constant by inspecting the + // value of its corresponding host-memory tensor. As a result, we don't need + // to give the compiler graph-inferred shapes for constant arguments. + if (arg.kind() == tpu::TPUCompileMetadataProto::Arg::GUARANTEED_CONSTANT) { + continue; + } + TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(arg.shape())); + PartialTensorShape static_shape(arg.shape()); + + TensorShape& shape = (*arg_shapes)[i]; + if (static_shape.IsFullyDefined()) { + TF_RET_CHECK(static_shape.AsTensorShape(&shape)); + } else { + TF_RET_CHECK(dynamic_shape_pos < dynamic_shapes.size()) + << "Too few dynamic shapes"; + shape = dynamic_shapes[dynamic_shape_pos++]; + if (!static_shape.IsCompatibleWith(shape)) { + return errors::InvalidArgument( + "Mismatch between static and dynamic shape for argument. Static " + "shape: ", + static_shape.DebugString(), + "; dynamic shape: ", shape.DebugString()); + } + } + } + // Checks we consumed all of the dynamic shapes. + TF_RET_CHECK(dynamic_shape_pos == dynamic_shapes.size()) + << "Too many dynamic shapes"; + return Status::OK(); +} } // namespace tpu } // namespace tensorflow diff --git a/tensorflow/core/tpu/kernels/tpu_compile_op_support.h b/tensorflow/core/tpu/kernels/tpu_compile_op_support.h index bc60f64286a..ea13d33b521 100644 --- a/tensorflow/core/tpu/kernels/tpu_compile_op_support.h +++ b/tensorflow/core/tpu/kernels/tpu_compile_op_support.h @@ -159,6 +159,14 @@ se::port::Status CompileOpMetadataFromContext(OpKernelConstruction* ctx, TPUCompileMetadataProto* metadata, NameAttrList* function_name, std::string* mlir_module); + +// Computes shapes for each argument. Uses both the static shape from the +// metadata, and the dynamic shapes where the static shape is not +// defined. There must be one dynamic_shape for each argument with a +// partially defined shape, in index order. +Status ComputeArgumentShapes(const TPUCompileMetadataProto& metadata, + const std::vector& dynamic_shapes, + std::vector* arg_shapes); } // namespace tpu } // namespace tensorflow