Validation and grid search
Here, I will introduce the validation and grid search workflows. This will be different for each study system, and will depend on the format of the manual annotation you have.
This page would be slightly different from the other pages in the documentation, as I will give general guidance on generating evaluation metrics within python, so that you can readily adopt this to your own system. As with other pages, I will provide sample code for the Siberian Jay system.
Running the sample code
If you would just like to try out the sample code, here is how you can run it for the Siberian jay system:
python Code/6_SampleValidation.py
Now, I will try to breakdown the validation procedure using the sample script for the Siberian Jay dataset. The overall aim is to create two python lists, one as the ground truth, and one as the predicted classes. Then, we can use the sklearn library to calculate all kinds of evaluation metrics. How the ground truth and predicted classes are matched up will depend on the format of the manual annotation you have, that’s why this step will require some data wrangling.
Here, I defined the two empty lists before looping through the videos to get predictions. Throughout the script, I will then populate these lists with the ground truth and predicted classes.
24def EventValidation(AllVideoNames,AllBORISFiles,BehavHyperParam):
25
26 HyperParam = BehavHyperParam["Eat"]
27 SortTracker = Sort(max_age=HyperParam["max_age"],min_hits=HyperParam["min_hits"],iou_threshold=HyperParam["iou_threshold"])
28 GT_Data = []
29 Pred_Data = []
30
31
32 for vidPath in tqdm(AllVideoNames):
33 print(vidPath)
34
Then for each video, I loaded the detection pickle file (See Visualization and inference for more details), then used the SORT tracker to get behavioural events.
35 DetectionDict = pickle.load(open(vidPath,"rb"))
36 file = [filePath for filePath in AllBORISFiles if os.path.basename(vidPath).split(".")[0] in filePath][0]
37
38
39 ### Part 1: Process YOLO predictions first##
40 TrackingOutList = []
41 for frame,frameDict in DetectionDict.items():
42 detectedClass = []
43 [detectedClass.append(frameDict["bbox"][x] + [frameDict["conf"][x]]) for x in range(len(frameDict["Class"])) if frameDict["Class"][x] == "Eat" and frameDict["conf"][x] > BehavHyperParam[frameDict["Class"][x]]["YOLO_Threshold"]]
44
45
46 ##Update SORT trackers:
47 if len(detectedClass)>0:
48 TrackingOut = SortTracker.update(np.array([detectedClass]).reshape(len(detectedClass),5))
49 else:
50 TrackingOut = SortTracker.update(np.empty((0,5)))
51 if len(TrackingOut) == 0:
52 TrackingOut = np.zeros((1,5))
53 TrackingOutList.append(TrackingOut.tolist())
54
55 ###Get start end time of behaviours
56 ### index 4 is the track id
57 UnqEvents = list(set([bbox[4] for framelist in TrackingOutList for bbox in framelist]))
58 BadTracks = []
59 for track in UnqEvents:
60 if track ==0:
61 continue
62 TrackIndexes = [frame for frame in range(len(TrackingOutList)) for bbox in TrackingOutList[frame] if bbox[4] == track]
63 if len(TrackIndexes) < HyperParam["min_duration"]:
64 BadTracks.append(track)
65
66
67 PredList = []
68 for i in range(len(TrackingOutList)):
69 tracks= [bbox[4] for bbox in TrackingOutList[i] if sum(bbox)>0 and bbox[4] not in BadTracks]
70
71 PredList.append(len(tracks)) #append how many tracks
This results in a list of bounding boxes and tracking IDs for each frame in the video. We will use this list to match up with the ground truth later.
Next, we will load the manual annotation from BORIS. The BORIS data is annotated as absolute time, so we need to multiply it by the frame rate (25 here) to get back frame number. We then populate a list with the same length as the predicted list, then count the number of individuals “feeding” for a given frame
74 #####Part 2: Get ground truth####
75 df = pd.read_csv(file)
76
77 ###round time col to frame number
78 df["RoundTime"] = df["Time"].apply(lambda x: round(x*25)) ##get frame number instead
79
80 RoundTimeCounts = df.groupby("RoundTime").apply(lambda x: len(x["Subject"].unique()))
81
82 GTList = [0]*len(PredList)
83 for i in RoundTimeCounts.index:
84 if i >= len(GTList):
85 break
86 GTList[i] = RoundTimeCounts.loc[i]
Next, we will match up the ground truth and predicted classes, but with a certain time window. This would be different for different kind of data, we refer to the original publication for discussion on this. For the typical cases, perhaps you only have to match up whether a predicted class was correctly identifying an event, but here because of the slight mismatch in manual annotation and the automated method, we need to summarize datapoints into time windows.
88 ####Adjust window
89 windowThresh = 2*25
90 EqualChunks = [PredList[i:i + windowThresh] for i in range(0, len(PredList), windowThresh)]
91 PredList = [max(x) for x in EqualChunks] #maximum number detected within this chunk
92
93 EqualChunks = [GTList[i:i + windowThresh] for i in range(0, len(GTList), windowThresh)]
94 GTList = [max(x) for x in EqualChunks] #maximum number detected within this chunk
95
96
Finally, we will loop through the time window lists and populate the global prediction and ground truth lists.
97 ### if any pecking detected within window, then matches
98 for i in range(len(GTList)):
99 if GTList[i]==PredList[i] == 0:
100 GT_Data.append("Not Pecking")
101 Pred_Data.append("Not Pecking")
102 continue
103 elif GTList[i] > 0 and PredList[i] > 0:
104 GT_Data.append("Pecking")
105 Pred_Data.append("Pecking")
106
107 elif GTList[i] > 0 and PredList[i] == 0:
108 GT_Data.append("Pecking")
109 Pred_Data.append("Not Pecking")
110
111 elif GTList[i] == 0 and PredList[i] > 0:
112 GT_Data.append("Not Pecking")
113 Pred_Data.append("Pecking")
114
Now with the ground truth and predicted classes, we can use the sklearn library to calculate all kinds of evaluation metrics.
116 ##Metrics
117 GT_Data = [str(y) for y in GT_Data]
118 Pred_Data = [str(y) for y in Pred_Data]
119
120 SumCorrect = [1 for i in range(len(GT_Data)) if GT_Data[i] == Pred_Data[i]]
121 LabelWithData = [1 for i in range(len(GT_Data)) if GT_Data[i] != "nan"]
122 print(len(LabelWithData))
123
124 Accuracy = sum(SumCorrect)/len(GT_Data)
125 print("Overall Accuracy:%s"%Accuracy)
126
127 labels= ["Pecking","Not Pecking"]
128
129 cm = confusion_matrix(GT_Data, Pred_Data, labels=labels,normalize="true")
130
131
132 labels = ["Eating","Not Eating"]
133 colour = sns.cubehelix_palette(as_cmap=True)
134 plt.figure(figsize=(6,6))
135 sns.heatmap(cm, annot=True, xticklabels=labels,
136 yticklabels=labels,cbar=False, cmap=colour,
137 annot_kws={"fontsize":16})
138
139
140 plt.xlabel("Predicted",fontsize=20)
141 plt.ylabel("Manual", fontsize=20)
142 plt.xticks(fontsize = 16)
143 plt.yticks(fontsize = 16)
144 plt.tight_layout()
145 plt.show()
146
147 OutReport = classification_report(GT_Data, Pred_Data, output_dict=False)
148 print(OutReport)
149
150 return OutReport
The code above will calculate the precision, recall, F1 score, and confusion matrix for the two lists of predicted and ground truth classes.
Grid search
After figuring out the validation pipeline and obtaining metrics, you can then proceed to grid search to optimize hyperparameters of the SORT tracker. This step is not strictly necessary, but is just a standardize way of making sure everything is optimized for your system.
You can run this in the terminal for the Siberian jay dataset:
python Code/7_SampleGridSearch.py
The core of the script is essentially the same as the validation script, but the difference is that here, we loop through the validation a lot of times to test out all cominbations of hyperparameters.
151 ###Define range of values to explore
152 HyperParams = {"max_age": list(range(1,26,5)),
153 "min_hits": list(range(1,6,2)),
154 "iou_threshold": list(np.arange(0.1,0.5,0.1)),
155 "min_duration": list(range(1,20,2)),
156 "YOLO_Threshold":list(np.arange(0.1,0.9,0.1))}
157 GridSearchDF = expand_grid(HyperParams)
158
Here is where we define the range of parameters we want to explore in the Code/7_SampleGridSearch.py
script.
The script will then loop through the combinations of these paramters and save it as a csv.
We do note that the validation function can be essentially the same as the one used for validation above, except for a small change at the end to save the validation metrics instead of printing it.
124 Precision,Recall,fbeta_score,support = sklearn.metrics.precision_recall_fscore_support(GT_Data, Pred_Data, labels=Labels, average=None, sample_weight=None, zero_division='warn')
125 BehavIndex = labels.index("Pecking")
126
127 BehavOut = {"precision":Precision[BehavIndex],"recall":Recall[BehavIndex],"f1-score":fbeta_score[BehavIndex],"support":support[BehavIndex]}
128
129 return iter,BehavOut
After running the script, the results will be saved in a csv file, then you can choose your best hyper parameters!
Hope this whole workflow was clear, and here is the whole YOLO-Behaviour pipeline! If you have any questions, feel free to contact me at hoi-hang.chan[at]uni-konstanz.de