// O(N) algoritm #include #include #include #include #define llong long long using namespace::std; const int MAXN = 5e5 + 5; const llong INF = 1e18; const int ROOT = 0; vector > adj(MAXN, vector(0)); vector participants(MAXN, 0); int N, K; /* Lahendus kasutab modifitseeritud tsentroidi leidmise algoritmi. Harilikus tsentroidi leidmise algoritmis juurestatakse puu ära ja salvestatakse iga tipu alampuu kaal. Tsentroid leitakse liikudes juurest rekursiivelt niikaua kõige raskema lapse suunas, kuni kõige raskema lapse kaal muutub väiksemaks kui K / 2. Selleks, et antud algoritmi kohandada selle ülesande jaoks, määrame kõigepealt juure jaoks osalejaslinnade kauguste ruutude summad. Analoogselt tsentroidi leidmise algortimile on vaja toimumislinna nihutada juurest lasteni niikaua kuni lapsele nihutamine vähendab transpordikulusid. Vaatleme olukorda, kus osalemislinna nihutatakse lapse x suunas. Olgu x alampuu osalejaslinnade kaugused x-ni d1, d2, d3, ... , dm. x-i alampuus olevate linnade arvelt väheneb transpordi kogukulu 2 * (d1 + d2 + ... dm) + m võrra (selles saab veenduda kui d1^2 + ... + dm^2 - (d1 + 1)^2 - ... - (dm + 1)^2 sulud avada). Analoogselt muutub transpordikulud suuremaks osalejaslinnade arvelt, mis asuvad x-i alampuust väljaspool. Seega on lisainfona vaja leida iga tipu alampuu osalejaslinnade arv ja nende kauguste summad alampuu juureni. */ struct Node { int child_cnt; // tipu alampuus olevate osalejaslinnade arv llong child_dist_sum; // tipu alampuus olevate osalejaslinnade kauguste summa int ancestor_cnt; // väljaspool tipu alampuud ülejäänud puu osalejaslinnade arv llong ancestor_dist_sum; // väljaspool tipu alampuud ülejäänud puu osalejaslinnade kauguste summa Node() { child_cnt = 0; child_dist_sum = 0; ancestor_cnt = 0; ancestor_dist_sum = 0; } }; vector nodes(MAXN); void init(int pos, int parent) // sügavuti otsing, mis initsialiseerib child_cnt ja child_dist_sum { int cnt = 0; llong dist_sum = 0; for (auto i : adj[pos]) { if (i != parent) { init(i, pos); cnt += nodes[i].child_cnt; dist_sum += nodes[i].child_dist_sum; } } dist_sum += cnt; if (participants[pos]) { cnt++; } nodes[pos].child_cnt = cnt; nodes[pos].child_dist_sum = dist_sum; return; } void dfs(int pos, int parent) // sügavuti otsing, mis initsialiseerib ancestor_cnt ja ancestor_dist_sum { llong tot_dist_sum = nodes[pos].ancestor_dist_sum + nodes[pos].child_dist_sum; // iga lapse jaoks arvutatakse ancestor_cnt ja ancestor_dist_sum, pos jaoks on vastavad väärtused juba leitud for (auto i : adj[pos]) { if (i != parent) { nodes[i].ancestor_cnt = K - nodes[i].child_cnt; nodes[i].ancestor_dist_sum = (tot_dist_sum - nodes[i].child_dist_sum - nodes[i].child_cnt) + nodes[i].ancestor_cnt; } } for (auto i : adj[pos]) { if (i != parent) { dfs(i, pos); } } } int main() { ifstream fin("transsis.txt"); ofstream fout("transval.txt"); fin >> N >> K; for (int i = 0; i < N - 1; i++) { int u, v; fin >> u >> v; u--; v--; adj[u].push_back(v); adj[v].push_back(u); } for (int i = 0; i < K; i++) { int u; fin >> u; u--; participants[u] = 1; } fin.close(); init(ROOT, -1); // child_cnt ja child_dist_sum määramine dfs(ROOT, -1); // ancestor_cnt ja ancestor_dist_sum määramine int pos = ROOT; int parent = -1; llong answer = 0; vector visited(N, 0); queue > Q; Q.push({pos, 0}); while(!Q.empty()) { // määrame juure jaoks transpordikulud auto cur = Q.front(); Q.pop(); visited[cur.first] = 1; if (participants[cur.first]) { answer += (llong)cur.second * cur.second; } for (auto j : adj[cur.first]) { if (!visited[j]) { Q.push({j, cur.second + 1}); } } } while (1) { // alustades juurest liigume niikaua laste suunas kui see transpordikulusid vähendab int next = -1; // kõige parema transpordikulude muuduga laps llong best_change = 0; // kõige parema transpordikulude muuduga lapse transpordikulude muut for (auto i : adj[pos]) { if (i != parent) { llong change = 2 * nodes[i].ancestor_dist_sum - nodes[i].ancestor_cnt; change -= 2 * nodes[i].child_dist_sum + nodes[i].child_cnt; if (change < best_change) { next = i; best_change = change; } } } answer += best_change; if (best_change == 0) { fout << answer; fout.close(); return 0; } parent = pos; pos = next; } return 0; }