[2021 KAKAO Blind Recruitment] Q5. 광고 삽입 (C++, Python, Java)

문제 링크

https://programmers.co.kr/learn/courses/30/lessons/72414

 

예상 난이도

G2

 

알고리즘 분류

투 포인터 or Prefix Sum

 

풀이

투 포인터 or Prefix Sum을 이용하는 전형적인 문제이나 투 포인터로 접근했다면 구현 난이도가 다소 높아서 어려웠을 문제입니다. 우선 등장 가능한 시각의 수는 360,000개이고 시각을 모두 초로 변환한 후 정수 범위에서 생각을 하면 편합니다. 참고로 저는 처음 문제를 풀 때에는 시각의 수가 360,000개인줄 모르고 바로 투 포인터로 해결했습니다.

 

일단 Prefix Sum 풀이 먼저 살펴보면, 특정 초에 시청중인 사람의 수를 계산하는 작업이 필요하게 됩니다. 

 

이렇게 구간이 들어올 때 마다 대응되는 모든 영역에 +1을 하면 쉽게 계산이 가능합니다.

 

그런데 시간복잡도를 생각해보면 가능한 시간은 360,000개이므로 O(360000N)이 되고, 그렇기 때문에 이렇게 구현하면 시간 초과가 발생해야 합니다. 원래는 시간 초과가 나야하는데 데이터가 약한지 현시점을 기준으로 C++, Java는 이렇게 구현해도 통과가 가능합니다. 다만 Python은 택도 없습니다.

 

코딩테스트를 치는 입장에서야 잘못된 시간복잡도로 접근해도 풀리면 매우 이득이지만 공부 하는 입장에서는 문제를 잘못된 풀이로 기억하는 불상사가 생길 수 있다보니 이런 일은 최대한 없으면 좋은데 아쉽네요. 아무튼 그러면 어떻게 이 상황을 해결하냐 생각할 때, 바로 변화량을 기록하는 방법으로 시간복잡도를 떨굴 수 있습니다.

 

아까와 달리 이번에는 시작 지점에 1을, 끝 지점에 -1을 쓰기만 합니다. 이게 무슨 의미인지는 조금 있다가 같이 봅시다.

 

이렇게 4개의 시청기록을 의미했는데 4초에 대응된 2라는 값은 대체 무슨 의미를 가지는 값일까요?

 

저 값은 4초일 때 시청중인 사람의 수가 3초일 때 보다 2명이 더 늘었음을 알려주는 값입니다. 즉 우리는 변화량을 잘 기록해둔거죠.

 

이 변화량을 바탕으로 우리가 얻고싶어했던 시청중인 사람의 수를 계산해낼 수 있습니다.

 

이렇게 각 log마다 최대 360,000칸의 값을 1 증가시키는 대신 2칸의 값에만 +1, -1을 기록해 O(N)에 수행 가능합니다.

 

그 다음으로 이제 누적 재생시간이 가장 많은 곳을 계산해야 하는데 위처럼 구간의 합을 직접 다 구하면 O(3600002)여서 바로 시간 초과 직행입니다.

 

대신 이전 구간에서의 값을 이용하면 O(360000)에 계산 가능합니다. 이렇게 Prefix Sum 풀이를 살펴보았습니다.

 

두 번째 방법은 투 포인터를 이용하는 방법입니다. 사실 Prefix sum을 쓰는게 이 문제에서는 더 간단하긴 합니다. 그런데 투 포인터 방법도 알아둘 필요가 있어서 이렇게 소개를 드립니다. 투 포인터 방법에서는 아까와 달리 큰 배열을 잡지 않습니다. 여기 있는 수직선은 오로지 설명의 편의를 위해 제시되는 수직선입니다.

 

이렇게 들어온 각 시청 기록에 대해 시작과 끝을 따로 Event로 관리합니다. 저 pair에서 앞은 시간, 뒤는 시작일 경우 1 / 끝일 경우 -1입니다.

 

그 다음 시간 순으로 Event를 정렬합니다. 여기까지가 전처리입니다.

 

다음으로 누적 재생시간이 가장 많은 곳을 계산해야 하는데, Event가 있는 곳이 바로 그 시간에 시청중인 사람 수에 변화가 있는 곳입니다. 일단 처음 구간은 저렇게 계산을 합니다.

 

처음에 0 to 6초 구간을 본 뒤 1 to 7초 구간을 보는게 아니라, 저 구간 막대의 왼쪽 끝 혹은 오른쪽 끝이 event에 걸릴 때 까지 이동을 합니다. 그렇게 이동을 한 후 이전과 비교했을 때 제거되는 부분(빨간색 영역), 추가되는 부분(파란색 영역)을 적절하게 계산해줍니다.

 

투 포인터를 사용하면 이와 같이 시간 복잡도에서 전체 시간의 개수 360,000이 빠진다는 큰 이점이 있습니다. 즉 만약 전체 시간이 360,000까지가 아니라 1,000,000,000까지와 같이 굉장히 컸다면 Prefix Sum 풀이 대신(아예 불가능한건 아니고 좌표압축을 이용하면 가능하긴 합니다) 투 포인터로 풀어내야 합니다.  다만 투 포인터와 구간을 다루는 문제에 익숙하지 않다면 풀이가 잘 이해가 안가거나, 코드를 짜는데 큰 어려움을 겪을거라 우선은 Prefix Sum 풀이를 먼저 살펴보시고 투 포인터를 충분히 숙달한 뒤에 다시 투 포인터로 도전해보시면 좋을 것 같습니다. 

 

코드(C++, Prefix Sum)

#include <bits/stdc++.h>
using namespace std;

int s2i(string s){
    return stoi(s.substr(0, 2))*3600 + stoi(s.substr(3, 2))*60 + stoi(s.substr(6, 2));
}

string numzfill(int x){
    if(x < 10) return "0"+to_string(x);
    return to_string(x);
}

string i2s(int t){
    string ret = numzfill(t/3600) + ":";
    t %= 3600;
    ret += numzfill(t/60)+":";
    t %= 60;
    return ret + numzfill(t);
}

string solution(string play_time, string adv_time, vector<string> logs) {
    int pt = s2i(play_time), at = s2i(adv_time);
    // A. 특정 초에 시청중인 사람의 수 계산
    int d[360001] = {};
    for(auto l : logs){
        int st = s2i(l.substr(0, 8)), en = s2i(l.substr(9, 8));
        d[st]++; d[en]--;
    }
    for(int i = 1; i <= 360000; i++) d[i] += d[i-1];
    // B. 누적 재생시간이 가장 많은 곳 계산
    long long mxval = 0, curval = 0;
    int mxtime = 0;
    for(int i = 0; i < at; i++) curval += d[i];
    mxval = curval;
    for(int i = 1; i <= 360000 - at; i++){
        curval = curval - d[i-1] + d[i+at-1];
        if(curval > mxval){
            mxval = curval;
            mxtime = i;
        }
    }
    return i2s(mxtime);
}

 

코드(C++, 투 포인터)

#include <bits/stdc++.h>
using namespace std;
using pii = pair<int,int>;

int s2i(string s){
    return stoi(s.substr(0, 2))*3600 + stoi(s.substr(3, 2))*60 + stoi(s.substr(6, 2));
}

string numzfill(int x){
    if(x < 10) return "0"+to_string(x);
    return to_string(x);
}

string i2s(int t){
    string ret = numzfill(t/3600) + ":";
    t %= 3600;
    ret += numzfill(t/60)+":";
    t %= 60;
    return ret + numzfill(t);
}

string solution(string play_time, string adv_time, vector<string> logs) {
    int pt = s2i(play_time), at = s2i(adv_time);
    vector<pii> event;
    // A. 전처리
    for(auto l : logs){
        int st = s2i(l.substr(0, 8)), en = s2i(l.substr(9, 8));
        event.push_back({st, 1});
        event.push_back({en, -1});                
    }
    event.push_back({0, 0});
    sort(event.begin(), event.end());
    // B. 누적 재생시간이 가장 많은 곳 계산
    // cnt1 : 시작 구간에서의 시청중인 사람의 수
    // cnt2 : 끝 구간에서의 시청중인 사람의 수
    int idx1 = 0, idx2 = 0, cnt1 = 0, cnt2 = 0;
    long long curval = 0, mxval = 0;
    int curtime = 0, mxtime = 0;
    while(idx2 + 1 < event.size() && event[idx2+1].first <= at){
        curval += (event[idx2+1].first-event[idx2].first) * cnt2;
        cnt2 += event[idx2+1].second;
        idx2++;
    }
    curval += (at - event[idx2].first) * cnt2;
    mxval = curval;
    
    while(curtime <= pt-at && idx2 + 1 < event.size()){
        int delta1 = event[idx1+1].first - curtime;
        int delta2 = event[idx2+1].first - (curtime + at);
        if(delta1 <= delta2){ // 시작 구간이 다음 event에 더 가까운 경우
            curval = curval + 1ll * (cnt2 - cnt1) * delta1;
            cnt1 += event[idx1+1].second;
            idx1++;
            curtime += delta1;
        }
        else{
            curval = curval + 1ll * (cnt2 - cnt1) * delta2;
            cnt2 += event[idx2+1].second;
            idx2++;
            curtime += delta2;
        }
        if(curval > mxval){
            mxval = curval;
            mxtime = curtime;
        }
    }
    return i2s(mxtime);
}

 

코드(Python, Prefix Sum)

def s2i(s):
    z = s.split(':')
    return int(z[0])*3600+int(z[1])*60+int(z[2])

def i2s(t):
    ret = ''
    ret += str(t//3600).zfill(2)+':'
    t %= 3600
    ret += str(t//60).zfill(2)+':'
    t %= 60
    ret += str(t).zfill(2)
    return ret

def solution(play_time, adv_time, logs):
    pt, at = s2i(play_time), s2i(adv_time)
    # A. 특정 초에 시청중인 사람의 수 계산
    d = [0]*360001
    for l in logs:
        st, en = map(s2i, l.split('-'))
        d[st] += 1
        d[en] -= 1
    for i in range(1, 360001):
        d[i] += d[i-1]
    # B. 누적 재생시간이 가장 많은 곳 계산
    mxval, mxtime = sum(d[:at]), 0
    curval = mxval
    for i in range(1, 360001-at):
        curval = curval - d[i-1] + d[i+at-1]
        if curval > mxval:
            mxval = curval
            mxtime = i
    return i2s(mxtime)

코드(Python, 투 포인터)

def s2i(s):
    z = s.split(':')
    return int(z[0])*3600+int(z[1])*60+int(z[2])

def i2s(t):
    ret = ''
    ret += str(t//3600).zfill(2)+':'
    t %= 3600
    ret += str(t//60).zfill(2)+':'
    t %= 60
    ret += str(t).zfill(2)
    return ret

def solution(play_time, adv_time, logs):
    event = []
    pt, at = s2i(play_time), s2i(adv_time)
    # A. 전처리
    for l in logs:
        st,en = map(s2i, l.split('-'))
        event.append((st,1))
        event.append((en,-1))
    event.append((0, 0));
    event.sort()
    # B. 누적 재생시간이 가장 많은 곳 계산
    # cnt1 : 시작 구간에서의 시청중인 사람의 수
    # cnt2 : 끝 구간에서의 시청중인 사람의 수
    idx1, idx2, cnt1, cnt2 = 0, 0, 0, 0
    curtime, curval = 0, 0
    while idx2 < len(event) - 1 and event[idx2+1][0] <= at:
        curval += (event[idx2+1][0]-event[idx2][0]) * cnt2
        cnt2 += event[idx2+1][1]
        idx2 += 1
    curval += (at - event[idx2][0]) * cnt2
    mxval = curval
    mxtime = 0
    while curtime <= pt-at and idx2 < len(event) - 1:
        delta1 = event[idx1+1][0] - curtime
        delta2 = event[idx2+1][0] - (curtime + at)
        if delta1 <= delta2: # 시작 구간이 다음 event에 더 가까운 경우
            curval = curval + (cnt2 - cnt1) * delta1
            cnt1 += event[idx1+1][1]
            idx1 += 1
            curtime += delta1
        else:
            curval = curval + (cnt2 - cnt1) * delta2
            cnt2 += event[idx2+1][1]
            idx2 += 1
            curtime += delta2
        if curval > mxval:
            mxval, mxtime = curval, curtime
    return i2s(mxtime)

 

코드(Java, Prefix Sum)

import java.util.*;
class Solution {
    static int s2i(String s){
        return Integer.parseInt(s.substring(0, 2))*3600 + Integer.parseInt(s.substring(3, 5))*60 + Integer.parseInt(s.substring(6, 8));
    }
    
    static String numzfill(int x){
        if(x < 10) return "0" + String.valueOf(x);
        return String.valueOf(x);
    }
    
    static String i2s(int t){
        String ret = numzfill(t/3600) + ":";
        t %= 3600;
        ret += numzfill(t/60) + ":";
        t %= 60;
        return ret + numzfill(t);
    }    
    
    public String solution(String play_time, String adv_time, String[] logs) {
        int pt = s2i(play_time), at = s2i(adv_time);
        // A. 특정 초에 시청중인 사람의 수 계산
        int d[] = new int[360001];
        for(int i = 0; i < logs.length; i++){
            String l = logs[i];
            int st = s2i(l.substring(0, 8)), en = s2i(l.substring(9, 17));
            d[st]++; d[en]--;
        }
        for(int i = 1; i <= 360000; i++) d[i] += d[i-1];
        // B. 누적 재생시간이 가장 많은 곳 계산
        long mxval = 0, curval = 0;
        int mxtime = 0;
        for(int i = 0; i < at; i++) curval += d[i];
        mxval = curval;
        for(int i = 1; i <= 360000 - at; i++){
            curval = curval - d[i-1] + d[i+at-1];
            if(curval > mxval){
                mxval = curval;
                mxtime = i;
            }
        }
        return i2s(mxtime);
    }
}

코드(Java, 투 포인터)

import java.util.*;
import java.lang.*;

class Solution {
    
    static class pair implements Comparable<pair>{
        int x, y;
        
        public pair(int x, int y){this.x = x; this.y = y;}

        @Override
        public int compareTo(pair o) {
            if (x > o.x) return 1;
            if (x < o.x) return -1;
            if(y > o.y) return 1;
            if(y == o.y) return 0;
            return -1;
        }
    }
    
    static int s2i(String s){
        return Integer.parseInt(s.substring(0, 2))*3600 + Integer.parseInt(s.substring(3, 5))*60 + Integer.parseInt(s.substring(6, 8));
    }
    
    static String numzfill(int x){
        if(x < 10) return "0" + String.valueOf(x);
        return String.valueOf(x);
    }
    
    static String i2s(int t){
        String ret = numzfill(t/3600) + ":";
        t %= 3600;
        ret += numzfill(t/60) + ":";
        t %= 60;
        return ret + numzfill(t);
    }
        
    public String solution(String play_time, String adv_time, String[] logs) {
        int pt = s2i(play_time), at = s2i(adv_time);
        // A. 전처리
        ArrayList<pair> event = new ArrayList<>();
        for(int i = 0; i < logs.length; i++){
            String l = logs[i];
            int st = s2i(l.substring(0, 8)), en = s2i(l.substring(9, 17));
            event.add(new pair(st, 1));
            event.add(new pair(en, -1));            
        }
        event.add(new pair(0, 0));
        Collections.sort(event);
        // B. 누적 재생시간이 가장 많은 곳 계산
        // cnt1 : 시작 구간에서의 시청중인 사람의 수
        // cnt2 : 끝 구간에서의 시청중인 사람의 수
        int idx1 = 0, idx2 = 0, cnt1 = 0, cnt2 = 0;
        long curval = 0, mxval = 0;
        int curtime = 0, mxtime = 0;
        while(idx2 + 1 < event.size() && event.get(idx2+1).x <= at){
            curval += (event.get(idx2+1).x-event.get(idx2).x) * cnt2;
            cnt2 += event.get(idx2+1).y;
            idx2++;
        }
        curval += (at - event.get(idx2).x) * cnt2;
        mxval = curval;

        while(curtime <= pt-at && idx2 + 1 < event.size()){
            int delta1 = event.get(idx1+1).x - curtime;
            int delta2 = event.get(idx2+1).x - (curtime + at);
            if(delta1 <= delta2){ // 시작 구간이 다음 event에 더 가까운 경우
                curval = curval + (long)(cnt2 - cnt1) * delta1;
                cnt1 += event.get(idx1+1).y;
                idx1++;
                curtime += delta1;
            }
            else{
                curval = curval + (long)(cnt2 - cnt1) * delta2;
                cnt2 += event.get(idx2+1).y;
                idx2++;
                curtime += delta2;
            }
            if(curval > mxval){
                mxval = curval;
                mxtime = curtime;
            }
        }
        return i2s(mxtime);
    }
}

 

관련 강의

0x14강 - 투 포인터

  Comments
댓글 쓰기