AI Research/Object Detection

[Object Detection] Yolov5 dataloaders.py 코드 분석

wawawaaw 2023. 4. 8. 01:35
반응형

⭐ Yolov5에서 데이터를 load하는 코드 dataloaders.py

진행 중인 연구에서 이미지 전처리를 해야해서 분석한 코드 내용을 정리하였다. 

아래 yolov5 깃허브 코드 분석 내용이다. 

https://github.com/ultralytics/yolov5/blob/master/utils/dataloaders.py

 

GitHub - ultralytics/yolov5: YOLOv5 🚀 in PyTorch > ONNX > CoreML > TFLite

YOLOv5 🚀 in PyTorch > ONNX > CoreML > TFLite. Contribute to ultralytics/yolov5 development by creating an account on GitHub.

github.com

Multiple stream의 데이터를 load하는 LoadStreams 클래스,

폴더에 있는 이미지를 load하는 LoadImages 클래스 웹 캠으로부터 받은 이미지를 load하는 LoadWebcam 클래스

training을 위한 이미지 데이터와 라벨을 load하는 LoadImageAndLables 클래스

와 부수적인 클래스들로 구성되어 있는 파일이다.

 

LoadImage 클래스

  • 입력받은 source 경로나 파일로부터 이미지를 load하는 클래스이다.

1. Init 함수

class LoadImages:
    # YOLOv5 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4`
    def __init__(self, path, img_size=640, stride=32, auto=True, transforms=None, vid_stride=1):
        files = []
        for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
            p = str(Path(p).resolve())
            if '*' in p:
                files.extend(sorted(glob.glob(p, recursive=True)))  # glob
            elif os.path.isdir(p):
                files.extend(sorted(glob.glob(os.path.join(p, '*.*'))))  # dir
            elif os.path.isfile(p):
                files.append(p)  # files
            else:
                raise FileNotFoundError(f'{p} does not exist')

        images = [x for x in files if x.split('.')[-1].lower() in IMG_FORMATS]
        videos = [x for x in files if x.split('.')[-1].lower() in VID_FORMATS]
        ni, nv = len(images), len(videos)

        self.img_size = img_size
        self.stride = stride
        self.files = images + videos
        self.nf = ni + nv  # number of files
        self.video_flag = [False] * ni + [True] * nv
        self.mode = 'image'
        self.auto = auto
        self.transforms = transforms  # optional
        self.vid_stride = vid_stride  # video frame-rate stride
        if any(videos):
            self._new_video(videos[0])  # new video
        else:
            self.cap = None
        assert self.nf > 0, f'No images or videos found in {p}. ' \
                            f'Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}'
  • init 함수
    • files list : 탐색할 이미지 list
    • for문 path(입력 받은 경로 source)에 있는 값이 list나 tuple이 아니면 []로 list로 만들어준다
      • path 전체에 대해서 경로 아래에 있는 file들을 files list에 추가
    • images, videos 변수 초기화
    • ni,nv 개수 변수 초기화
    • 기타 변수 초기화
    • 비디오가 있다면 비디오 캡쳐 함수
    • 아니라면 self.cap 초기화
    • 파일의 개수 nf가 존재하는 경우에만 진행.

2. Next , iter

def __next__(self):
        if self.count == self.nf:
            raise StopIteration
        path = self.files[self.count]

        if self.video_flag[self.count]:
            # Read video
            self.mode = 'video'
            ret_val, im0 = self.cap.read()
            self.cap.set(cv2.CAP_PROP_POS_FRAMES, self.vid_stride * (self.frame + 1))  # read at vid_stride
            while not ret_val:
                self.count += 1
                self.cap.release()
                if self.count == self.nf:  # last video
                    raise StopIteration
                path = self.files[self.count]
                self._new_video(path)
                ret_val, im0 = self.cap.read()

            self.frame += 1
            # im0 = self._cv2_rotate(im0)  # for use if cv2 autorotation is False
            s = f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: '

        else:
            # Read image
            self.count += 1
            im0 = cv2.imread(path)  # BGR
            assert im0 is not None, f'Image Not Found {path}'
            s = f'image {self.count}/{self.nf} {path}: '

        if self.transforms:
            im = self.transforms(im0)  # transforms
        else:
            im = letterbox(im0, self.img_size, stride=self.stride, auto=self.auto)[0]  # padded resize
            im = im.transpose((2, 0, 1))[::-1]  # HWC to CHW, BGR to RGB
            im = np.ascontiguousarray(im)  # contiguous

        return path, im, im0, self.cap, s

def __iter__(self):
        self.count = 0
        return self
  • count 개수가 차면 stopIteration
  • path = self.files [self.count]
  • 비디오 인 경우 비디오로 읽어오기
  • 아닌 경우 image로 읽어 들이기
    • 이미지 개수 숫자 count ++
    • imread로 읽어들이기
  • image transform 과정 거치기

 

LoadStreams 클래스

1. Init함수

class LoadStreams:
    # YOLOv5 streamloader, i.e. `python detect.py --source 'rtsp://example.com/media.mp4'  # RTSP, RTMP, HTTP streams`
    def __init__(self, sources='streams.txt', img_size=640, stride=32, auto=True):
        self.mode = 'stream'
        self.img_size = img_size
        self.stride = stride

        if os.path.isfile(sources):
            with open(sources) as f:
                sources = [x.strip() for x in f.read().strip().splitlines() if len(x.strip())]
        else:
            sources = [sources]

        n = len(sources)
        self.imgs, self.fps, self.frames, self.threads = [None] * n, [0] * n, [0] * n, [None] * n
        self.sources = [clean_str(x) for x in sources]  # clean source names for later
        self.auto = auto
        for i, s in enumerate(sources):  # index, source
            # Start thread to read frames from video stream
            st = f'{i + 1}/{n}: {s}... '
            if urlparse(s).hostname in ('www.youtube.com', 'youtube.com', 'youtu.be'):  # if source is YouTube video
                check_requirements(('pafy', 'youtube_dl==2020.12.2'))
                import pafy
                s = pafy.new(s).getbest(preftype="mp4").url  # YouTube URL
            s = eval(s) if s.isnumeric() else s  # i.e. s = '0' local webcam
            if s == 0:
                assert not is_colab(), '--source 0 webcam unsupported on Colab. Rerun command in a local environment.'
                assert not is_kaggle(), '--source 0 webcam unsupported on Kaggle. Rerun command in a local environment.'
            cap = cv2.VideoCapture(s)
            assert cap.isOpened(), f'{st}Failed to open {s}'
            w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
            h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
            fps = cap.get(cv2.CAP_PROP_FPS)  # warning: may return 0 or nan
            self.frames[i] = max(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), 0) or float('inf')  # infinite stream fallback
            self.fps[i] = max((fps if math.isfinite(fps) else 0) % 100, 0) or 30  # 30 FPS fallback

            _, self.imgs[i] = cap.read()  # guarantee first frame
            self.threads[i] = Thread(target=self.update, args=([i, cap, s]), daemon=True)
            LOGGER.info(f"{st} Success ({self.frames[i]} frames {w}x{h} at {self.fps[i]:.2f} FPS)")
            self.threads[i].start()
        LOGGER.info('')  # newline

        # check for common shapes
        s = np.stack([letterbox(x, self.img_size, stride=self.stride, auto=self.auto)[0].shape for x in self.imgs])
        self.rect = np.unique(s, axis=0).shape[0] == 1  # rect inference if all shapes equal
        if not self.rect:
            LOGGER.warning('WARNING: Stream shapes differ. For optimal performance supply similarly-shaped streams.')
  • init 함수
    • files list : 탐색할 이미지 list
    • for문 path(입력 받은 경로 source)에 있는 값이 list나 tuple이 아니면 []로 list로 만들어준다
      • path 전체에 대해서 경로 아래에 있는 file들을 files list에 추가
    • images, videos 변수 초기화
    • ni,nv 개수 변수 초기화
    • 기타 변수 초기화
    • 비디오가 있다면 비디오 캡쳐 함수
    • 아니라면 self.cap 초기화
    • 파일의 개수 nf가 존재하는 경우에만 진행.

2. update

  • load streams의 update함수. 이 함수가 thread가 됨.
def update(self, i, cap, stream):
        # Read stream `i` frames in daemon thread
        n, f = 0, self.frames[i]  # frame number, frame array
        while cap.isOpened() and n < f:
            n += 1
            cap.grab()  # .read() = .grab() followed by .retrieve()
            if n % self.vid_stride == 0:
                success, im = cap.retrieve()
                if success:
                    self.imgs[i] = im
                else:
                    LOGGER.warning('WARNING: Video stream unresponsive, please check your IP camera connection.')
                    self.imgs[i] = np.zeros_like(self.imgs[i])
                    cap.open(stream)  # re-open stream if signal was lost
            time.sleep(0.0)  # wait time
  • frame의 수 n과 frame array f 초기화
  • video가 열려있는 동안에, frame의 수보다 개수가 작은 동안에
    • cap.grab 이미지를 가지고 와서 self.imgs에 추가

3. iter, next, len

    def __iter__(self):
		        self.count = -1
		        return self
	
    def __next__(self):
        self.count += 1
        if not all(x.is_alive() for x in self.threads) or cv2.waitKey(1) == ord('q'):  # q to quit
            cv2.destroyAllWindows()
            raise StopIteration

        im0 = self.imgs.copy()
        if self.transforms:
            im = np.stack([self.transforms(x) for x in im0])  # transforms
        else:
            im = np.stack([letterbox(x, self.img_size, stride=self.stride, auto=self.auto)[0] for x in im0])  # resize
            im = im[..., ::-1].transpose((0, 3, 1, 2))  # BGR to RGB, BHWC to BCHW
            im = np.ascontiguousarray(im)  # contiguous

        return self.sources, im, im0, None, ''

    def __len__(self):
        return len(self.sources)  # 1E12 frames = 32 streams at 30 FPS for 30 years
  • iter
  • next
    • thread가 더 이상 존재하지 않으면 raise stopIteration.
    • 그게 아니라면..
      • 가져온 img transform 과정 거치기
  • len
    • 전체 source의 length를 return 하는 함수

LoadMultiImages Class

1. Init 함수

  • list of files
  • thread 생성하고 start

2. update

3. iter, next , len

 

#load images from multiple dir with theading
class LoadMultiImages:
    def __init__(self, sources='dir.txt',img_size = 640, stride =32, auto =True ):
        self.mode = "dir"
        self.img_size = img_size
        self.stride = stride
        
        #open txt file
        if os.path.isfile(sources):
            with open(sources) as f:
                sources = [x.strip() for x in f.read().strip().splitlines() if len(x.strip)]
        else:
            sources = [sources]

        n = len(sources)

        self.images, self.fps, self.frames, self.threads = [None]*n , [0]*n , [0]*n , [None]*n 
        self.sources = [clean_str(x) for x in sources]  # clean source names for later
        self.auto = auto

        self.ni = 0 


        for i , s in enumerate(sources):

            files = []
            for p in sorted(s) if isinstance(s,(list, tuple)) else [s]:
                p = str(Path(p).resolve())
                if '*' in p :
                    files.extend(sorted(glob.glob(p,recursive=True)))
                elif os.path.isdir(p):
                    files.extend(sorted(glob.glob(os.path.join(p,'*.*'))))
                elif os.path.isfile(p):
                    files.append(p)
                else:
                    raise FileNotFoundError(f'{p} does not exist')

                cap = None 
                self.images = [x for x in files if x.split('.')[-1].lower() in IMG_FORMATS]
                self.ni = len(self.images)


                self.threads[i] = Thread(target=self.update , args=[i, cap, files],daemon= True)
                self.threads[i].start() 

    def update(self, i, cap, files):
        n, f = 0 , self.frames #frame number, frame array
                
        while n<f:
            n += 1

 

본 포스팅에서 yolov5의 dataloader 코드 부분을 간단하게 나마 정리해보았다. 

 

반응형