
Where is the ant?

Using YOLOv8 for object detection, we analyzed thousands of hours of video to study how ants interact with wounded trees in Costa Rica. By leveraging AI-driven tracking, we identified and compared ant activity on both healthy and wounded branches, providing valuable insights into potential symbiotic relationships. The combination of classification and bounding box detection allowed us to quantify these interactions with precision, creating a data-driven foundation for further ecological research. With the findings expected to be published later this year, this study paves the way for deeper investigations into how ants respond to environmental changes.
In 2024, a research student from the University of Regensburg approached me for advice and consultancy on object detection. Their study focused on a specific ant species and its interaction with wounded trees in Costa Rica. The goal was to analyze thousands of hours of video footage and conduct a statistical evaluation to determine whether there was a significant correlation between the ants and the wounded areas. This analysis would serve as a basis for further assumptions and follow-up studies. Both the absolute number of ants present on wounds (per frame) and heat maps would help support the hypothesis that a form of symbiosis might be occurring.
Since the goal was not only classification but also object detection, a YOLOv8 model was chosen over a VGG16 network. Although YOLO is comparatively slower than VGG in the detection step, it offers a key advantage: by updating the output vector, it directly provides the coordinates of the bounding boxes along with the classification prediction and confidence score. This eliminates the need for a separate preprocessing step for detection, making the workflow more efficient and easier to understand overall. The code is based in parts on the yolov8 example from the keras-io repository adapted for the purposes. This may suit as an example implementation for readers to start their first steps into image classification.
The model was trained using manually classified images. Labeling was conducted in LabelStudio, allowing university students and assistants to contribute easily to the project. The annotations were then exported as XML files, which were loaded into memory and converted into a tf.data.Dataset
to create a data pipeline for the model. This approach enabled a continuous process for feeding the model with an ever-growing set of training data, improving its accuracy over time.
1def parse_annotation(xml_file):
2 """Reads XML files, finds image name and paths and iterates over each object in the XML file.
3 It extracts the bounding box coordinates and class labels for each object.
4
5 The function returns three values: the image path, a list of bounding boxes (each
6 represented as a list of four floats: xmin, ymin, xmax, ymax), and a list of class IDs
7 (represented as integers) corresponding to each bounding box. The class IDs are obtained
8 by mapping the class labels to integer values using a dictionary called `class_mapping`.
9
10 Parameters
11 ----------
12 xml_file (str): the XML file to parse
13
14 Returns
15 -------
16 image_path (str): the path to the image
17 boxes (list): a list of bounding boxes
18 class_ids (list): a list of class IDs
19 """
20 tree = ET.parse(xml_file)
21 root = tree.getroot()
22
23 image_name = root.find("filename").text
24 image_path = os.path.join(path_labelled, image_name)
25
26 boxes = []
27 classes = []
28 for obj in root.iter("object"):
29 cls = obj.find("name").text
30 classes.append(cls)
31
32 bbox = obj.find("bndbox")
33 xmin = float(bbox.find("xmin").text)
34 ymin = float(bbox.find("ymin").text)
35 xmax = float(bbox.find("xmax").text)
36 ymax = float(bbox.find("ymax").text)
37 boxes.append([xmin, ymin, xmax, ymax])
38
39 class_ids = [
40 list(class_mapping.keys())[list(class_mapping.values()).index(cls)]
41 for cls in classes
42 ]
43 return image_path, boxes, class_ids
44
45
46image_paths = []
47bbox = []
48classes = []
49
50for xml_file in tqdm(xml_files):
51 image_path, boxes, class_ids = parse_annotation(xml_file)
52 image_paths.append(image_path)
53 bbox.append(boxes)
54 classes.append(class_ids)
55
56
57"""
58We are using `tf.ragged.constant` to create ragged tensors from the `bbox` and
59`classes` lists. A ragged tensor is a type of tensor that can handle varying lengths of
60data along one or more dimensions. This is useful when dealing with data that has
61variable-length sequences, such as text or time series data.
62
63```python
64classes = [
65 [8, 8, 8, 8, 8], # 5 classes
66 [12, 14, 14, 14], # 4 classes
67 [1], # 1 class
68 [7, 7], # 2 classes
69 ...]
70```
71
72```python
73bbox = [
74 [[199.0, 19.0, 390.0, 401.0],
75 [217.0, 15.0, 270.0, 157.0],
76 [393.0, 18.0, 432.0, 162.0],
77 [1.0, 15.0, 226.0, 276.0],
78 [19.0, 95.0, 458.0, 443.0]], #image 1 has 4 objects
79 [[52.0, 117.0, 109.0, 177.0]], #image 2 has 1 object
80 [[88.0, 87.0, 235.0, 322.0],
81 [113.0, 117.0, 218.0, 471.0]], #image 3 has 2 objects
82 ...]
83```
84
85In this case, the `bbox` and `classes` lists have different lengths for each image,
86depending on the number of objects in the image and the corresponding bounding boxes and
87classes. To handle this variability, ragged tensors are used instead of regular tensors.
88
89Later, these ragged tensors are used to create a `tf.data.Dataset` using the
90`from_tensor_slices` method. This method creates a dataset from the input tensors by
91slicing them along the first dimension. By using ragged tensors, the dataset can handle
92varying lengths of data for each image and provide a flexible input pipeline for further
93processing.
94"""
95
96bbox = tf.ragged.constant(bbox)
97classes = tf.ragged.constant(classes)
98image_paths = tf.ragged.constant(image_paths)
99data = tf.data.Dataset.from_tensor_slices((image_paths, classes, bbox))
100
101
102# Determine the number of validation samples
103num_val = int(len(xml_files) * config.SPLIT_RATIO)
104
105# Split the dataset into train and validation sets
106val_data = data.take(num_val)
107train_data = data.skip(num_val)
108
109
110
111"""
112Bounding boxes in KerasCV have a predetermined format. To do this, you must bundle your bounding
113boxes into a dictionary that complies with the requirements listed below:
114
115```python
116bounding_boxes = {
117 # num_boxes may be a Ragged dimension
118 'boxes': Tensor(shape=[batch, num_boxes, 4]),
119 'classes': Tensor(shape=[batch, num_boxes])
120}
121```
122
123The dictionary has two keys, `'boxes'` and `'classes'`, each of which maps to a
124TensorFlow RaggedTensor or Tensor object. Ragged tensors are the TensorFlow equivalent
125of nested variable-length lists. They make it easy to store and process data with
126non-uniform shapes. The `'boxes'` Tensor has a shape of `[batch, num_boxes, 4]`,
127where batch is the number of images in the batch and num_boxes is the maximum number
128of bounding boxes in any image. The 4 represents the four values needed to define a
129bounding box: xmin, ymin, xmax, ymax.
130
131The `'classes'` Tensor has a shape of `[batch, num_boxes]`, where each element represents
132the class label for the corresponding bounding box in the `'boxes'` Tensor. The num_boxes
133dimension may be ragged, which means that the number of boxes may vary across images in
134the batch.
135
136Final dict should be:
137```python
138{"images": images, "bounding_boxes": bounding_boxes}
139```
140"""
One of the biggest challenges in building object detection pipelines is data augmentation. This process involves applying various transformations to input images to increase training data diversity and improve the model’s ability to generalize. However, in object detection tasks, augmentation becomes even more complex, as transformations must also account for bounding boxes and update them accordingly.
KerasCV simplifies this process by providing native support for bounding box augmentation. It offers a comprehensive set of data augmentation layers specifically designed to handle bounding boxes. These layers intelligently adjust bounding box coordinates as the image undergoes transformations, ensuring they remain accurate and properly aligned with the augmented images.
Lastly in preparation for the training it is necessary to split the datasets into reasonably sized training and validation data in order to ensure that we can verify the fit of our model.
1augmenter = keras.Sequential(
2 layers=[
3 keras_cv.layers.RandomFlip(
4 mode="horizontal",
5 bounding_box_format=config.BOUNDING_BOX_FORMAT
6 ),
7 keras_cv.layers.RandomShear(
8 x_factor=0.2,
9 y_factor=0.2,
10 bounding_box_format=config.BOUNDING_BOX_FORMAT
11 ),
12 keras_cv.layers.JitteredResize(
13 target_size = (
14 config.IMAGE_DIMENSIONS_WIDTH,
15 config.IMAGE_DIMENSIONS_HEIGHT
16 ),
17 scale_factor = (
18 0.75,
19 1.3
20 ),
21 bounding_box_format = config.BOUNDING_BOX_FORMAT
22 ),
23 ]
24)
25
26# Creating Training Dataset
27train_ds = train_data.map(load_dataset, num_parallel_calls=tf.data.AUTOTUNE)
28train_ds = train_ds.shuffle(config.BATCH_SIZE * 4)
29train_ds = train_ds.ragged_batch(config.BATCH_SIZE, drop_remainder=True)
30train_ds = train_ds.map(augmenter, num_parallel_calls=tf.data.AUTOTUNE)
31
32# Creating Validation Dataset
33resizing = keras_cv.layers.JitteredResize(
34 target_size=(
35 config.IMAGE_DIMENSIONS_WIDTH,
36 config.IMAGE_DIMENSIONS_HEIGHT
37 ),
38 scale_factor=(
39 0.75,
40 1.3
41 ),
42 bounding_box_format=config.BOUNDING_BOX_FORMAT,
43)
44
45val_ds = val_data.map(load_dataset, num_parallel_calls=tf.data.AUTOTUNE)
46val_ds = val_ds.shuffle(config.BATCH_SIZE * 4)
47val_ds = val_ds.ragged_batch(config.BATCH_SIZE, drop_remainder=True)
48val_ds = val_ds.map(resizing, num_parallel_calls=tf.data.AUTOTUNE)
49
50# prepare the inputs for the model
51def dict_to_tuple(inputs):
52 return inputs["images"], inputs["bounding_boxes"]
53
54train_ds = train_ds.map(dict_to_tuple, num_parallel_calls=tf.data.AUTOTUNE)
55train_ds = train_ds.prefetch(tf.data.AUTOTUNE)
56
57val_ds = val_ds.map(dict_to_tuple, num_parallel_calls=tf.data.AUTOTUNE)
58val_ds = val_ds.prefetch(tf.data.AUTOTUNE)
Next, we’ll build a YOLOv8 model using the YOLOV8Detector. This model requires several key parameters:
backbone
: Specifies the feature extractor used for learning representations.num_classes
: Defines the number of object classes to detect, based on the size of the class_mapping list.bounding_box_format
: Informs the model of the bounding box format used in the dataset, ensuring compatibility.fpn_depth
: Determines the depth of the Feature Pyramid Network (FPN), which enhances multi-scale feature detection.
By configuring these parameters appropriately, we can optimize the YOLOv8 model for our specific object detection task.
1yolo = keras_cv.models.YOLOV8Detector(
2 backbone=backbone,
3 num_classes=len(class_mapping),
4 bounding_box_format=config.BOUNDING_BOX_FORMAT,
5 fpn_depth=config.FPN_DEPTH,
6)
The YOLOv8 model optimizes its performance using a combination of two key loss functions:
Classification Loss: This loss measures the difference between predicted and actual class probabilities. In YOLOv8,
binary_crossentropy
is used, as it is well-suited for binary classification problems. Each detected object is classified as either belonging to a specific category (e.g., person, car) or not.Box Loss: The
box_loss
function evaluates the discrepancy between predicted bounding boxes and ground truth boxes. YOLOv8 employs the Complete IoU (CIoU) metric, which extends traditional Intersection over Union (IoU) by factoring in aspect ratio, center distance, and box size. This leads to more precise localization of objects.
By jointly minimizing these loss functions, YOLOv8 enhances both classification accuracy and bounding box precision, making it highly effective for object detection tasks. For more specific settings consult the KerasCV documentation on the Adam optimizer.
1optimizer = tf.keras.optimizers.Adam(
2 learning_rate=config.LEARNING_RATE,
3 global_clipnorm=config.GLOBAL_CLIPNORM,
4)
5
6yolo.compile(
7 classification_loss="binary_crossentropy",
8 box_loss="ciou",
9 optimizer=optimizer,
10 jit_compile=False,
11)
We will use BoxCOCOMetrics to evaluate the model's performance by calculating key metrics such as Mean Average Precision (mAP), Recall, and Precision. These metrics provide a comprehensive assessment of how well the model detects and classifies objects.Additionally, we will implement a checkpointing mechanism to save the model whenever the mAP score improves. This ensures that we retain the best-performing version of the model throughout the training process.
1class EvaluateCOCOMetricsCallback(keras.callbacks.Callback):
2 def __init__(self, data, save_path):
3 super().__init__()
4 self.data = data
5 self.metrics = keras_cv.metrics.BoxCOCOMetrics(
6 bounding_box_format=config.BOUNDING_BOX_FORMAT,
7 evaluate_freq=1e9,
8 )
9
10 self.save_path = save_path
11 self.best_map = -1.0
12
13 def on_epoch_end(self, epoch, logs):
14 self.metrics.reset_state()
15 for batch in self.data:
16 images, y_true = batch[0], batch[1]
17 y_pred = self.model.predict(images, verbose=0)
18 self.metrics.update_state(y_true, y_pred)
19
20 metrics = self.metrics.result(force=True)
21 logs.update(metrics)
22
23 current_map = metrics["MaP"]
24 if current_map >= self.best_map:
25 self.best_map = current_map
26 self.model.save(self.save_path)
27
28 return logs
29
30# Train the model on the training dataset and evaluate it on the validation dataset
31history = yolo.fit(
32 train_ds,
33 epochs=config.NUM_EPOCHS,
34 validation_data=val_ds,
35 shuffle=True,
36 callbacks=[
37 EvaluateCOCOMetricsCallback(
38 val_ds,
39 config.MODEL_PATH
40 )
41 ],
42 validation_freq=1,
43)
44
45# Print the model summary to get an overview of the model architecture and the number of parameters
46yolo.summary()
The trained model can then be applied to video frames to detect ants and wounded areas, identifying their respective bounding boxes. By simply checking whether an ant falls within the wound’s bounding box, we can quantify the number of frames where ants are present on unwounded branches versus those where they interact with wounds.
With the research findings expected to be published later this year, further studies are likely to follow, building on these results to deepen our understanding of the ants' behavior and potential symbiotic relationships.
For the full code including documentation please refer to the GitHub repository.
This post was part of a research supporting a study on the behavior of ants in Costa Rica. Follow along for more machine learning post like this one.