Add File
This commit is contained in:
38
pcdet/utils/spconv_utils.py
Normal file
38
pcdet/utils/spconv_utils.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from typing import Set
|
||||
|
||||
import spconv
|
||||
if float(spconv.__version__[2:]) >= 2.2:
|
||||
spconv.constants.SPCONV_USE_DIRECT_TABLE = False
|
||||
|
||||
try:
|
||||
import spconv.pytorch as spconv
|
||||
except:
|
||||
import spconv as spconv
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def find_all_spconv_keys(model: nn.Module, prefix="") -> Set[str]:
|
||||
"""
|
||||
Finds all spconv keys that need to have weight's transposed
|
||||
"""
|
||||
found_keys: Set[str] = set()
|
||||
for name, child in model.named_children():
|
||||
new_prefix = f"{prefix}.{name}" if prefix != "" else name
|
||||
|
||||
if isinstance(child, spconv.conv.SparseConvolution):
|
||||
new_prefix = f"{new_prefix}.weight"
|
||||
found_keys.add(new_prefix)
|
||||
|
||||
found_keys.update(find_all_spconv_keys(child, prefix=new_prefix))
|
||||
|
||||
return found_keys
|
||||
|
||||
|
||||
def replace_feature(out, new_features):
|
||||
if "replace_feature" in out.__dir__():
|
||||
# spconv 2.x behaviour
|
||||
return out.replace_feature(new_features)
|
||||
else:
|
||||
out.features = new_features
|
||||
return out
|
||||
Reference in New Issue
Block a user