From 66be37a3b9d98a7b0f3312346ced2c23d46a89b6 Mon Sep 17 00:00:00 2001 From: inter Date: Sun, 21 Sep 2025 20:19:27 +0800 Subject: [PATCH] Add File --- pcdet/utils/spconv_utils.py | 38 +++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 pcdet/utils/spconv_utils.py diff --git a/pcdet/utils/spconv_utils.py b/pcdet/utils/spconv_utils.py new file mode 100644 index 0000000..c38f899 --- /dev/null +++ b/pcdet/utils/spconv_utils.py @@ -0,0 +1,38 @@ +from typing import Set + +import spconv +if float(spconv.__version__[2:]) >= 2.2: + spconv.constants.SPCONV_USE_DIRECT_TABLE = False + +try: + import spconv.pytorch as spconv +except: + import spconv as spconv + +import torch.nn as nn + + +def find_all_spconv_keys(model: nn.Module, prefix="") -> Set[str]: + """ + Finds all spconv keys that need to have weight's transposed + """ + found_keys: Set[str] = set() + for name, child in model.named_children(): + new_prefix = f"{prefix}.{name}" if prefix != "" else name + + if isinstance(child, spconv.conv.SparseConvolution): + new_prefix = f"{new_prefix}.weight" + found_keys.add(new_prefix) + + found_keys.update(find_all_spconv_keys(child, prefix=new_prefix)) + + return found_keys + + +def replace_feature(out, new_features): + if "replace_feature" in out.__dir__(): + # spconv 2.x behaviour + return out.replace_feature(new_features) + else: + out.features = new_features + return out