Moving ComputeArgumentShapes to tpu_compile_op_support.
PiperOrigin-RevId: 324064650 Change-Id: I057003860e7849b63351d855b4a9c0a166a10cad
This commit is contained in:
parent
c1336e9a40
commit
34bd3aaad4
tensorflow/core/tpu/kernels
@ -413,46 +413,6 @@ Status TpuCompileOpKernelCommon::CompileTFFunctionToHlo(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/* static */
|
||||
Status TpuCompileOpKernelCommon::ComputeArgumentShapes(
|
||||
const tpu::TPUCompileMetadataProto& metadata,
|
||||
const std::vector<TensorShape>& dynamic_shapes,
|
||||
std::vector<TensorShape>* arg_shapes) {
|
||||
arg_shapes->resize(metadata.args_size());
|
||||
int dynamic_shape_pos = 0;
|
||||
for (int i = 0; i < metadata.args_size(); ++i) {
|
||||
const tpu::TPUCompileMetadataProto::Arg& arg = metadata.args(i);
|
||||
// The XLA compiler determines the shape of each constant by inspecting the
|
||||
// value of its corresponding host-memory tensor. As a result, we don't need
|
||||
// to give the compiler graph-inferred shapes for constant arguments.
|
||||
if (arg.kind() == tpu::TPUCompileMetadataProto::Arg::GUARANTEED_CONSTANT) {
|
||||
continue;
|
||||
}
|
||||
TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(arg.shape()));
|
||||
PartialTensorShape static_shape(arg.shape());
|
||||
|
||||
TensorShape& shape = (*arg_shapes)[i];
|
||||
if (static_shape.IsFullyDefined()) {
|
||||
TF_RET_CHECK(static_shape.AsTensorShape(&shape));
|
||||
} else {
|
||||
TF_RET_CHECK(dynamic_shape_pos < dynamic_shapes.size())
|
||||
<< "Too few dynamic shapes";
|
||||
shape = dynamic_shapes[dynamic_shape_pos++];
|
||||
if (!static_shape.IsCompatibleWith(shape)) {
|
||||
return errors::InvalidArgument(
|
||||
"Mismatch between static and dynamic shape for argument. Static "
|
||||
"shape: ",
|
||||
static_shape.DebugString(),
|
||||
"; dynamic shape: ", shape.DebugString());
|
||||
}
|
||||
}
|
||||
}
|
||||
// Checks we consumed all of the dynamic shapes.
|
||||
TF_RET_CHECK(dynamic_shape_pos == dynamic_shapes.size())
|
||||
<< "Too many dynamic shapes";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Function arguments and return values lose their device assignments, so we
|
||||
// must recreate them.
|
||||
/* static */ Status TpuCompileOpKernelCommon::AssignDevicesToArgsAndRetvals(
|
||||
|
@ -99,15 +99,6 @@ class TpuCompileOpKernelCommon {
|
||||
const std::vector<TensorShape>& arg_shapes,
|
||||
TpuProgramGroupInterface* tpu_program_group) = 0;
|
||||
|
||||
// Computes shapes for each argument. Uses both the static shape from the
|
||||
// metadata, and the dynamic shapes where the static shape is not
|
||||
// defined. There must be one dynamic_shape for each argument with a
|
||||
// partially defined shape, in index order.
|
||||
static Status ComputeArgumentShapes(
|
||||
const tpu::TPUCompileMetadataProto& metadata,
|
||||
const std::vector<TensorShape>& dynamic_shapes,
|
||||
std::vector<TensorShape>* arg_shapes);
|
||||
|
||||
// Performs shape inference on `computation`, filling shape_info with operator
|
||||
// shapes. The shapes of the _Arg nodes are taken from `arg_shapes`.
|
||||
static Status RunShapeInferenceOnComputation(
|
||||
|
@ -540,5 +540,43 @@ Status CompileOpMetadataFromContext(OpKernelConstruction* ctx,
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ComputeArgumentShapes(const tpu::TPUCompileMetadataProto& metadata,
|
||||
const std::vector<TensorShape>& dynamic_shapes,
|
||||
std::vector<TensorShape>* arg_shapes) {
|
||||
arg_shapes->resize(metadata.args_size());
|
||||
int dynamic_shape_pos = 0;
|
||||
for (int i = 0; i < metadata.args_size(); ++i) {
|
||||
const tpu::TPUCompileMetadataProto::Arg& arg = metadata.args(i);
|
||||
// The XLA compiler determines the shape of each constant by inspecting the
|
||||
// value of its corresponding host-memory tensor. As a result, we don't need
|
||||
// to give the compiler graph-inferred shapes for constant arguments.
|
||||
if (arg.kind() == tpu::TPUCompileMetadataProto::Arg::GUARANTEED_CONSTANT) {
|
||||
continue;
|
||||
}
|
||||
TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(arg.shape()));
|
||||
PartialTensorShape static_shape(arg.shape());
|
||||
|
||||
TensorShape& shape = (*arg_shapes)[i];
|
||||
if (static_shape.IsFullyDefined()) {
|
||||
TF_RET_CHECK(static_shape.AsTensorShape(&shape));
|
||||
} else {
|
||||
TF_RET_CHECK(dynamic_shape_pos < dynamic_shapes.size())
|
||||
<< "Too few dynamic shapes";
|
||||
shape = dynamic_shapes[dynamic_shape_pos++];
|
||||
if (!static_shape.IsCompatibleWith(shape)) {
|
||||
return errors::InvalidArgument(
|
||||
"Mismatch between static and dynamic shape for argument. Static "
|
||||
"shape: ",
|
||||
static_shape.DebugString(),
|
||||
"; dynamic shape: ", shape.DebugString());
|
||||
}
|
||||
}
|
||||
}
|
||||
// Checks we consumed all of the dynamic shapes.
|
||||
TF_RET_CHECK(dynamic_shape_pos == dynamic_shapes.size())
|
||||
<< "Too many dynamic shapes";
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
@ -159,6 +159,14 @@ se::port::Status CompileOpMetadataFromContext(OpKernelConstruction* ctx,
|
||||
TPUCompileMetadataProto* metadata,
|
||||
NameAttrList* function_name,
|
||||
std::string* mlir_module);
|
||||
|
||||
// Computes shapes for each argument. Uses both the static shape from the
|
||||
// metadata, and the dynamic shapes where the static shape is not
|
||||
// defined. There must be one dynamic_shape for each argument with a
|
||||
// partially defined shape, in index order.
|
||||
Status ComputeArgumentShapes(const TPUCompileMetadataProto& metadata,
|
||||
const std::vector<TensorShape>& dynamic_shapes,
|
||||
std::vector<TensorShape>* arg_shapes);
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user