[강화학습] roboflow 실습
2025. 3. 23. 18:42ㆍComputer Engineering/강화학습
1. Roboflow Universe 에서 데이터셋을 가져온다.
나 같은 경우 아래와 같은 lane 데이터셋을 다운로드 받았다.
https://universe.roboflow.com/duyen/lane-unzou/fork
lane Instance Segmentation Dataset by duyen
352 open source lane images. lane dataset by duyen
universe.roboflow.com
download format에서 본인이 필요로 하는 코드를 다운 받으면 된다. 나는 YOLOv8을 선택했다.
그렇게 다운로드 받은 코드를 실행하게 되면
data.yaml 파일이 들어 있는 폴더가 만들어지게 되는데, 이 폴더를 꼭 'datasets' 폴더 경로에 넣어야 한다.
2. YOLOv8로 학습하기
터미널에 아래 코드를 입력한다.
yolo detect train model=yolov8n.pt data=your-dataset/data.yaml epochs=50 imgsz=640
위와 같이 학습이 진행되기 때문에 시간이 생각보다 오래 걸릴 수 있다.
3. 학습이 끝난 후에는 자동으로 해당하는 경로에 모델이 저장된다.
→ 커스텀 객체 탐지 모델
4. 저장된 모델 경로를 강화학습 환경에 적용한다.
YOLO가 실시간으로 lane, stop, intersection 같은 커스텀 객체를 탐지하고,
탐지된 객체의 위치를 강화학습 환경에 넘겨줘서 PPO 에이전트가 반응하도록 연결하는 전체 코드를 작성한다.
from ultralytics import YOLO
import cv2
from stable_baselines3 import PPO
from tracking_env import MovingTargetEnv
model = YOLO("runs/detect/train7/weights/best.pt")
agent = PPO.load("ppo_tracking_agent")
env = MovingTargetEnv()
obs, _ = env.reset()
cap = cv2.VideoCapture(0)
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
results = model(frame)
result = results[0]
if result.boxes:
best_box = max(result.boxes, key=lambda b: b.conf[0].item())
x1, y1, x2, y2 = map(int, best_box.xyxy[0])
cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
env.update_target_position((cx, cy))
label = f"{model.names[int(best_box.cls[0].item())]} {best_box.conf[0].item():.2f}"
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
cv2.putText(frame, label, (x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
action, _ = agent.predict(obs)
obs, reward, terminated, truncated, _ = env.step(action)
print(f"Action: {action}, Reward: {reward}, Target: {env.target_position}")
cv2.imshow("YOLO + RL", frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
if terminated or truncated:
obs, _ = env.reset()
cap.release()
cv2.destroyAllWindows()
5. YOLO + 강화학습 + DeepSORT를 결합
동일한 객체를 지속적으로 추적하고, 에이전트가 추적 대상에 맞춰 반응하는 시스템으로 확장.
우선은 "1개"의 대상만 추적하도록 설정한다.
from ultralytics import YOLO
import cv2
from stable_baselines3 import PPO
from tracking_env import MovingTargetEnv
from deep_sort_realtime.deepsort_tracker import DeepSort
model = YOLO("runs/detect/train7/weights/best.pt")
agent = PPO.load("ppo_tracking_agent")
env = MovingTargetEnv()
obs, _ = env.reset()
tracker = DeepSort(max_age=30)
cap = cv2.VideoCapture(0)
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
results = model(frame)
result = results[0]
detections = []
for box in result.boxes:
x1, y1, x2, y2 = map(int, box.xyxy[0])
conf = box.conf[0].item()
cls = int(box.cls[0].item())
detections.append(([x1, y1, x2 - x1, y2 - y1], conf, cls))
tracks = tracker.update_tracks(detections, frame=frame)
for track in tracks:
if not track.is_confirmed():
continue
x1, y1, x2, y2 = track.to_tlbr()
track_id = track.track_id
cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
env.update_target_position((cx, cy))
label = f"ID:{track_id}"
cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (255, 0, 0), 2)
cv2.putText(frame, label, (int(x1), int(y1) - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 0, 0), 2)
break # 하나의 추적 대상만 추적 (에이전트가 따라갈 기준)
action, _ = agent.predict(obs)
obs, reward, terminated, truncated, _ = env.step(action)
print(f"Action: {action}, Reward: {reward}, Target: {env.target_position}")
cv2.imshow("YOLO + DeepSORT + RL", frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
if terminated or truncated:
obs, _ = env.reset()
cap.release()
cv2.destroyAllWindows()
'Computer Engineering > 강화학습' 카테고리의 다른 글
[강화학습] Masked Reinforcement Learning (0) | 2025.03.25 |
---|---|
[강화학습] YOLO + 강화학습 환경 연동 (0) | 2025.03.22 |
[강화학습] PPO 알고리즘 (0) | 2025.03.22 |
[YOLO] YOLOv8 실시간 물체 탐지 실습 (0) | 2025.03.20 |
[YOLO] YOLOv8 + DeepSORT 이동물체추적 (1) | 2025.03.20 |