Skip to content

Commit

Permalink
feat: Run Detection Models from Terminal
Browse files Browse the repository at this point in the history
* Kept arguments mostly inline with the MegaDetector run_batch_detector.py script.
* Currently missing recursive, image queue and checkpoint functionality in comparison.

Once installed, allows a user to run the command with:

    python -m camtrapml.scripts.batch_detection model_name path_to_image_dir path_to_output_json
  • Loading branch information
bencevans committed Jun 23, 2022
1 parent b5a86f2 commit 5c8cb38
Showing 1 changed file with 112 additions and 0 deletions.
112 changes: 112 additions & 0 deletions camtrapml/scripts/batch_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
"""
Command Line Utility for batch detection.
"""

from argparse import ArgumentParser
from json import dump
from datetime import datetime
from tqdm import tqdm
from camtrapml.dataset import ImageDataset
from camtrapml.detection.models.megadetector import (
MegaDetectorV2,
MegaDetectorV3,
MegaDetectorV4_1,
)


def parse_args():
"""
Parse command line arguments.
"""

parser = ArgumentParser()
parser.add_argument(
"model",
type=str,
help="Detection model to utilise [md4, md5a, md5b]",
)
parser.add_argument(
"dataset_path",
type=str,
help="Path to directory containing the images",
)
parser.add_argument("output_path", type=str, help="Path to store the JSON output")

parser.add_argument("--output_relative_filenames", action="store_true")

return parser.parse_args()


def get_model(model_name):
"""
Load a detection model based on a short name.
"""

if model_name == "md2":
return MegaDetectorV2()

if model_name == "md3":
return MegaDetectorV3()

if model_name == "md4":
return MegaDetectorV4_1()

raise ValueError(f"Unknown model {model_name}")


def detection_to_json_types(detection):
"""
Convert a detection to a JSON-compatible dictionary.
"""
detection["conf"] = float(detection["conf"])
detection["bbox"] = [float(x) for x in detection["bbox"]]
detection["category"] = str(detection["category"])
return detection


def batch_detection():
"""
Run detection on a batch of images.
"""

args = parse_args()

model = get_model(args.model)

print("Enumerating images...", end="")
dataset = ImageDataset(args.dataset_path)
image_paths = list(tqdm(dataset.enumerate_images()))
print(" Done")

results = {"images": []}

for image_path in tqdm(image_paths):
if args.output_relative_filenames:
output_image_path = str(image_path.relative_to(args.dataset_path))
else:
output_image_path = str(image_path)

prediction = model.detect(image_path)

results["images"].append(
{
"file": output_image_path,
"detections": [
detection_to_json_types(detection) for detection in prediction
],
}
)

results["detection_categories"] = {"1": "animal", "2": "person", "3": "vehicle"}

results["info"] = {
"detection_completion_time": str(datetime.now()),
"format_version": "1.0",
}

with open(args.output_path, "w") as file_handle:
dump(results, file_handle, indent=2)


if __name__ == "__main__":
batch_detection()

0 comments on commit 5c8cb38

Please sign in to comment.