[강화학습] roboflow 실습

2025. 3. 23. 18:42Computer Engineering/강화학습

1. Roboflow Universe 에서 데이터셋을 가져온다. 

나 같은 경우 아래와 같은 lane 데이터셋을 다운로드 받았다. 

https://universe.roboflow.com/duyen/lane-unzou/fork

 

lane Instance Segmentation Dataset by duyen

352 open source lane images. lane dataset by duyen

universe.roboflow.com

download format에서 본인이 필요로 하는 코드를 다운 받으면 된다. 나는 YOLOv8을 선택했다. 

 

그렇게 다운로드 받은 코드를 실행하게 되면 

data.yaml 파일이 들어 있는 폴더가 만들어지게 되는데, 이 폴더를 꼭 'datasets'  폴더 경로에 넣어야 한다.

 

2. YOLOv8로 학습하기

터미널에 아래 코드를 입력한다. 

yolo detect train model=yolov8n.pt data=your-dataset/data.yaml epochs=50 imgsz=640

 

위와 같이 학습이 진행되기 때문에 시간이 생각보다 오래 걸릴 수 있다. 

 

3. 학습이 끝난 후에는 자동으로 해당하는 경로에 모델이 저장된다. 

→ 커스텀 객체 탐지 모델

 

4. 저장된 모델 경로를 강화학습 환경에 적용한다.

YOLO가 실시간으로 lane, stop, intersection 같은 커스텀 객체를 탐지하고,
탐지된 객체의 위치를 강화학습 환경에 넘겨줘서 PPO 에이전트가 반응하도록 연결하는 전체 코드를 작성한다. 

from ultralytics import YOLO
import cv2
from stable_baselines3 import PPO
from tracking_env import MovingTargetEnv

model = YOLO("runs/detect/train7/weights/best.pt")
agent = PPO.load("ppo_tracking_agent")
env = MovingTargetEnv()
obs, _ = env.reset()

cap = cv2.VideoCapture(0)

while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break

    results = model(frame)
    result = results[0]

    if result.boxes:
        best_box = max(result.boxes, key=lambda b: b.conf[0].item())
        x1, y1, x2, y2 = map(int, best_box.xyxy[0])
        cx, cy = (x1 + x2) // 2, (y1 + y2) // 2

        env.update_target_position((cx, cy))

        label = f"{model.names[int(best_box.cls[0].item())]} {best_box.conf[0].item():.2f}"
        cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
        cv2.putText(frame, label, (x1, y1 - 10),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)

    action, _ = agent.predict(obs)
    obs, reward, terminated, truncated, _ = env.step(action)

    print(f"Action: {action}, Reward: {reward}, Target: {env.target_position}")

    cv2.imshow("YOLO + RL", frame)
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

    if terminated or truncated:
        obs, _ = env.reset()

cap.release()
cv2.destroyAllWindows()

 

 

5. YOLO + 강화학습 + DeepSORT를 결합

동일한 객체를 지속적으로 추적하고, 에이전트가 추적 대상에 맞춰 반응하는 시스템으로 확장.

우선은 "1개"의 대상만 추적하도록 설정한다. 

from ultralytics import YOLO
import cv2
from stable_baselines3 import PPO
from tracking_env import MovingTargetEnv
from deep_sort_realtime.deepsort_tracker import DeepSort

model = YOLO("runs/detect/train7/weights/best.pt")
agent = PPO.load("ppo_tracking_agent")
env = MovingTargetEnv()
obs, _ = env.reset()

tracker = DeepSort(max_age=30)
cap = cv2.VideoCapture(0)

while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break

    results = model(frame)
    result = results[0]

    detections = []
    for box in result.boxes:
        x1, y1, x2, y2 = map(int, box.xyxy[0])
        conf = box.conf[0].item()
        cls = int(box.cls[0].item())
        detections.append(([x1, y1, x2 - x1, y2 - y1], conf, cls))

    tracks = tracker.update_tracks(detections, frame=frame)

    for track in tracks:
        if not track.is_confirmed():
            continue
        x1, y1, x2, y2 = track.to_tlbr()
        track_id = track.track_id
        cx, cy = (x1 + x2) // 2, (y1 + y2) // 2

        env.update_target_position((cx, cy))

        label = f"ID:{track_id}"
        cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (255, 0, 0), 2)
        cv2.putText(frame, label, (int(x1), int(y1) - 10),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 0, 0), 2)
        break  # 하나의 추적 대상만 추적 (에이전트가 따라갈 기준)

    action, _ = agent.predict(obs)
    obs, reward, terminated, truncated, _ = env.step(action)

    print(f"Action: {action}, Reward: {reward}, Target: {env.target_position}")

    cv2.imshow("YOLO + DeepSORT + RL", frame)
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

    if terminated or truncated:
        obs, _ = env.reset()

cap.release()
cv2.destroyAllWindows()