ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

21牛客9G - Glass Balls (树上概率dp)

2021-08-15 14:32:32  阅读:211  来源: 互联网

标签:Balls 21 int down Glass num nt po dp


题目

source

题解

对于从\(u\)点出发掉到\(v\)点的球来说,它的贡献是\(dep[u]-dep[v]\)。设对于一个固定的局面,掉到\(v\)点的球的球的个数为\(cnt[v]\),那么所有球的贡献为(即该局面的分数)为:

\[\sum\limits_{i=1}^{n}{dep[i]-\sum\limits_{i=1}^{n}{cnt[i]\cdot dep[i]}} \]

因此,只要分别求出深度总和的期望每个结点掉下去球数的期望即可,可以用树上dp计算。这里有几点要注意的:

  • 局面有合法的情况和非法的情况,因此在转移状态时注意确保的是从合法的子状态以合法的过程转移过来。
  • 树上dp一般计算的是子树的结果,在合并统计答案时要考虑上子树外部分的影响,这也是为什么往往需要两个dfs计算down和up的原因。

从题目中可以容易推得,每个结点的子节点中至多只有一个结点不是“储存点”,否则就是非法的。

设\(dp[i]\)为点\(i\)的子树中到\(i\)的球数的期望;\(down[i]\)为点\(i\)的子树为合法局面的概率;\(up[i]\)为整棵树在点\(i\)​为“储存点”时且除去了\(down[i]\)​的合法概率。这里的\(up[i]\)是为了将\(i\)子树中到点\(i\)的球数的期望转换为整棵树中从\(i\)掉下去的球数的期望,即\(cnt[i]=up[i] \times dp[i]\)。

显然,深度总和的期望就是整棵树合法的概率乘上深度的总和,即\(down[1] \times \sum\limits_{i=1}^n{dep[i]}\)。

\(down\)和\(up\)的转移都比较简单,主要是\(dp\)的转移。设\(P\)为“储存点的概率”,\(t\)为点p子结点的个数。

  • 子结点都是“储存点”,且子节点都合法,此时\(u\)中只有1个球,这种情况的贡献为:

\[dp[u]=1 \times P^t \times \prod_{v {\rm 是}u{\rm的子节点}} {down[v]} \]

  • 子结点\(v\)​不是”储存点“,且子节点都合法,此时\(u\)中除了本身的1个球,还有来自\(dp[v]\)那么多的球,这种情况的贡献为:

\[dp[u]=dp[v]\times P^{t-1}\times (1-P)\times \prod_{v'\neq v}{down[v']}+1\times P^{t-1}\times (1-P)\times\prod_{v' {\rm 是}u{\rm的子节点}} {down[v']} \]

最终答案为:\(down[1] \times \sum\limits_{i=1}^n{dep[i]}-\sum\limits_{i=1}^{n}{up[i] \times dp[i]\cdot dep[i]}\)

#include <bits/stdc++.h>

#define endl '\n'
#define IOS std::ios::sync_with_stdio(0); cin.tie(0); cout.tie(0)
#define mp make_pair
#define seteps(N) fixed << setprecision(N) 
typedef long long ll;

using namespace std;
/*-----------------------------------------------------------------*/

ll gcd(ll a, ll b) {return b ? gcd(b, a % b) : a;}
#define INF 0x3f3f3f3f

const int N = 5e5 + 10;
const int M = 998244353;
const double eps = 1e-5;

ll down[N];
ll up[N];
ll dp[N];
int dep[N];
ll po;
vector<int> np[N];

inline ll qpow(ll a, ll b, ll m) {
    ll res = 1;
    while(b) {
        if(b & 1) res = (res * a) % m;
        a = (a * a) % m;
        b = b >> 1;
    }
    return res;
}

void dfs(int p, int fa, int d) {
    dep[p] = d;
    for(int nt : np[p]) {
        if(nt == fa) continue;
        dfs(nt, p, d + 1);
    }
}

void caldown(int p, int fa) {
    ll lp = 1;
    int num = 0;
    for(int nt : np[p]) {
        if(nt == fa) continue;
        num++;
        caldown(nt, p);
        lp = lp * down[nt] % M;
    }
    if(num)
        lp = lp * (qpow(po, num - 1, M) * (1 - po + M) % M * num % M + qpow(po, num, M)) % M;
    down[p] = lp;
}

void calup(int p, int fa) {
    int num = 0;
    for(int nt : np[p]) {
        if(nt == fa) continue;
        num++;
        up[nt] = down[1] * qpow(down[nt], M - 2, M) % M;
    }
    if(num) {
        ll tp = (qpow(po, num - 1, M) * (1 - po + M) % M * num % M + qpow(po, num, M)) % M;
        for(int nt : np[p]) {
            if(nt == fa) continue;
            up[nt] = up[nt] * qpow(tp, M - 2, M) % M;
            up[nt] = up[nt] * (qpow(po, num - 1, M) * (1 - po + M) % M * (num - 1) % M + qpow(po, num, M)) % M;
            calup(nt, p);
        }
    }
}

void solve(int p, int fa) {
    int num = 0;
    ll lp = 1;
    for(int nt : np[p]) {
        if(nt == fa) continue;
        num++;
        lp = lp * down[nt] % M;
        solve(nt, p);
    }
    dp[p] = qpow(po, num, M) * lp % M;
    if(num)
        for(int nt : np[p]) {
            if(nt == fa) continue;
            // 注意后面1的贡献
            // 不要写成(dp[nt] + 1) * qpow(po, num - 1, M) % M * (1 - po + M) % M * lp % M * qpow(down[nt], M - 2, M) % M
            dp[p] += dp[nt] * qpow(po, num - 1, M) % M * (1 - po + M) % M * lp % M * qpow(down[nt], M - 2, M) % M + qpow(po, num - 1, M) % M * (1 - po + M) % M * lp % M;
            // 也可以写成
            // dp[p] += (dp[nt] + down[nt]) * qpow(po, num - 1, M) % M * (1 - po + M) % M * lp % M * qpow(down[nt], M - 2, M) % M;
            
            dp[p] %= M;
        }
}

int main() {
    IOS;
    up[1] = 1;
    int n;
    cin >> n >> po;
    for(int i = 2; i <= n; i++) {
        int f;
        cin >> f;
        np[i].push_back(f);
        np[f].push_back(i);
    }
    dfs(1, 0, 1);
    caldown(1, 0);
    calup(1, 0);
    solve(1, 0);
    ll ans = 0;
    ll tp = down[1];
    for(int i = 1; i <= n; i++) {
        ans = (ans + (tp - up[i] * (dp[i]) % M + M) * dep[i] % M) % M;
    }
    cout << ans << endl;
}

标签:Balls,21,int,down,Glass,num,nt,po,dp
来源: https://www.cnblogs.com/limil/p/15143352.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有