YOLO-Behaviour
  • Is YOLO-Behaviour appropriate for my study system?
  • Introduction and installation
  • Image annotation
  • Model training
  • Visualization and inference
  • Validation and grid search
  • Human in the loop methods
YOLO-Behaviour
  • Validation and grid search
  • View page source

Validation and grid search

Here, I will introduce the validation and grid search workflows. This will be different for each study system, and will depend on the format of the manual annotation you have.

This page would be slightly different from the other pages in the documentation, as I will give general guidance on generating evaluation metrics within python, so that you can readily adopt this to your own system. As with other pages, I will provide sample code for the Siberian Jay system.

Running the sample code

If you would just like to try out the sample code, here is how you can run it for the Siberian jay system:

python Code/6_SampleValidation.py

Now, I will try to breakdown the validation procedure using the sample script for the Siberian Jay dataset. The overall aim is to create two python lists, one as the ground truth, and one as the predicted classes. Then, we can use the sklearn library to calculate all kinds of evaluation metrics. How the ground truth and predicted classes are matched up will depend on the format of the manual annotation you have, that’s why this step will require some data wrangling.

Here, I defined the two empty lists before looping through the videos to get predictions. Throughout the script, I will then populate these lists with the ground truth and predicted classes.

24def EventValidation(AllVideoNames,AllBORISFiles,BehavHyperParam):
25
26    HyperParam = BehavHyperParam["Eat"]
27    SortTracker = Sort(max_age=HyperParam["max_age"],min_hits=HyperParam["min_hits"],iou_threshold=HyperParam["iou_threshold"])
28    GT_Data = []
29    Pred_Data = []
30
31
32    for vidPath in tqdm(AllVideoNames):
33        print(vidPath)
34

Then for each video, I loaded the detection pickle file (See Visualization and inference for more details), then used the SORT tracker to get behavioural events.

35        DetectionDict = pickle.load(open(vidPath,"rb"))
36        file = [filePath for filePath in AllBORISFiles if os.path.basename(vidPath).split(".")[0] in filePath][0]
37
38
39        ### Part 1: Process YOLO predictions first##
40        TrackingOutList = []
41        for frame,frameDict in DetectionDict.items():
42            detectedClass = []
43            [detectedClass.append(frameDict["bbox"][x] + [frameDict["conf"][x]])  for x in range(len(frameDict["Class"])) if frameDict["Class"][x] == "Eat" and frameDict["conf"][x] > BehavHyperParam[frameDict["Class"][x]]["YOLO_Threshold"]]
44
45            
46            ##Update SORT trackers:
47            if len(detectedClass)>0:
48                TrackingOut = SortTracker.update(np.array([detectedClass]).reshape(len(detectedClass),5))
49            else:
50                TrackingOut = SortTracker.update(np.empty((0,5)))
51            if len(TrackingOut) == 0:
52                TrackingOut = np.zeros((1,5))
53            TrackingOutList.append(TrackingOut.tolist())
54
55        ###Get start end time of behaviours
56        ### index 4 is the track id
57        UnqEvents = list(set([bbox[4] for framelist in TrackingOutList for bbox in framelist]))
58        BadTracks = []
59        for track in UnqEvents:
60            if track ==0:
61                continue
62            TrackIndexes = [frame for frame in range(len(TrackingOutList)) for bbox in TrackingOutList[frame] if bbox[4] == track]
63            if len(TrackIndexes) < HyperParam["min_duration"]:
64                BadTracks.append(track)
65
66
67        PredList = []
68        for i in range(len(TrackingOutList)):
69            tracks= [bbox[4] for bbox in TrackingOutList[i] if sum(bbox)>0 and bbox[4] not in BadTracks]
70
71            PredList.append(len(tracks)) #append how many tracks

This results in a list of bounding boxes and tracking IDs for each frame in the video. We will use this list to match up with the ground truth later.

Next, we will load the manual annotation from BORIS. The BORIS data is annotated as absolute time, so we need to multiply it by the frame rate (25 here) to get back frame number. We then populate a list with the same length as the predicted list, then count the number of individuals “feeding” for a given frame

74        #####Part 2: Get ground truth####
75        df = pd.read_csv(file)
76
77        ###round time col to frame number
78        df["RoundTime"] = df["Time"].apply(lambda x: round(x*25)) ##get frame number instead
79
80        RoundTimeCounts = df.groupby("RoundTime").apply(lambda x: len(x["Subject"].unique()))
81
82        GTList = [0]*len(PredList)
83        for i in RoundTimeCounts.index:
84            if i >= len(GTList):
85                break
86            GTList[i] = RoundTimeCounts.loc[i]

Next, we will match up the ground truth and predicted classes, but with a certain time window. This would be different for different kind of data, we refer to the original publication for discussion on this. For the typical cases, perhaps you only have to match up whether a predicted class was correctly identifying an event, but here because of the slight mismatch in manual annotation and the automated method, we need to summarize datapoints into time windows.

88        ####Adjust window
89        windowThresh = 2*25
90        EqualChunks = [PredList[i:i + windowThresh] for i in range(0, len(PredList), windowThresh)]
91        PredList = [max(x) for x in EqualChunks] #maximum number detected within this chunk
92
93        EqualChunks = [GTList[i:i + windowThresh] for i in range(0, len(GTList), windowThresh)]
94        GTList = [max(x) for x in EqualChunks] #maximum number detected within this chunk
95
96

Finally, we will loop through the time window lists and populate the global prediction and ground truth lists.

 97        ### if any pecking detected within window, then matches
 98        for i in range(len(GTList)):
 99            if GTList[i]==PredList[i] == 0:
100                GT_Data.append("Not Pecking")
101                Pred_Data.append("Not Pecking")
102                continue
103            elif GTList[i] > 0 and PredList[i] > 0:
104                GT_Data.append("Pecking")
105                Pred_Data.append("Pecking")
106
107            elif GTList[i] > 0 and PredList[i] == 0:
108                GT_Data.append("Pecking")
109                Pred_Data.append("Not Pecking")
110
111            elif GTList[i] == 0 and PredList[i] > 0:
112                GT_Data.append("Not Pecking")
113                Pred_Data.append("Pecking")
114

Now with the ground truth and predicted classes, we can use the sklearn library to calculate all kinds of evaluation metrics.

116    ##Metrics
117    GT_Data = [str(y) for y in GT_Data]
118    Pred_Data = [str(y) for y in Pred_Data]
119
120    SumCorrect = [1 for i in range(len(GT_Data)) if GT_Data[i] == Pred_Data[i]]
121    LabelWithData = [1 for i in range(len(GT_Data)) if GT_Data[i] != "nan"]
122    print(len(LabelWithData))
123
124    Accuracy = sum(SumCorrect)/len(GT_Data)
125    print("Overall Accuracy:%s"%Accuracy)
126
127    labels= ["Pecking","Not Pecking"]
128
129    cm = confusion_matrix(GT_Data, Pred_Data, labels=labels,normalize="true")
130    
131
132    labels = ["Eating","Not Eating"]
133    colour = sns.cubehelix_palette(as_cmap=True)
134    plt.figure(figsize=(6,6))
135    sns.heatmap(cm, annot=True, xticklabels=labels, 
136                yticklabels=labels,cbar=False, cmap=colour,
137                annot_kws={"fontsize":16})
138
139
140    plt.xlabel("Predicted",fontsize=20)
141    plt.ylabel("Manual", fontsize=20)
142    plt.xticks(fontsize = 16)
143    plt.yticks(fontsize = 16)
144    plt.tight_layout()
145    plt.show()    
146
147    OutReport = classification_report(GT_Data, Pred_Data, output_dict=False)
148    print(OutReport)
149    
150    return OutReport

The code above will calculate the precision, recall, F1 score, and confusion matrix for the two lists of predicted and ground truth classes.

Grid search

After figuring out the validation pipeline and obtaining metrics, you can then proceed to grid search to optimize hyperparameters of the SORT tracker. This step is not strictly necessary, but is just a standardize way of making sure everything is optimized for your system.

You can run this in the terminal for the Siberian jay dataset:

python Code/7_SampleGridSearch.py

The core of the script is essentially the same as the validation script, but the difference is that here, we loop through the validation a lot of times to test out all cominbations of hyperparameters.

151    ###Define range of values to explore
152    HyperParams = {"max_age": list(range(1,26,5)),
153                    "min_hits": list(range(1,6,2)), 
154                    "iou_threshold": list(np.arange(0.1,0.5,0.1)), 
155                    "min_duration": list(range(1,20,2)),
156                    "YOLO_Threshold":list(np.arange(0.1,0.9,0.1))}
157    GridSearchDF = expand_grid(HyperParams)
158

Here is where we define the range of parameters we want to explore in the Code/7_SampleGridSearch.py script. The script will then loop through the combinations of these paramters and save it as a csv.

We do note that the validation function can be essentially the same as the one used for validation above, except for a small change at the end to save the validation metrics instead of printing it.

124    Precision,Recall,fbeta_score,support = sklearn.metrics.precision_recall_fscore_support(GT_Data, Pred_Data, labels=Labels, average=None, sample_weight=None, zero_division='warn')
125    BehavIndex = labels.index("Pecking")
126
127    BehavOut = {"precision":Precision[BehavIndex],"recall":Recall[BehavIndex],"f1-score":fbeta_score[BehavIndex],"support":support[BehavIndex]}
128
129    return iter,BehavOut

After running the script, the results will be saved in a csv file, then you can choose your best hyper parameters!

Hope this whole workflow was clear, and here is the whole YOLO-Behaviour pipeline! If you have any questions, feel free to contact me at hoi-hang.chan[at]uni-konstanz.de

Previous Next

© Copyright 2024, Alex Chan Hoi Hang.

Built with Sphinx using a theme provided by Read the Docs.