[강화학습] Masked Reinforcement Learning
2025. 3. 25. 11:25ㆍComputer Engineering/강화학습
1.Masking 기법
강화학습에서는 에이전트가 가능한 모든 action 중 하나를 선택해 환경 env와 상호작용하고, reward를 최대화하도록 학습한다.
그런데, 어떤 상황에서는 특정 행동이 물리적으로 불가능하거나, 규칙을 어기거나, 비효율적으로 위험할 수 있다.
이러한 경우에도 에이전트가 탐험을 위해 그러한 행동을 시도할 수 있는데, 이러한 행동은 학습 속도를 느리게 하고, 학습 실패로까지 이어질 수 있다.
→ Masked RL은 이런 불필요한 행동을 미리 제거해서 학습을 더 효율적이고 안정적으로 만들 수 있다.
행동공간이 매우 크거나 복잡할 때, 안전이나 물리적 제약이 존재할 때, 학습 속도가 느릴때, 행동마가 계산 비용이 클때 마스킹 기법을 사용하면 좋다.
2. 핵심 IDEA
Action Mask : 상태 s에서 가능한 행동 중 불가능한 행동은 확률을 0으로 masking하고, 나머지만 학습에 반영한다.
ex) Discrete Action Space: 행동공간 A = [왼, 오, 위, 아래] , 현재 상태에서 '위'로 가면 벽에 부딪힌다
→ Mask: [1, 1, 0, 1]
위의 예시처럼 masking이 적용된 후, 에이전트는 확률 분포에서 '위'방향을 아예 선택하지 않는다.
3. 동작 흐름
현재상태 s
→ 정책 네트워크: 원래는 모든 행동의 확률 분포 출력
→ Action Masking
→ 정규화된 행동 확률: 가능한 행동 중 하나 선택
→ 행동 a 수행
→ 다음 상태, 보상
4. 수식
- 일반적인 정책 확률 분포 : π(a∣s)=softmax(fθ(s))
- 마스킹 적용 : (단, m(a)∈{0,1}, 마스크된 행동은 m(a)=0 → 확률 0)
5. 실습 목표
- 3차선 도로 환경에서 에이전트가 차선을 유지하며 주행
- PPO 알고리즘을 사용해서 정책을 학습
- 그리고 마스킹을 적용한 PPO(Masked PPO) 와 기존 PPO 를 비교 실험
6. 코드구성
- lane_env.py
class LaneEnv(gym.Env):
def __init__(self):
...
self.observation_space = spaces.Box(low=0, high=2, shape=(2,), dtype=np.int32)
self.action_space = spaces.Discrete(5) # 행동 5가지 (좌,우,감속,가속,유지)
def get_action_mask(self): # Masking
mask = np.ones(5, dtype=np.int32)
if self.lane == 0:
mask[0] = 0 # 좌측 이동 불가
if self.lane == 2:
mask[1] = 0 # 우측 이동 불가
if self.speed == 2:
mask[2] = 0 # 가속 불가
if self.speed == 0:
mask[3] = 0 # 감속 불가
return mask
- 정책 네트워크.py
class MaskedPolicy(nn.Module):
def __init__(self, input_dim=2, output_dim=5):
...
self.fc = nn.Sequential(
nn.Linear(input_dim, 64),
nn.ReLU(),
nn.Linear(64, output_dim)
)
def forward(self, x, mask=None):
logits = self.fc(x) # 원래 softmax 전에 나오는 raw score
if mask is not None:
logits = logits.masked_fill(mask == 0, float('-inf'))
return F.softmax(logits, dim=-1)
마스크된 행동은 확률 0 처리되니까, 절대 선택되지 않는다.
- train_masked.py
mask = torch.tensor(env.get_action_mask(), dtype=torch.bool)
# Mask를 policy에 직접 넣어줘서, softmax에서 불가능한 행동은 자동 제외된 확률 분포를 사용한다.
probs = policy(state, mask=mask)
dist = torch.distributions.Categorical(probs)
action = dist.sample()
- train.py
probs = policy(state) # mask 없이 그냥 행동 확률 출력
'Computer Engineering > 강화학습' 카테고리의 다른 글
[강화학습] roboflow 실습 (0) | 2025.03.23 |
---|---|
[강화학습] YOLO + 강화학습 환경 연동 (0) | 2025.03.22 |
[강화학습] PPO 알고리즘 (0) | 2025.03.22 |
[YOLO] YOLOv8 실시간 물체 탐지 실습 (0) | 2025.03.20 |
[YOLO] YOLOv8 + DeepSORT 이동물체추적 (1) | 2025.03.20 |