강화학습/파이썬과 케라스로 배우는 강화학습(스터디)

[강화학습] 15 - 브레이크아웃과 CNN

고집호랑이 2023. 1. 27. 07:30

개요

지금까지 저희는 그리드월드와 카트폴과 같은 간단한 예제에 강화학습 알고리즘을 적용해봤습니다.

 

이 두 예제는 간단하며 인공신경망의 입력으로 사용된 상태 공간도 작았습니다. 하지만 다소 복잡하고 상태 공간이 큰 게임 화면으로부터도 에이전트는 잘 학습할 수 있을까요?  

 

이번 포스팅부터 저희는 아타리 사의 브레이크아웃이라는 게임에서 DQN과 A3C 알고리즘을 적용해보면서 에이전트가 게임화면으로부터 어떻게 학습하는지 알아보도록 하겠습니다.

 

아타리: 브레이크아웃

브레이크아웃은 아타리라는 미국 게임 회사에서 만든 벽돌 깨기 게임입니다.

 

이 고전 게임은 2013년에 다시 유명해지게 되는데, 바로 알파고를 만든 회사로 알려진 딥마인드에서 강화학습을 통해 이 브레이크아웃 게임을 학습시켰기 때문입니다.

 

이때 딥마인드가 소개하고 학습에 사용했던 강화학습 알고리즘이 바로 DQN 알고리즘입니다. 

 

오픈에이아이 짐에는 브레이크아웃도 환경으로서 제공하기 때문에 저희는 오픈에이아이 짐의 브레이크아웃에 강화학습을 적용해볼 것입니다. 

 

아타리의 브레이크아웃

 

브레이크아웃에 강화학습을 적용하기 위해서 먼저 MDP를 살펴볼 필요가 있습니다. 

 

상태: 브레이크아웃에서 상태는 2차원 RGB 픽셀 데이터로 이루어진 게임 화면이며, 총 4개의 화면을 연속으로 입력 받습니다.

 

행동: 에이전트가 할 수 있는 행동으로는 제자리, 왼쪽, 오른쪽으로 3가지가 있습니다.

 

보상: 에이전트는 벽돌이 하나씩 깨질 때마다 (+)보상을 받으며 좀 뒤 쪽의 벽돌을 깰수록 높은 보상을 받습니다. 또한 아무것도 깨지 못한 평소에는 0의 보상을 받고 공을 떨어뜨려 목숨을 잃었을 때는 (-1)의 보상을 받습니다.

 

당연하게도 에이전트는 보상을 최대화하는 것이 목표입니다. 한 에피소드에서 에이전트는 총 5개의 목숨을 가지고 있으며 이를 모두 잃으면 게임 오버가 됩니다. 

 

브레이크아웃에서 강화학습을 통해서 학습되는 것은 심층신경망인데, 그중에서도 컨볼루션 신경망이라는 것입니다.

 

컨볼루션 신경망(CNN)

브레이크아웃에서의 강화학습은 게임화면으로 학습하기 때문에 인공신경망의 입력이 게임화면이 되어야합니다.

 

이때 입력으로 들어가는 게임화면의 크기는 (가로 픽셀 수) × (세로 픽셀 수)  × 3(RGB)가 됩니다.

 

2차원 숫자의 배열이 3차원으로 겹쳐진 형태인 게임화면을 1차원 특정벡터로 나열하여 인공신경망의 입력으로 들어간다면 인공신경망의 노드와 가중치의 수가 너무 많아집니다.

 

브레이크아웃의 이미지 크기는 210 × 160 × 3 = 100800이고 이전에 사용했던 인공신경망처럼 완전 연결 층(※ 한 층의 모든 노드가 다음 층의 모든 노드와 연결된 상태)이라면 은닉층에 있는 하나의 노드마다 100800개의 가중치값이 필요하게 되는 것이죠.

 

이렇게 인공신경망의 크기가 매우 커지면 학습이 무지막지하게 오래 걸리게 됩니다.

 

이 문제를 해결하기 위해서 실제로 사람이 시작정보를 처리하는 방법을 이용하게 됩니다.

 

사람의 시신경은 전체 시각입력 중 특정 입력에 대해서만 반응하는데, 이를 이용하여 은닉층의 노드들도 인공신경망의 입력층의 일부 노드에 대해서만 반응하도록 구조를 만들었습니다.

 

이를 구현한 것이 컨볼루션 필터입니다. 이미지에 적용하는 필터는 이미지의 노이즈를 없애거나 어떤 특징을 강조할 때 사용하는 것으로 대부분 박스 형태의 필터를 사용합니다.

 

박스의 필터는 아래 그림처럼 이미지의 왼쪽 위부터 오른쪽 아래까지 이동하면서 이미지의 각 픽셀 값이 주변 픽셀값과 필터를 통해 연산됩니다. 이를 컨볼루션 연산이라고합니다.

 

출처 - http://deeplearning.stanford.edu/wiki/index.php/Feature_extraction_using_convolution

 

필터의 종류에는 아래 그림과 같이 선명도를 딸어뜨리는 필터, 이미지를 왼쪽으로 1 픽셀씩 움직이는 필터 등 종류가 엄청 많습니다.

 

필터의 종류
출처 - https://m.blog.naver.com/PostView.naver?isHttpsRedirect=true&blogId=framkang&logNo=220561249726

 

이미지에 어떤 필터를 사용하느냐에 따라 전혀 다른 정보를 전달하기 때문에 학습시킬 때 수많은 필터 중 어떤 필터를 사용할지 정하는 것은 학습의 성능과 성공 유무를 결정하는 중요한 일이자 어려운 일이 되겠죠.

 

이 과정을 자동으로 만들어주는 것이 바로 컨볼루션 신경망입니다. 컨볼루션 신경망을 이용하면 전문성을 필요로 하는 '특징 추출'을 스스로 하면서 바로 입력에 대한 출력이 나옵니다.

 

컨볼루션 신경망은 보통 여러 층의 컨볼루션 층을 가지고 있습니다. 각 층에는 게임화면으로부터 특징을 추출하는 필터들이 존재하죠. 

 

아래의 오른쪽 그림이 첫 번째 층에 해당하는 브레이크아웃의 학습된 32개의 필터입니다.

 

DQN의 CNN과 학습된 필터

브레이크아웃의 게임 화면이 입력으로 들어오면 필터와 이미지의 컨볼루션 연산을 통해서 새로운 이미지가 생성됩니다. 이 이미지의 각 픽셀 값은 활성함수를 통과하고 다시 다음 층의 컨볼루션 필터와 컨볼루션 연산을 수행합니다.

 

이후 마지막 컨볼루션 층까지 통과하고 노드들은 일렬로 퍼지는데 이 과정을 플랫이라고 합니다. 이렇게 펼쳐진 노드는 출력층을 통과해 행동의 개수대로 큐함수의 값을 출력합니다. 

 

이렇게 컨볼루션 신경망을 이용해 큐함수를 근사하는 이 네트워크를 딥-큐네트워크(DQN)이라고 합니다.

 

브레이크아웃의 컨볼루션 신경망

class DQN(tf.keras.Model):
    def __init__(self, action_size, state_size):
        super(DQN, self).__init__()
        self.conv1 = Conv2D(32, (8, 8), strides=(4, 4), activation='relu',
                            input_shape=state_size)
        self.conv2 = Conv2D(64, (4, 4), strides=(2, 2), activation='relu')
        self.conv3 = Conv2D(64, (3, 3), strides=(1, 1), activation='relu')
        self.flatten = Flatten()
        self.fc = Dense(512, activation='relu')
        self.fc_out = Dense(action_size)

    def call(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.flatten(x)
        x = self.fc(x)
        q = self.fc_out(x)
        return q

브레이크아웃에 사용되는 컨볼루션 신경망을 생성하는 코드는 위와 같습니다. 여기에서는 총 3개의 컨볼루션 층을 생성하였는데, 이때 저희는 필터의 개수, 필터의 크기, 필터가 이동하는 폭, 활성함수를 설정해야 합니다.

 

self.conv3 = Conv2D(64, (3, 3), strides=(1, 1), activation='relu')

이 코드는 3번째 컨볼루션 층을 생성하는 코드인데 <64>는 필터의 개수, <(3,3)>은 필터의 크기, <strides = (1,1)>은 필터가 이동하는 폭, relu는 사용하는 활성함수를 나타냅니다. 

 

정리하자면 (3,3) 크기의 필터 64개가, (1,1)씩 이동하면서 이미지와 컨볼루션 연산을 한 후 ReLU 활성함수를 통과한다는 뜻입니다.

 

3개의 컨볼루션 층을 통과하면 Flatten()이라는 함수가 노드들을 플랫시켜주고, 이후 출력층을 통과하여 action_size만큼 노드를 생성하여 (좌, 우, 정지)에 해당하는 큐함수의 값을 출력합니다.

 

여기서 컨볼루션 신경망은 컨볼루션 연산만 진행하였지만, 본래 컨볼루션 신경망은 컨볼루션 연산 외에도 padding과 pooling 과정도 존재합니다. 이에 대해서는 다음에 기회가 되면 설명하도록 하겠습니다. 

 

저희는 이전에 카트폴에서도 DQN 알고리즘을 적용하였는데, 카트폴은 화면으로 학습하는 것이 아니기 때문에 컨볼루션 신경망은 사용되지 않았습니다.  

 

다음 포스팅에서는 카트폴에서의 DQN 알고리즘에서 추가로 필요한 사항과 브레이크 아웃에서의 DQN 알고리즘 코드를 살펴보도록 하겠습니다. 읽어주셔서 감사합니다~!

 

 

http://www.yes24.com/Product/Goods/44136413

 

파이썬과 케라스로 배우는 강화학습 - YES24

“강화학습을 쉽게 이해하고 코드로 구현하기”강화학습의 기초부터 최근 알고리즘까지 친절하게 설명한다!‘알파고’로부터 받은 신선한 충격으로 많은 사람들이 강화학습에 관심을 가지기

www.yes24.com

※ 이 글은 위의 책 내용을 바탕으로 작성한 글입니다.