[강화학습] YOLO + 강화학습 환경 연동
2025. 3. 22. 23:44ㆍComputer Engineering/강화학습
목표
YOLOv8로 실시간으로 물체(ex. 사람)를 탐지하고,
해당 물체의 위치를 강화학습 환경으로 넘겨서,
AI가 실제 움직임에 맞춰 따라가거나 회피하도록 반응하게 한다.
1. YOLO가 실시간으로 프레임을 분석하여 탐지된 사람의 좌표(x, y)를 추출한다.
2. 그 좌표를 강화학습 환경의 target_position으로 전달한다.
3. PPO 에이전트가 해당 좌표를 목표로 삼고 행동을 결정한다.
1. YOLO로 사람 위치 추출
from ultralytics import YOLO
import cv2
model = YOLO("yolov8n.pt")
cap = cv2.VideoCapture(0)
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
results = model(frame)
result = results[0]
# 사람 클래스만 필터링
people_boxes = [box for box in result.boxes if int(box.cls[0].item()) == 0] # 0: person
if people_boxes:
# 신뢰도 가장 높은 사람 선택
best_box = max(people_boxes, key=lambda b: b.conf[0].item())
x1, y1, x2, y2 = map(int, best_box.xyxy[0])
person_center = ((x1 + x2) // 2, (y1 + y2) // 2)
print(f"탐지된 사람 중심 좌표: {person_center}")
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()
2. 강화학습 환경으로 좌표 전달
import gymnasium as gym
import numpy as np
from gymnasium import spaces
class MovingTargetEnv(gym.Env):
def __init__(self):
super().__init__()
self.area_size = 640 # 프레임 크기 기준 (640x480이라면 가로 기준)
self.agent_position = np.array([320, 240]) # 에이전트 시작 위치 (중앙)
self.target_position = np.array([100, 100]) # 타겟 초기 위치
# 행동: 0=왼쪽, 1=오른쪽, 2=위쪽, 3=아래쪽
self.action_space = spaces.Discrete(4)
# 상태: [에이전트 x, y, 타겟 x, y]
self.observation_space = spaces.Box(
low=0,
high=self.area_size,
shape=(4,),
dtype=np.float32
)
def step(self, action):
# 에이전트 이동
if action == 0:
self.agent_position[0] -= 10
elif action == 1:
self.agent_position[0] += 10
elif action == 2:
self.agent_position[1] -= 10
elif action == 3:
self.agent_position[1] += 10
# 범위 제한
self.agent_position = np.clip(self.agent_position, 0, self.area_size)
# 보상: 타겟과 가까워지면 보상 증가
distance = np.linalg.norm(self.agent_position - self.target_position)
reward = -distance
# 도달 여부 판단 (오차 범위 20px 이내)
done = distance < 20
terminated = done
truncated = False
obs = np.concatenate([self.agent_position, self.target_position])
return obs, reward, terminated, truncated, {}
def reset(self, seed=None, options=None):
self.agent_position = np.array([320, 240]) # 중앙으로 초기화
# 타겟 위치는 YOLO로부터 실시간으로 전달받음 (기본은 그대로 둠)
obs = np.concatenate([self.agent_position, self.target_position])
return obs, {}
def update_target_position(self, new_position):
"""YOLO 결과를 받아 타겟 위치 갱신"""
self.target_position = np.clip(np.array(new_position), 0, self.area_size)
3. 에이전트 학습
from stable_baselines3 import PPO
from tracking_env_people import MovingTargetEnv
env = MovingTargetEnv()
model = PPO('MlpPolicy', env, verbose=1)
model.learn(total_timesteps=50000) # 5만법 학습
model.save("ppo_tracking_people_agent")
from stable_baselines3 import PPO
from tracking_env_people import MovingTargetEnv
import matplotlib.pyplot as plt
import time
model = PPO.load("ppo_tracking_people_agent")
env = MovingTargetEnv()
obs, _ = env.reset()
for _ in range(50): # 50번 움직임
action, _ = model.predict(obs) # 강화학습 모델이 행동 선택
obs, reward, done, truncated, _ = env.step(action)
print(f"Action: {action}, Reward: {reward}")
# 시각화
plt.clf()
plt.xlim(0, 20)
plt.ylim(0, 20)
plt.plot(env.agent_position[0], env.agent_position[1], 'bo', label='Agent') # 파란 점
plt.plot(env.target_position[0], env.target_position[1], 'ro', label='Target') # 빨간 점
plt.legend()
plt.pause(0.2)
if done:
print("목표 도달")
break
4. YOLO 통합 후 실시간 반응 확인
카메라에 사람이 들어오면 AI가 그 좌표를 타겟으로 삼고 반응한다.
from ultralytics import YOLO
import cv2
from stable_baselines3 import PPO
from tracking_env_people import MovingTargetEnv
# YOLO 모델 로드
model = YOLO("yolov8n.pt")
agent = PPO.load("ppo_tracking_people_agent")
env = MovingTargetEnv()
obs, _ = env.reset()
cap = cv2.VideoCapture(0)
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
results = model(frame)
result = results[0]
people_boxes = [box for box in result.boxes if int(box.cls[0].item()) == 0]
if people_boxes:
best_box = max(people_boxes, key=lambda b: b.conf[0].item())
x1, y1, x2, y2 = map(int, best_box.xyxy[0])
cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
# 타겟 위치 갱신
env.update_target_position((cx, cy))
# 화면에 바운딩 박스 그리기
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
label = f"Target: ({cx}, {cy})"
cv2.putText(frame, label, (x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
# 에이전트 행동 선택
action, _ = agent.predict(obs)
obs, reward, terminated, truncated, _ = env.step(action)
print(f"Action: {action}, Reward: {reward}, Target: {env.target_position}")
cv2.imshow("YOLO + RL", frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()
최종 결과
해당 파일 실행 후 카메라 창이 뜨고, 탐지된 사람 위에 녹색 바운딩 박스 + 좌표 표시가 된다.
PPO 에이전트는 그 타겟 좌표를 따라 반응하고, 터미널에는 Action, Reward, Target이 출력된다.
'Computer Engineering > 강화학습' 카테고리의 다른 글
[강화학습] Masked Reinforcement Learning (0) | 2025.03.25 |
---|---|
[강화학습] roboflow 실습 (0) | 2025.03.23 |
[강화학습] PPO 알고리즘 (0) | 2025.03.22 |
[YOLO] YOLOv8 실시간 물체 탐지 실습 (0) | 2025.03.20 |
[YOLO] YOLOv8 + DeepSORT 이동물체추적 (1) | 2025.03.20 |