This commit is contained in:
2025-09-21 20:19:08 +08:00
parent 42f93d879d
commit 79335ce388

View File

@@ -0,0 +1,79 @@
import torch
class AnchorGenerator(object):
def __init__(self, anchor_range, anchor_generator_config):
super().__init__()
self.anchor_generator_cfg = anchor_generator_config
self.anchor_range = anchor_range
self.anchor_sizes = [config['anchor_sizes'] for config in anchor_generator_config]
self.anchor_rotations = [config['anchor_rotations'] for config in anchor_generator_config]
self.anchor_heights = [config['anchor_bottom_heights'] for config in anchor_generator_config]
self.align_center = [config.get('align_center', False) for config in anchor_generator_config]
assert len(self.anchor_sizes) == len(self.anchor_rotations) == len(self.anchor_heights)
self.num_of_anchor_sets = len(self.anchor_sizes)
def generate_anchors(self, grid_sizes):
assert len(grid_sizes) == self.num_of_anchor_sets
all_anchors = []
num_anchors_per_location = []
for grid_size, anchor_size, anchor_rotation, anchor_height, align_center in zip(
grid_sizes, self.anchor_sizes, self.anchor_rotations, self.anchor_heights, self.align_center):
num_anchors_per_location.append(len(anchor_rotation) * len(anchor_size) * len(anchor_height))
if align_center:
x_stride = (self.anchor_range[3] - self.anchor_range[0]) / grid_size[0]
y_stride = (self.anchor_range[4] - self.anchor_range[1]) / grid_size[1]
x_offset, y_offset = x_stride / 2, y_stride / 2
else:
x_stride = (self.anchor_range[3] - self.anchor_range[0]) / (grid_size[0] - 1)
y_stride = (self.anchor_range[4] - self.anchor_range[1]) / (grid_size[1] - 1)
x_offset, y_offset = 0, 0
x_shifts = torch.arange(
self.anchor_range[0] + x_offset, self.anchor_range[3] + 1e-5, step=x_stride, dtype=torch.float32,
).cuda()
y_shifts = torch.arange(
self.anchor_range[1] + y_offset, self.anchor_range[4] + 1e-5, step=y_stride, dtype=torch.float32,
).cuda()
z_shifts = x_shifts.new_tensor(anchor_height)
num_anchor_size, num_anchor_rotation = anchor_size.__len__(), anchor_rotation.__len__()
anchor_rotation = x_shifts.new_tensor(anchor_rotation)
anchor_size = x_shifts.new_tensor(anchor_size)
x_shifts, y_shifts, z_shifts = torch.meshgrid([
x_shifts, y_shifts, z_shifts
]) # [x_grid, y_grid, z_grid]
anchors = torch.stack((x_shifts, y_shifts, z_shifts), dim=-1) # [x, y, z, 3]
anchors = anchors[:, :, :, None, :].repeat(1, 1, 1, anchor_size.shape[0], 1)
anchor_size = anchor_size.view(1, 1, 1, -1, 3).repeat([*anchors.shape[0:3], 1, 1])
anchors = torch.cat((anchors, anchor_size), dim=-1)
anchors = anchors[:, :, :, :, None, :].repeat(1, 1, 1, 1, num_anchor_rotation, 1)
anchor_rotation = anchor_rotation.view(1, 1, 1, 1, -1, 1).repeat([*anchors.shape[0:3], num_anchor_size, 1, 1])
anchors = torch.cat((anchors, anchor_rotation), dim=-1) # [x, y, z, num_size, num_rot, 7]
anchors = anchors.permute(2, 1, 0, 3, 4, 5).contiguous()
#anchors = anchors.view(-1, anchors.shape[-1])
anchors[..., 2] += anchors[..., 5] / 2 # shift to box centers
all_anchors.append(anchors)
return all_anchors, num_anchors_per_location
if __name__ == '__main__':
from easydict import EasyDict
config = [
EasyDict({
'anchor_sizes': [[2.1, 4.7, 1.7], [0.86, 0.91, 1.73], [0.84, 1.78, 1.78]],
'anchor_rotations': [0, 1.57],
'anchor_heights': [0, 0.5]
})
]
A = AnchorGenerator(
anchor_range=[-75.2, -75.2, -2, 75.2, 75.2, 4],
anchor_generator_config=config
)
import pdb
pdb.set_trace()
A.generate_anchors([[188, 188]])