반응형
⭐ Yolov5에서 데이터를 load하는 코드 dataloaders.py
진행 중인 연구에서 이미지 전처리를 해야해서 분석한 코드 내용을 정리하였다.
아래 yolov5 깃허브 코드 분석 내용이다.
https://github.com/ultralytics/yolov5/blob/master/utils/dataloaders.py
Multiple stream의 데이터를 load하는 LoadStreams 클래스,
폴더에 있는 이미지를 load하는 LoadImages 클래스 웹 캠으로부터 받은 이미지를 load하는 LoadWebcam 클래스
training을 위한 이미지 데이터와 라벨을 load하는 LoadImageAndLables 클래스
와 부수적인 클래스들로 구성되어 있는 파일이다.
LoadImage 클래스
- 입력받은 source 경로나 파일로부터 이미지를 load하는 클래스이다.
1. Init 함수
class LoadImages:
# YOLOv5 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4`
def __init__(self, path, img_size=640, stride=32, auto=True, transforms=None, vid_stride=1):
files = []
for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
p = str(Path(p).resolve())
if '*' in p:
files.extend(sorted(glob.glob(p, recursive=True))) # glob
elif os.path.isdir(p):
files.extend(sorted(glob.glob(os.path.join(p, '*.*')))) # dir
elif os.path.isfile(p):
files.append(p) # files
else:
raise FileNotFoundError(f'{p} does not exist')
images = [x for x in files if x.split('.')[-1].lower() in IMG_FORMATS]
videos = [x for x in files if x.split('.')[-1].lower() in VID_FORMATS]
ni, nv = len(images), len(videos)
self.img_size = img_size
self.stride = stride
self.files = images + videos
self.nf = ni + nv # number of files
self.video_flag = [False] * ni + [True] * nv
self.mode = 'image'
self.auto = auto
self.transforms = transforms # optional
self.vid_stride = vid_stride # video frame-rate stride
if any(videos):
self._new_video(videos[0]) # new video
else:
self.cap = None
assert self.nf > 0, f'No images or videos found in {p}. ' \
f'Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}'
- init 함수
- files list : 탐색할 이미지 list
- for문 path(입력 받은 경로 source)에 있는 값이 list나 tuple이 아니면 []로 list로 만들어준다
- path 전체에 대해서 경로 아래에 있는 file들을 files list에 추가
- images, videos 변수 초기화
- ni,nv 개수 변수 초기화
- 기타 변수 초기화
- 비디오가 있다면 비디오 캡쳐 함수
- 아니라면 self.cap 초기화
- 파일의 개수 nf가 존재하는 경우에만 진행.
2. Next , iter
def __next__(self):
if self.count == self.nf:
raise StopIteration
path = self.files[self.count]
if self.video_flag[self.count]:
# Read video
self.mode = 'video'
ret_val, im0 = self.cap.read()
self.cap.set(cv2.CAP_PROP_POS_FRAMES, self.vid_stride * (self.frame + 1)) # read at vid_stride
while not ret_val:
self.count += 1
self.cap.release()
if self.count == self.nf: # last video
raise StopIteration
path = self.files[self.count]
self._new_video(path)
ret_val, im0 = self.cap.read()
self.frame += 1
# im0 = self._cv2_rotate(im0) # for use if cv2 autorotation is False
s = f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: '
else:
# Read image
self.count += 1
im0 = cv2.imread(path) # BGR
assert im0 is not None, f'Image Not Found {path}'
s = f'image {self.count}/{self.nf} {path}: '
if self.transforms:
im = self.transforms(im0) # transforms
else:
im = letterbox(im0, self.img_size, stride=self.stride, auto=self.auto)[0] # padded resize
im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
im = np.ascontiguousarray(im) # contiguous
return path, im, im0, self.cap, s
def __iter__(self):
self.count = 0
return self
- count 개수가 차면 stopIteration
- path = self.files [self.count]
- 비디오 인 경우 비디오로 읽어오기
- 아닌 경우 image로 읽어 들이기
- 이미지 개수 숫자 count ++
- imread로 읽어들이기
- image transform 과정 거치기
LoadStreams 클래스
1. Init함수
class LoadStreams:
# YOLOv5 streamloader, i.e. `python detect.py --source 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`
def __init__(self, sources='streams.txt', img_size=640, stride=32, auto=True):
self.mode = 'stream'
self.img_size = img_size
self.stride = stride
if os.path.isfile(sources):
with open(sources) as f:
sources = [x.strip() for x in f.read().strip().splitlines() if len(x.strip())]
else:
sources = [sources]
n = len(sources)
self.imgs, self.fps, self.frames, self.threads = [None] * n, [0] * n, [0] * n, [None] * n
self.sources = [clean_str(x) for x in sources] # clean source names for later
self.auto = auto
for i, s in enumerate(sources): # index, source
# Start thread to read frames from video stream
st = f'{i + 1}/{n}: {s}... '
if urlparse(s).hostname in ('www.youtube.com', 'youtube.com', 'youtu.be'): # if source is YouTube video
check_requirements(('pafy', 'youtube_dl==2020.12.2'))
import pafy
s = pafy.new(s).getbest(preftype="mp4").url # YouTube URL
s = eval(s) if s.isnumeric() else s # i.e. s = '0' local webcam
if s == 0:
assert not is_colab(), '--source 0 webcam unsupported on Colab. Rerun command in a local environment.'
assert not is_kaggle(), '--source 0 webcam unsupported on Kaggle. Rerun command in a local environment.'
cap = cv2.VideoCapture(s)
assert cap.isOpened(), f'{st}Failed to open {s}'
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS) # warning: may return 0 or nan
self.frames[i] = max(int(cap.get(cv2.CAP_PROP_FRAME_COUNT)), 0) or float('inf') # infinite stream fallback
self.fps[i] = max((fps if math.isfinite(fps) else 0) % 100, 0) or 30 # 30 FPS fallback
_, self.imgs[i] = cap.read() # guarantee first frame
self.threads[i] = Thread(target=self.update, args=([i, cap, s]), daemon=True)
LOGGER.info(f"{st} Success ({self.frames[i]} frames {w}x{h} at {self.fps[i]:.2f} FPS)")
self.threads[i].start()
LOGGER.info('') # newline
# check for common shapes
s = np.stack([letterbox(x, self.img_size, stride=self.stride, auto=self.auto)[0].shape for x in self.imgs])
self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
if not self.rect:
LOGGER.warning('WARNING: Stream shapes differ. For optimal performance supply similarly-shaped streams.')
- init 함수
- files list : 탐색할 이미지 list
- for문 path(입력 받은 경로 source)에 있는 값이 list나 tuple이 아니면 []로 list로 만들어준다
- path 전체에 대해서 경로 아래에 있는 file들을 files list에 추가
- images, videos 변수 초기화
- ni,nv 개수 변수 초기화
- 기타 변수 초기화
- 비디오가 있다면 비디오 캡쳐 함수
- 아니라면 self.cap 초기화
- 파일의 개수 nf가 존재하는 경우에만 진행.
2. update
- load streams의 update함수. 이 함수가 thread가 됨.
def update(self, i, cap, stream):
# Read stream `i` frames in daemon thread
n, f = 0, self.frames[i] # frame number, frame array
while cap.isOpened() and n < f:
n += 1
cap.grab() # .read() = .grab() followed by .retrieve()
if n % self.vid_stride == 0:
success, im = cap.retrieve()
if success:
self.imgs[i] = im
else:
LOGGER.warning('WARNING: Video stream unresponsive, please check your IP camera connection.')
self.imgs[i] = np.zeros_like(self.imgs[i])
cap.open(stream) # re-open stream if signal was lost
time.sleep(0.0) # wait time
- frame의 수 n과 frame array f 초기화
- video가 열려있는 동안에, frame의 수보다 개수가 작은 동안에
- cap.grab 이미지를 가지고 와서 self.imgs에 추가
3. iter, next, len
def __iter__(self):
self.count = -1
return self
def __next__(self):
self.count += 1
if not all(x.is_alive() for x in self.threads) or cv2.waitKey(1) == ord('q'): # q to quit
cv2.destroyAllWindows()
raise StopIteration
im0 = self.imgs.copy()
if self.transforms:
im = np.stack([self.transforms(x) for x in im0]) # transforms
else:
im = np.stack([letterbox(x, self.img_size, stride=self.stride, auto=self.auto)[0] for x in im0]) # resize
im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW
im = np.ascontiguousarray(im) # contiguous
return self.sources, im, im0, None, ''
def __len__(self):
return len(self.sources) # 1E12 frames = 32 streams at 30 FPS for 30 years
- iter
- next
- thread가 더 이상 존재하지 않으면 raise stopIteration.
- 그게 아니라면..
- 가져온 img transform 과정 거치기
- len
- 전체 source의 length를 return 하는 함수
LoadMultiImages Class
1. Init 함수
- list of files
- thread 생성하고 start
2. update
3. iter, next , len
#load images from multiple dir with theading
class LoadMultiImages:
def __init__(self, sources='dir.txt',img_size = 640, stride =32, auto =True ):
self.mode = "dir"
self.img_size = img_size
self.stride = stride
#open txt file
if os.path.isfile(sources):
with open(sources) as f:
sources = [x.strip() for x in f.read().strip().splitlines() if len(x.strip)]
else:
sources = [sources]
n = len(sources)
self.images, self.fps, self.frames, self.threads = [None]*n , [0]*n , [0]*n , [None]*n
self.sources = [clean_str(x) for x in sources] # clean source names for later
self.auto = auto
self.ni = 0
for i , s in enumerate(sources):
files = []
for p in sorted(s) if isinstance(s,(list, tuple)) else [s]:
p = str(Path(p).resolve())
if '*' in p :
files.extend(sorted(glob.glob(p,recursive=True)))
elif os.path.isdir(p):
files.extend(sorted(glob.glob(os.path.join(p,'*.*'))))
elif os.path.isfile(p):
files.append(p)
else:
raise FileNotFoundError(f'{p} does not exist')
cap = None
self.images = [x for x in files if x.split('.')[-1].lower() in IMG_FORMATS]
self.ni = len(self.images)
self.threads[i] = Thread(target=self.update , args=[i, cap, files],daemon= True)
self.threads[i].start()
def update(self, i, cap, files):
n, f = 0 , self.frames #frame number, frame array
while n<f:
n += 1
본 포스팅에서 yolov5의 dataloader 코드 부분을 간단하게 나마 정리해보았다.
반응형
'AI Research > Object Detection' 카테고리의 다른 글
[Object Detection] MMdetection으로 RPN 실행하기 (0) | 2023.04.08 |
---|---|
[Object Detection] Python image padding 하는 방법 (0) | 2023.04.08 |
[Object Detection] 오픈소스로 detection 결과 mAP 측정하기 (0) | 2023.01.13 |
[Object Detection] mmdetection의 input image size 변경하는 방법 (Faster r-cnn) (0) | 2022.12.28 |
[Object Detection] YOLOv5로 multi stream, multi camera object detection하는 법/ 동시에 여러 video detection 진행하기 (3) | 2022.08.08 |