Dancing Link

Author Avatar
Tianqi Zhang 7月 20, 2018

简介

Dancing Link 是 Donald Knuth 提出的用来实现 Algorithm X 的一种数据结构,因为其对于链表的奇妙操作而得名。Algorithm X 是一种用来解决精确覆盖问题 (exact cover) 的递归的深度优先搜索算法。Exact cover 问题可以转化为很多常见的问题,例如N-皇后问题,数独问题等等。但是在数独问题中由于数据规模有限,基于 Dancing Link 的 Algorithm X 并不能取得比 Brute Force 算法更好的效果。

Exact Cover Problem(精确覆盖问题)

Exact Cover Problem 的定义为:给定一个binary矩阵,是否能找到一个行的集合,使得集合中每一列都恰好包含一个1。

例如,如下矩阵

此矩阵的{1, 4, 5}行就是这样一个集合。

Algorithm X

Algorithm X 的步骤如下:

Algorithm X

  1. 如果矩阵$A$没有列,那么当前Solution是一个正确的Solution,算法结束;不然的话,转到2
  2. 选择一列 \(c\)(非随机的,一般取当前剩余的第一列)
  3. 随机选择一行 \(r\) 使得\(A_{r,c}=1\)
  4. 将\(r\)加入到Solution中
  5. 对行 \(r\) 上满足 \(A_{r,j}=1\) 的列\(j\),对列\(j\)上任意满足 \(A_{i,j}\) 的行\(i\),将此行从\(A\)中删去,最后,将列\(j\)也从\(A\)中删去
  6. 迭代重复此算法

对于此算法具象化的理解可以参照wiki或者这里

Dancing Link 的主要思想来自于双向链表。Knuth 发现用朴素算法实现Algorithm X会花费大量的时间来搜索矩阵中的 1。当要选择一列的时候,要搜索整个矩阵来找到 1。当选择一行的时候,需要在整列中搜索 1。为了把搜索时间从 \(O(n)\) 降到 \(O(1)\), Knuth 使用了一个 稀疏矩阵 ,只存放所有1 的位置。无论何时,矩阵中的每个节点都会与左边和右边的节点(原始矩阵中的1的位置)、上面和下面(原始矩阵中同一列的1),以及列头连接。每一行和每一列都会形成一个双向链表。

每一列都会有特殊的,叫做“列头”的节点,作为列表中的一部分。列头形成了特殊的一行(控制行),包括了原始矩阵中还存在的每一列。每一列的头会记录这一列中节点的个数,我们可以用这些信息来定位节点最少的一列,只花费 \(O(n)\) 的时间复杂度。 (而不是\(O(n×m)\)),这里的n指的是列的个数,m指的是行的个数。选择节点最少的一列来进行搜索在一些情况下可以提高性能,但不是每个问题中都需要这么做。

对于每一个节点,删除操作非常简单,选定一个方向,只需要将其和与其邻接的节点的链接单方面(邻接节点那面)断开即可,注意此时被删除的节点仍然指向其原来的邻接节点,以水平方向为例,断开时

node->right->left = node->left;
node->left->right = node->right;

而重新链接时只需要做一次反操作即可。

代码实现

// dancing_link.h
#ifndef _DANCINGLINK_H
#define _DANCINGLINK_H

#include <stdio.h>
#include <map>
#include <unordered_set>
#include <vector>
#include <algorithm>
#include <cassert>

namespace DCLink{

    typedef enum {HORIZONTAL, VERTICAL} Direction;

    struct DCNode{
        DCNode *up, *down, *left, *right, *col_node;
        bool deleted;
        int row, col;

        DCNode(){
            up = this; down = this; left = this; right = this; col_node = this;
            deleted = false; row = 0; col = 0;
        }

        DCNode(int r, DCNode *c){
            up = this; down = this; left = this; right = this; col_node = c;
            deleted = false; row = r; col = c->col;
        }

        void Remove(Direction d){
            deleted = true;
            switch(d){
                case HORIZONTAL: left->right = right; right->left = left; break;
                case VERTICAL: up->down = down; down->up = up; break;
                default: return;
            }
        }

        void Recover(){
            deleted = false;
            left->right = this;
            right->left = this;
            up->down = this;
            down->up = this;
        }
    };

    class DancingLink{
    private:
        DCLink::DCNode *head;
        int row_rank, col_rank;
        bool has_solution;

        void buildDancingLink(const std::vector<std::vector<int>> &matrix);

        void releaseSpace();

        int indexToId(int row, int col){ return row * col_rank + col; }

        void ignoreRow(DCNode *node);

        void chooseRow(DCNode *node);

        void unchooseRow(DCNode *node);

        void reconsiderRow(DCNode *node);

        bool solve(std::vector<int> &solution);

    public:
        DancingLink(){ head = nullptr; row_rank = 0; col_rank = 0; has_solution = true;}

        DancingLink(const std::vector<std::vector<int>> &matrix) {
            buildDancingLink(matrix);
        }

        ~DancingLink(){releaseSpace();}

        void Rebuild(const std::vector<std::vector<int>> &matrix){
            releaseSpace();
            buildDancingLink(matrix);
        }

        std::vector<int> Solve();

        void DebugPrint();
    };

} // namespace DCLink

#endif //_DANCINGLINK_H
// dancing_link.cpp
#include "dancing_link.h"

namespace DCLink{

    void DancingLink::buildDancingLink(const std::vector<std::vector<int>> &matrix){
        assert(!matrix.empty() && !matrix[0].empty());
        row_rank = static_cast<int>(matrix.size());
        col_rank = static_cast<int>(matrix[0].size());
        has_solution = true;

        // check if has solution
        for(int i=0;i<row_rank;++i){
            int sum = 0;
            for(int j=0;j<col_rank;++j) sum += matrix[i][j];
            if(!sum){
                has_solution = false;
                head = nullptr;
                return;
            }
        }

        for(int j=0;j<col_rank;++j){
            int sum = 0;
            for(int i=0;i<row_rank;++i) sum += matrix[i][j];
            if(!sum){
                has_solution = false;
                head = nullptr;
                return;
            }
        }

        // build head and col-nodes
        head = new DCNode;
        DCNode *curr_tail = head;
        for(int i=0;i<col_rank;++i){
            DCNode *curr_col_node = new DCNode;
            curr_col_node->col = i;

            curr_tail->right = curr_col_node;
            curr_col_node->left = curr_tail;
            head->left = curr_col_node;
            curr_col_node->right = head;

            curr_tail = curr_col_node;
        }

        // build element nodes
        std::map<int, DCNode*> mapper; // node id -> node*
        int col_index = 0;
        DCNode *ptr = head->right;

        // vertical and create node
        while(ptr != head){
            DCNode *walk_down = ptr;
            for(int r=0;r<row_rank;++r){
                if(matrix[r][col_index]){
                    DCNode *new_node = new DCNode(r, ptr);

                    new_node->up = walk_down;
                    walk_down->down = new_node;
                    ptr->up = new_node;
                    new_node->down = ptr;

                    walk_down = new_node;
                    mapper[indexToId(r, col_index)] = new_node;
                }
            }
            ptr = ptr->right;
            ++col_index;
        }

        // horizontal
        DCNode *head, *tail;
        for(int i=0;i<row_rank;++i){
            head = nullptr; tail = nullptr;
            for(int j=0;j<col_rank;++j) if(matrix[i][j]){
                if(!head) {
                    head = mapper[indexToId(i,j)];
                    tail = head;
                }
                else{
                    DCNode *curr = mapper[indexToId(i,j)];
                    tail->right = curr;
                    curr->left = tail;
                    head->left = curr;
                    curr->right = head;

                    tail = curr;
                }
            }
        }
    }

    void DancingLink::releaseSpace(){
        if(!head) return;
        DCNode *ptr = head->right;
        while(ptr!=head){
            DCNode *walker = ptr->down;
            while(walker!=ptr){
                DCNode *tmp = walker;
                walker = walker->down;
                delete tmp;
            }
            DCNode *tmp = ptr;
            ptr = ptr->right;
            delete tmp;
        }
        delete head;
    }

    void DancingLink::ignoreRow(DCNode *node){
        // Ignore a row and remove nodes whose col_node is not removed
        DCNode *horizontal_walker = node->right;
        while(horizontal_walker != node){
            if(!horizontal_walker->col_node->deleted) 
                horizontal_walker->Remove(VERTICAL);
            horizontal_walker = horizontal_walker->right;
        }
    }

    void DancingLink::chooseRow(DCNode *node){
        // Choose a row to the solution, remove corresponding col_node
        DCNode *chosen_row_walker = node;
        do{
            chosen_row_walker->col_node->Remove(HORIZONTAL);
            chosen_row_walker = chosen_row_walker->right;
        }while(chosen_row_walker != node);
    }

    void DancingLink::unchooseRow(DCNode *node){
        // unchoose the row that we have chosen, recover corresponding col_node
        DCNode *horizontal_walker = node;
        do{
            horizontal_walker->col_node->Recover();
            horizontal_walker = horizontal_walker->right;
        }while(horizontal_walker != node);
    }

    void DancingLink::reconsiderRow(DCNode* node){
        DCNode *horizontal_walker = node->right;
        while(horizontal_walker != node){
            if(!horizontal_walker->col_node->deleted)
                horizontal_walker->Recover();
            horizontal_walker = horizontal_walker->right;
        }
    }

    bool DancingLink::solve(std::vector<int> &solution){
        if(head->right == head) return true;
        if(head->right->down == head->right) return false;
        DCNode *curr_col = head->right; // colomn that in consideration
        DCNode *choice = curr_col->down; // row we choose

        while(choice != curr_col){
            // First mark the chosen row and remove the col_node
            std::unordered_set<int> ignore_rows;
            chooseRow(choice);
            solution.push_back(choice->row);

            // Next remove rows that we will not choose.
            DCNode *horizontal_walker = choice;
            do{
                DCNode *vertical_walker = horizontal_walker->down;
                while(vertical_walker != horizontal_walker) {
                    if( vertical_walker->col_node != vertical_walker 
                    && ignore_rows.find(vertical_walker->row)==ignore_rows.end()){
                        ignore_rows.insert(vertical_walker->row);
                        ignoreRow(vertical_walker);
                        //printf("IgnoreRow passed\n");
                    }

                    vertical_walker = vertical_walker->down;
                }

                horizontal_walker = horizontal_walker->right;
            }while(horizontal_walker != choice);

            // recursion
            if(solve(solution)) return true;

            // if failed, first recover rows we ignore.
            solution.pop_back();
            do{
                DCNode *vertical_walker = horizontal_walker->down;
                while(vertical_walker != horizontal_walker){
                    if(vertical_walker->col_node != vertical_walker 
                    && ignore_rows.find(vertical_walker->row)!=ignore_rows.end()){
                        ignore_rows.erase(vertical_walker->row);
                        reconsiderRow(vertical_walker);
                    }

                    vertical_walker = vertical_walker->down;
                }

                horizontal_walker = horizontal_walker->right;
            }while(horizontal_walker != choice);

            // Next unmark row that we choose and recover the corresponding col_node
            unchooseRow(choice);

            // next choice
            choice = choice->down;
        }
        return false;
    }

    std::vector<int> DancingLink::Solve(){
        std::vector<int> solution;
        if(!has_solution) return solution;
        solve(solution);
        std::sort(solution.begin(), solution.end());
        return solution;
    }

    void DancingLink::DebugPrint(){
        printf("Head\n");
        DCNode *ptr = head->right;
        for(int i=0;i<col_rank;++i){
            printf("C%d", i);
            DCNode *walker = ptr->down;
            while(walker!=ptr){
                printf(
                    " -> ([%d, %d], left=[%d, %d], right=[%d, %d])", 
                    walker->row, i, 
                    walker->left->row, walker->left->col,
                    walker->right->row, walker->right->col
                );
                walker = walker->down;
            }
            printf("\n");
            ptr = ptr->right;
        }
    }
} // namespace DCLink