Created
July 8, 2024 17:44
-
-
Save DarthSim/216551dfd58e5628290e90c1d358704b to your computer and use it in GitHub Desktop.
Revisions
-
DarthSim created this gist
Jul 8, 2024 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,61 @@ diff --git a/export_opencv.py b/export_opencv.py new file mode 100644 index 00000000..15bfef90 --- /dev/null +++ b/export_opencv.py @@ -0,0 +1,23 @@ +from ultralytics import YOLOv10 +import argparse + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument("--weights", type=str, + default="yolov10n.pt", + help="model.pt path") + parser.add_argument("--imgsz", type=int, nargs=2, + default=(640, 640), + help="Image size for the model") + parser.add_argument("--half", + action="store_true", + help="FP16 half-precision export") + args = parser.parse_args() + + model = YOLOv10(args.weights) + + model.export(format='onnx', + imgsz=args.imgsz, + simplify=True, + half=args.half) diff --git a/ultralytics/nn/modules/head.py b/ultralytics/nn/modules/head.py index a9c5d9ee..544ab8b7 100644 --- a/ultralytics/nn/modules/head.py +++ b/ultralytics/nn/modules/head.py @@ -79,7 +79,7 @@ class Detect(nn.Module): def forward(self, x): """Concatenates and returns predicted bounding boxes and class probabilities.""" y = self.forward_feat(x, self.cv2, self.cv3) - + if self.training: return y @@ -507,7 +507,7 @@ class v10Detect(Detect): self.one2one_cv2 = copy.deepcopy(self.cv2) self.one2one_cv3 = copy.deepcopy(self.cv3) - + def forward(self, x): one2one = self.forward_feat([xi.detach() for xi in x], self.one2one_cv2, self.one2one_cv3) if not self.export: @@ -519,8 +519,7 @@ class v10Detect(Detect): return {"one2many": one2many, "one2one": one2one} else: assert(self.max_det != -1) - boxes, scores, labels = ops.v10postprocess(one2one.permute(0, 2, 1), self.max_det, self.nc) - return torch.cat([boxes, scores.unsqueeze(-1), labels.unsqueeze(-1).to(boxes.dtype)], dim=-1) + return one2one else: return {"one2many": one2many, "one2one": one2one}