POJ2054-Color a Tree
这是我近期写过最难的贪心题之一。
题目大意:给出一棵有根树,要把这棵树的每个节点都涂色。除了根节点外,只有当一个结点的所有祖先都涂上颜色了之后才能给这个结点涂色。给每个点i涂色都有一个代价系数c[i]
,在时刻t给i结点涂色的代价是t * c[i]
,问怎样涂色能使得将整棵树涂色所用代价最小。
首先考虑动态规划。但虽然这题的n看着不大,但将一个节点涂色的先决条件是祖先都全部涂上色,动态规划的状态也必然需要将一个点各祖先的涂色状态都记录下来,这样就肯定就爆复杂度了,当然不可行。
贪心,第一个想到的贪心就是每次找到可选的点中代价系数c最大的(我用的是set),将它涂上色,直到所有点都涂完颜色。虽然这么写可以过样例,但却是错误的。看一下下面这个案例就明白了
1 2 3 4 5
| 4 1 1 100 2 1000 1 2 1 3 3 4
|
正确的输出是3405,涂色顺序是1->3->4->2。
这时我又想,是否可以每次选出可以选的所有点中,子树上代价系数c之和最大的点来涂色,但很快又想到了问题:如果两个子树上代价系数之和相等怎么办。我又想到了优先选子树上代价系数c的平均数最大的点来涂色,稀里糊涂地写了出来结果连样例都过不了,灰溜溜地看题解去了。
题解指出:在操作中,(除根节点外)代价系数c最大的节点一定紧接着它的父节点涂色,因此可以将c最大的结点与它的父节点看成是一个节点来进行操作,即将它们合并成一个新节点,新节点的父节点是原父节点的父节点,子节点包括原来的这个最大节点和它的父节点子节点。
知道这一点有什么用?如果能等效出合并成的新节点的“代价系数”,作为点的权值,就可以对于整棵树不断找出权值最大的结点,合并,直至整棵树只剩根节点。最后就将根节点按照合并的顺序输出就好了。
那么这个等效的权值应该怎么算?
首先明确一点:之所以每次都找权值最大的点跟父节点合并,是因为一旦它的父亲涂色了,就需要立刻将它与涂色,如果先处理其它点,最后产生的总代价一定比立刻给这个点涂色要大。
以这样的一棵树为例
1 2 3 4 5
| 4 1 1 100 2 1000 1 2 1 3 3 4
|
第一步显然是将4号结点与它的父亲3号结点合并成一个新的结点。然后就面临了下面两种选择:
- 将2号结点与它的父节点1号结点合并
- 将刚刚形成的新节点与他的父结点1号结点合并
选择1的本质是:在根节点涂色之后马上将2号结点涂色,之后再给3号4号结点涂色;选择2的本质是在根节点涂色之后马上将3号结点涂色,然后再马上将4号结点涂色,之后再给2号结点涂色。
由于根节点一定要在第一个涂色,两种选择的最终总代价分别是
1 2 3
| 1*c[1] + 2*c[2] + 3*c[3] + 4*c[4]
1*c[1] + 2*c[3] + 3*c[4] + 4*c[2]
|
做差,得到
当这个式子的值大于0,就说明选择1的最终代价比选择2大,反之亦然。因此需要比较c[3]+c[4]
和2*c[2]
的大小关系,也就是c[3]、c[4]
的平均数和c[2]
的关系。
因此,将点合并后,新点的权值就可以等效为它所包含所有点的代价系数c的平均数,并依照这个权值决定后续的合并操作,(这一步离上面的推导在逻辑较远,因为博主无法找到合适的言语来描述其间的逻辑关系。如果没看懂的话建议鲨害博主)。
知道这一点就很简单了:每个点的初始权值都为这个点的代价系数c,然后每次都在树上找到除根节点外权值最大的点,将它与它的父节点合并成一个新的结点,直到整棵树只剩下根节点。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
| #include <bits/stdc++.h> using namespace std; int c[1005]; int pre[1005]; int n, rt; vector<int> son[1005]; struct nd { int root; double val; double csum; queue<int> q; bool operator<(const nd &x) const { if (val == x.val) return root > x.root; return val > x.val; } } node[1005]; int main() { int i, j, u, v, root, res; set<nd> st; while (scanf("%d%d", &n, &rt)) { if (n == 0 && rt == 0) return 0; st.clear(); for (i = 1; i <= n; i++) { scanf("%d", &c[i]); son[i].clear(); node[i].root = i; node[i].val = node[i].csum = c[i]; node[i].q.push(i); } for (i = 1; i < n; i++) { scanf("%d%d", &u, &v); pre[v] = u; son[u].push_back(v); } for (i = 1; i <= n; i++) { if (i != rt) st.insert(node[i]); } while (!st.empty()) { root = st.begin()->root; st.erase(node[root]); if (pre[root] != rt) st.erase(node[pre[root]]); while (!node[root].q.empty()) { node[pre[root]].q.push(node[root].q.front()); node[root].q.pop(); } for (i = 0; i < son[root].size(); i++) { pre[son[root][i]] = pre[root]; son[pre[root]].push_back(son[root][i]); } node[pre[root]].csum += node[root].csum; node[pre[root]].val = node[pre[root]].csum / node[pre[root]].q.size(); if (pre[root] != rt) st.insert(node[pre[root]]); } res = 0; for (i = 1; !node[rt].q.empty(); i++) { res += c[node[rt].q.front()] * i; node[rt].q.pop(); } printf("%d\n", res); } }
|