Idea

Suppose k = 1, what the answer woule be?

For each edges, it will be the combination of the vs on the two sides, i.e., suppose the edge connects to (u, v), with v as the root of the subtree, the contribution of that edge will be S(v) * n - S(v), and the final answer is the sum of contribution over all edges

Insight: we have a closed form for this answer, and yet, it can be expressed as sum(L /1), where L is the length of a single path

When k > 1, we know the final answer should be around sum(L/k), to be more specific, it will be sum(ceil(L/k))

Now we can separate the final answer into 2 parts

1. sum(L/k) = sum(L/1) / k = (answer in k = 1 case ) / k

2. sum(f(L, k)), where f(L, k) = k - (L % k), or 0 if L % k = 0 over all L

Now the problem becomes calculating the second part. The key is to swap sum over order, i.e., instead of caculating over all L, we can calculate over all k!!!

That is, for each remainder, how many paths we have with L s.t. L % k = remainder, for each remainder, we need to add remainder * numer to the second part

How to calculate the number of paths?!!!

Insight

1. We decompose any shortest path into 3 parts: a -> LCA -> root of the tree -> LCA -> b

2. Therefore, we need to keep track of number of paths with length L % k, which starts at the root and passes through each node

3. During bottom-up calculation, we will calculate only the count in 2, and update the final answer 

4. since in DFS we add child one by one, we need to update the answer, and then update thr subtree root state based on the child state

Implementation

def dfs(root, parent, depth) =

count[root][depth %k] = 1 //just reach the root of this subtree

for each child
  dfs(child, root, depth + 1) //calculate bottom up

  for pathLenPassingRootFromOrigin in [0.. k) //for simplicity, we call the root of the whole tree graph Origin
    for pathLenPassingChildFromOrigin in [0..k)
        int distMod = (pathLenPassingRootFromOrigin + pathLenPassingChildFromOrigin - 2 * depth) % k + k % k;
        int needToAdd = k - distMod

/*
need to
multiply here, because count[root] represents all visited child of the root including itself, and yet we are visiting a new child, so any
new path we found on the child can form a combinition with discovered paths going through roots so far.
*/
        answer += needToAdd * count[root][pathLenPassingRootFromOrigin]  * count_subtree[child][pathLenPassingChildFromOrigin]

/*
done processing that child, we need to update the counts of the paths with we have discovered so far, not e if we have path from origin to
child with length L, then we also have a pth from origin to parent with lenght L!

*/
  for pathLenPassingRootFromOrigin in [0..k)
        count[root][pathLenPasingRootFromOrigin] += count[child][pathLenPassingRootFromOrigin]