Convex Hull Trick - Алгоритмика
Convex Hull Trick

Convex Hull Trick

Эта статья — одна из серии. Рекомендуется сначала прочитать все предыдущие.

Возьмём исходную формулу для $f$ и раскроем скобки в cost:

$$ \begin{aligned} f[i, j] &= \min_{k < i} \{ f[k, j-1] + (x_{i-1}-x_k)^2 \} \\ &= \min_{k < i} \{ f[k, j-1] + x_{i-1}^2 - 2x_{i-1} x_k + x_k^2 \} \end{aligned} $$ Заметим, что $x_{i-1}^2$ не зависит от $k$, значит его можно вынести. Под минимумом тогда останется только $$ \underbrace{(f[k, j-1] + x_k^2)}_{a_k} + \underbrace{(-2 x_k)}_{b_k} \cdot x_{i-1} $$ Выполнив перегруппировку, получаем, что исходное выражение можно переписать как $$ f[i, j] = \min_k \{ (a_k, b_k) \cdot (1, x_{i-1}) \} $$

где под «$\cdot$» имеется в виду скалярное произведение.

Алгоритм

Пусть мы хотим найти оптимальное $k$ для $f[i, j]$. Представим все уже посчитанные релевантные динамики с предыдущего слоя как точки $(a_k, b_k)$ на плоскости.

Чтобы эффективно находить среди них точку с минимальным скалярным произведением, можно поддерживать их нижнюю огибающую — вектор $(1, x_{i-1})$ «смотрит» всегда вверх, поэтому нам интересна только она — и бинпоиском по ней находить оптимальную точку.

Хранить нижнюю огибающую можно просто в стеке. Так как добавляемые точки отсортированы по $x$, её построение будет занимать линейное время, а асимптотика всего алгоритма будет упираться в асимптотику бинарного поиска, то есть будет равна $O(n m \log n)$

struct line {
    int k, b;
    line() {}
    line(int a, int _b) { k = a, b = _b; }
    int get(int x) { return k * x + b; }
};

vector<line> lines; // храним прямые нижней огибающей
vector<int> dots; // храним x-координаты точек нижней огибающей
//     ^ первое правило вещественных чисел
//      считаем, что в dots лежит округленная вниз x-координата

int cross(line a, line b) { // считаем точку пересечения
                            // считаем a.k > b.k
    int x = (b.b - a.b) / (a.k - b.k);
    if (b.b < a.b) x--; // боремся с округлением у отрицательных чисел
    return x;
}


void add(line cur) {
    while (lines.size() && lines.back().get(dots.back()) > cur.get(dots.back())) {
        lines.pop_back();
        dots.pop_back();
    }
    if (lines.empty())
        dots.push_back(-inf);
    else 
        dots.push_back(cross(lines.back(), cur));
    lines.push_back(cur);
}

int get(int x) {
    int pos = lower_bound(dots.begin(), dots.end(), x) - dots.begin() - 1;
    return lines[pos].get(x);
}

В случае нашей конкретной задачи, алгоритм можно и дальше соптимизировать, если вспомнить, что $opt[i, j] \leq opt[i][j+1]$, то есть что оптимальная точка всегда будет «правее». Это позволяет вместо бинпоиска применить метод двух указателей, и таким образом соптимизировать решение до $O(n m)$.