LG2014 [CTSC1997]选课

发布于 2020-01-18  2.98k 次阅读


题面

对于这道题,我们考虑在树形dp上套背包。我们会非常自然的采用dfs扫描整棵树,然后对树上的每个节点都进行一次背包。

计$dp[i][j]$为在以第$i$号节点为根结点的子树中,用题目中选法选取$j$项的最大值。

我们在dfs的过程中,采用递归的方式,在子节点都处理完之后,便考虑将所有子节点的答案综合,得到当前节点的答案。

很显然,就是在容量为$j$的01背包中放下$i$节点的所有子节点背包中的答案,我们很容易想到下面的DP方程
$$f[x][j]=max(f[to][k]+f[x][j-k])$$
$x$为当前节点,$to$为它的某个孩子节点,$j$和$k$是枚举的两个变量。

就得到了

for(int i=0;j<G[x].size();i++){
    int to=G[x][i];
    dfs(to,x);
    for(int j=m+1;j>=1;j--){
        for(int k=m;k>=0;k--){
            if(j-k<1) continue;
            f[x][j]=max(f[x][j],f[to][k]+f[x][j-k]);
        }
    }
}

最终的答案就是$f[0][m+1]$了,因为所有没有前提条件的课程都可以指向$0$,即让第$0$号课程成为他们的先决条件,因为$0$号课程自身也算一个课程,所以第二项就是$m+1$了

优化

显然,上述算法时间复杂度为$O(N*M^2)$,虽然在本题已经可以通过,但是如果数据量增加到$n<=2000$时,应该如何应对

由于对于每一个节点上的背包,我们每一次都枚举到了$m+1$,但是大部分情况下实际背包一般不会那么大,所以在时间上会产生很大的开销。

考虑下面的这份代码

int dfs(int node)
{
    int sum=0;
    f[node][1]=cost[node];
    for(int i=0;j<G[x].size();i++){
        int to=G[node][i];
        int cnt=dfs(to);
        sum+=cnt;
        for(int j=m+1;j>=1;j--)
            for(int k=min(cnt,j);k>=0;k--){
                if(j<k+1) continue;
                f[node][j]=max(f[node][j],f[to][k]+f[node][j-k]);
            }
    }
    sum++;
    return sum;
}

用$sum$存放当前节点为根的子树的大小,同时dfs值为子树大小直接进行传值。

显然背包的第二重循环是可以像这样优化的,毕竟子节点产生的序列大小也才那么点,这样的话整个子树上的所有情况都只会在这个节点上体现一次(显然之前的方法大量冗余的计算会使得它制造的无用情况会大大多于现在),所以整棵树一共是$n$个节点,对于每个节点它的孩子的若干个$m$情况只会经过一次,并且在$O(1)$时间内直接取出每组情况最优值,所以对于每个点的复杂度是$O(m)$的,所以总的算法时间复杂度就为$O(N*M)$了,就可以对付增强后的数据了。

代码

#include <bits/stdc++.h>
using namespace std;

vector<vector<int> > graph;
vector<int> cost;
int n,m;

int f[2005][2005];

int dfs(int node)
{
    int sum=0;
    f[node][1]=cost[node];
    for(vector<int>::iterator i=graph[node].begin();i!=graph[node].end();i++){
        int cnt=dfs(*i);
        sum+=cnt;
        for(int j=m+1;j>=1;j--)
            for(int k=sum;k>=0;k--){
                if(j<k+1) continue;
                f[node][j]=max(f[node][j],f[*i][k]+f[node][j-k]);
            }
    }
    sum++;
    return sum;
}

int main()
{
    scanf("%d%d",&n,&m);
    graph.resize(n+1);cost.resize(n+1);
    for(int i=1;i<=n;i++){
        int father;
        scanf("%d%d",&father,&cost[i]);
        graph[father].push_back(i);
    }
    dfs(0);
    printf("%d\n",f[0][m+1]);
    return 0;
}

我缓慢吐出一串啊吧啊吧并不再想说话