Source code for yolov3.src.NetworkTester

import argparse
import time
from yolov3.src.models import *
from yolov3.src.datasets.datasets import *
from yolov3.utils.utils import *
from yolov3.src.InputFile import *

[docs]class NetworkTester(): """ Class for handling testing and assessing the performance of a trained YOLOv3 model. | **Inputs:** | *model:* trained YOLOv3 network (PyTorch .pt file). | *dataloader:* dataloader object (usually an instantiation of the ImageFolder class) | *inputs:* input file with various user-specified options """ def __init__(self, model, dataloader, inputs): self.__inputs = inputs; self.__dataloader = dataloader; self.model = model; self.setupCuda(); self.loadSavedModels(); self.loadClasses()
[docs] def setupCuda(self): """ Basic method to setup GPU/cuda support, if available """ cuda = torch.cuda.is_available() self.__device = torch.device('cuda:0' if cuda else 'cpu') random.seed(0) np.random.seed(0) torch.manual_seed(0) if cuda: torch.cuda.manual_seed(0) torch.cuda.manual_seed_all(0) torch.backends.cudnn.benchmark = True if torch.cuda.device_count() > 1: print('Using ', torch.cuda.device_count(), ' GPUs')
[docs] def loadSavedModels(self): """ Method to load a saved YOLOv3 model from a PyTorch (.pt) file. """ checkpoint = torch.load(self.__inputs.networksavefile, map_location='cpu') self.model.load_state_dict(checkpoint['model'])
[docs] def loadClasses(self): """ Method to load class names from specified path in user-input file. Format assumed shall be a csv list of (class_name , class_label)_i """ class_names,class_labels = load_classes(self.__inputs.class_path) self.__classes = class_names self.__class_labels = class_labels
[docs] def detect(self): """Method to compute object detections over testing dataset""" print('********************* NETWORK TESTING *********************') imgs = [] # Stores image paths img_detections = [] # Stores detections for each image index prev_time = time.time() detections = None for batch_i, (img_paths, img) in enumerate(self.__dataloader): print('\n', batch_i, img.shape, end=' ') img_ud = np.ascontiguousarray(np.flip(img, axis=1)) img_lr = np.ascontiguousarray(np.flip(img, axis=2)) preds = [] length = self.__inputs.imgsize ni = int(math.ceil(img.shape[1] / length)) # up-down nj = int(math.ceil(img.shape[2] / length)) # left-right for i in range(ni): # for i in range(ni - 1): print('row %g/%g: ' % (i, ni), end='') for j in range(nj): # for j in range(nj if i==0 else nj - 1): print('%g ' % j, end='', flush=True) # forward scan y2 = min((i + 1) * length, img.shape[1]) y1 = y2 - length x2 = min((j + 1) * length, img.shape[2]) x1 = x2 - length # Get detections with torch.no_grad(): # Normal orientation chip = torch.from_numpy(img[:, y1:y2, x1:x2]).unsqueeze(0) pred = self.model(chip) pred = pred[pred[:, :, 4] > self.__inputs.conf_thres] if len(pred) > 0: pred[:, 0] += x1 pred[:, 1] += y1 preds.append(pred.unsqueeze(0)) if len(preds) > 0: detections = non_max_suppression(torch.cat(preds, 1), self.__inputs.conf_thres, self.__inputs.nms_thres, opt=self.__inputs, img=img) img_detections.extend(detections) imgs.extend(img_paths) print('Batch %d... (Done %.3fs)' % (batch_i, time.time() - prev_time)) prev_time = time.time() self.__img_detections = img_detections self.__imgs = imgs
[docs] def plotDetection(self): """Method to plot and display all detected objects in the testing dataset""" # Bounding-box colors color_list = [[random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)] for _ in range(len(self.__classes))] if len(self.__img_detections) == 0: return # Iterate through images and save plot of detections for img_i, (path, detections) in enumerate(zip(self.__imgs, self.__img_detections)): print("image %g: '%s'" % (img_i, path)) if self.__inputs.plot_flag: img = cv2.imread(path) # Draw bounding boxes and labels of detections if detections is not None: unique_classes = detections[:, -1].cpu().unique() bbox_colors = random.sample(color_list, len(unique_classes)) # write results to .txt file results_path = os.path.join(self.__inputs.outdir, path.split('/')[-1]) if os.path.isfile(results_path + '.txt'): os.remove(results_path + '.txt') results_img_path = os.path.join(self.__inputs.outdir , path.split('/')[-1]) with open(results_path.replace('.bmp', '.tif') + '.txt', 'a') as file: for i in unique_classes: n = (detections[:, -1].cpu() == i).sum() print('%g %ss' % (n, self.__classes[int(i)])) for x1, y1, x2, y2, conf, cls_conf, cls_pred in detections: x1, y1, x2, y2 = max(x1, 0), max(y1, 0), max(x2, 0), max(y2, 0) # write to file class_labels = self.__class_labels[int(cls_pred)] file.write(('%g %g %g %g %g %g \n') % (x1, y1, x2, y2, class_labels, cls_conf * conf)) if self.__inputs.plot_flag: # Add the bbox to the plot label = '%s %.2f' % (self.__classes[int(cls_pred)], cls_conf) if cls_conf > self.__inputs.cls_thres else None color = bbox_colors[int(np.where(unique_classes == int(cls_pred))[0])] plot_one_box([x1, y1, x2, y2], img, label=label, color=color, line_thickness=1) if self.__inputs.plot_flag: # Save generated image with detections cv2.imwrite(results_img_path.replace('.bmp', '.jpg').replace('.tif', '.jpg'), img)