【问题描述】
在n个城市建立起一套通信系统。 n个城市两两之间有且仅有一条简单路径。
一个通信系统的选择方案是随机选择一段连续序号的点, 方案的代价为从被选择
的点中选择任意一个点, 从该点出发遍历所有被选择点, 并回到出发点的总路径。
现在,你的任务就是求出通信系统代价的期望值。 (对 1000000007取模)
【输入格式】
从文件 communicate.in 中输入数据。
输入的第一行包含一个整数n,表示城市的数量。
第2 至 n 行每行两个整数x,y,表示在 x 和 y 之间连一条边。
【输出格式】
输出到文件 communicate.out 中。
输出一个整数,表示期望值。
【样例输入】
10
2 1
3 2
4 3
5 3
6 5
7 1
8 2
9 3
10 6
【样例输出】
490909103
【数据规模与约定】
对于 20%的数据,n<=100
对于 40%的数据,n<=1000
对于 60%的数据,n<=5000
对于 80%的数据,n<=30000
对于 100%的数据,n<=100000

首先有个结论,如果你选出来了一些点,那么这些点的答案就是这些点构成的虚树的边权和*2

我们考虑每一条边$(x, father_x)$对答案的贡献。

如果把$x$为根的子树在序列上标记为黑色,其他为白色的话。

选出来的区间如果既有黑色又有白色,这条边对这个区间有2的贡献。

考虑容斥,我们只要求出这个序列中有多少个区间内元素颜色相同的区间就好了。

这个可以用线段树来维护,这是一条边的贡献。接下来我们只要合并这些线段树合并就好了。

  1 #include <bits/stdc++.h>
  2 using namespace std;
  3 #define M 100010
  4 #define MOD 1000000007
  5 #define f(n) ((1ll * n * (n + 1) / 2) % MOD)
  6 inline int read() {
  7     char ch = getchar(); int x = 0, f = 1;
  8     while(ch < '0' || ch > '9') {
  9         if(ch == '-') f = -1;
 10         ch = getchar();
 11     }
 12     while('0' <= ch && ch <= '9') {
 13         x = x * 10 + ch - '0';
 14         ch = getchar();
 15     }
 16     return x * f;
 17 }
 18 struct Edge{
 19     int u, v, Next;
 20 } G[M * 2];
 21 int head[M], tot;
 22 int n, Ans;
 23 inline void add(int u, int v) {
 24     G[++ tot] = (Edge){u, v, head[u]};
 25     head[u] = tot;
 26 }
 27 int rt[M], ls[M * 40], rs[M * 40], cnt;
 28 int res[M * 40], L[M * 40], R[M * 40];
 29 inline void maintain(int l, int r, int mid, int o) {
 30     L[o] = R[o] = res[o] = 0;
 31     int lsL = L[ls[o]], lsR = R[ls[o]];
 32     int rsL = L[rs[o]], rsR = R[rs[o]];
 33     if(!ls[o]) {
 34         lsL = mid - l + 1;
 35         lsR = mid - l + 1;
 36     }
 37     if(!rs[o]) {
 38         rsL = r - mid;
 39         rsR = r - mid;
 40     }
 41     if(abs(lsL + rsR) == r - l + 1) {
 42         L[o] = R[o] = lsL + rsR;
 43         return;
 44     }
 45     if(abs(lsL) + abs(rsR) == r - l + 1) {
 46         L[o] = lsL; R[o] = rsR;
 47         return;
 48     }
 49     if(abs(lsL) == mid - l + 1) {
 50         if(1ll * lsL * rsL > 0) {
 51             L[o] = lsL + rsL;
 52             R[o] = rsR;
 53             res[o] = res[rs[o]];
 54         }
 55         else {
 56             L[o] = lsL;
 57             R[o] = rsR;
 58             res[o] = res[rs[o]] + f(abs(rsL));
 59             if(res[o] >= MOD) res[o] -= MOD;
 60         }
 61         return;
 62     }
 63     if(abs(rsR) == r - mid) {
 64         if(1ll * rsR * lsR > 0) {
 65             R[o] = rsR + lsR;
 66             L[o] = lsL;
 67             res[o] = res[ls[o]];
 68         }
 69         else {
 70             R[o] = rsR;
 71             L[o] = lsL;
 72             res[o] = res[ls[o]] + f(abs(lsR));
 73             if(res[o] >= MOD) res[o] -= MOD; 
 74         }
 75         return;
 76     }
 77     L[o] = lsL; R[o] = rsR;
 78     res[o] = res[ls[o]] + res[rs[o]];
 79     if(res[o] >= MOD) res[o] -= MOD;
 80     if(1ll * lsR * rsL > 0) {
 81         res[o] += f(abs(lsR + rsL));
 82         if(res[o] >= MOD) res[o] -= MOD;
 83     }
 84     else {
 85         res[o] += (f(abs(lsR)) + f(abs(rsL))) % MOD;
 86         if(res[o] >= MOD) res[o] -= MOD;
 87     }
 88 }
 89 inline void insert(int &o, int l, int r, int x) {
 90     if(!o) o = ++ cnt;
 91     if(l == r) {
 92         L[o] = R[o] = -1;
 93         return;
 94     }
 95     int mid = (l + r) / 2;
 96     if(x <= mid) insert(ls[o], l, mid, x);
 97     else insert(rs[o], mid + 1, r, x);
 98     maintain(l, r, mid, o);
 99 }
100 inline int merge(int x, int y, int l, int r) {
101     if(!x || !y) return x + y;
102     int ret = ++ cnt;
103     int mid = (l + r) / 2;
104     ls[ret] = merge(ls[x], ls[y], l, mid);
105     rs[ret] = merge(rs[x], rs[y], mid + 1, r);
106     maintain(l, r, mid, ret);
107     return ret;
108 }
109 inline int query(int o) {
110     if(abs(L[o]) == n) return f(n);
111     return (1ll * res[o] + f(abs(L[o])) + f(abs(R[o]))) % MOD;
112 }
113 inline void dfs(int x, int fa) {
114     insert(rt[x], 1, n, x);
115     for(int i = head[x]; i != -1; i = G[i].Next) {
116         if(G[i].v == fa) continue;
117         dfs(G[i].v, x);
118         rt[x] = merge(rt[G[i].v], rt[x], 1, n);
119     }
120     Ans += (f(n) - query(rt[x]) + MOD) % MOD;
121     if(Ans >= MOD) Ans -= MOD;
122 }
123 inline int Power(int x, int y) {
124     int ret = 1;
125     while(y) {
126         if(y & 1) ret = 1ll * ret * x % MOD;
127         x = 1ll * x * x % MOD;
128         y >>= 1;
129     }
130     return ret;
131 }
132 int main() {
133     n = read();
134     memset(head, -1, sizeof(head));
135     for(int i = 1; i < n; ++ i) {
136         int u = read(), v = read();
137         add(u, v); add(v, u);
138     }
139     dfs(1, 0);
140     printf("%d
", 2ll * Ans * Power(f(n), MOD - 2) % MOD);
141 }