← Home
For problem statement at 1000-1999/1900-1999/1960-1969/1967/problemE1.txt this is a correct solution, but verifier at 1000-1999/1900-1999/1960-1969/1967/verifierE1.go ends with All tests passed can you fix the verifier? package main

import (
	"bufio"
	"fmt"
	"io"
	"os"
)

const MOD int64 = 998244353
const MAXN = 200000
const TH = 400

var fact []int64
var invFact []int64

type FastScanner struct {
	data []byte
	idx  int
}

func NewFastScanner() *FastScanner {
	data, _ := io.ReadAll(os.Stdin)
	return &FastScanner{data: data}
}

func (fs *FastScanner) NextInt() int {
	n := len(fs.data)
	for fs.idx < n && (fs.data[fs.idx] < '0' || fs.data[fs.idx] > '9') {
		fs.idx++
	}
	val := 0
	for fs.idx < n && fs.data[fs.idx] >= '0' && fs.data[fs.idx] <= '9' {
		val = val*10 + int(fs.data[fs.idx]-'0')
		fs.idx++
	}
	return val
}

func modPow(a, e int64) int64 {
	res := int64(1)
	for e > 0 {
		if e&1 == 1 {
			res = res * a % MOD
		}
		a = a * a % MOD
		e >>= 1
	}
	return res
}

func comb(n, k int) int64 {
	if k < 0 || k > n {
		return 0
	}
	return fact[n] * invFact[k] % MOD * invFact[n-k] % MOD
}

func solve(n, m, b0 int) int64 {
	powM := make([]int64, n+1)
	powM[0] = 1
	mm := int64(m)
	for i := 1; i <= n; i++ {
		powM[i] = powM[i-1] * mm % MOD
	}

	if b0 >= m {
		return powM[n]
	}

	s := b0 + 1
	if n < s {
		return powM[n]
	}

	if m <= TH {
		dp := make([]int64, m+2)
		nx := make([]int64, m+2)
		dp[s] = 1
		invalid := int64(0)
		mul := int64(m - 1)

		for step := 1; step <= n; step++ {
			for i := 1; i <= m; i++ {
				nx[i] = 0
			}
			rem := powM[n-step]

			if m == 1 {
				v := dp[1]
				if v != 0 {
					invalid += v * rem % MOD
					if invalid >= MOD {
						invalid -= MOD
					}
				}
			} else {
				v := dp[1]
				if v != 0 {
					invalid += v * rem % MOD
					if invalid >= MOD {
						invalid -= MOD
					}
					add := v * mul % MOD
					nx[2] += add
					if nx[2] >= MOD {
						nx[2] -= MOD
					}
				}

				for x := 2; x < m; x++ {
					v = dp[x]
					if v == 0 {
						continue
					}
					nx[x-1] += v
					if nx[x-1] >= MOD {
						nx[x-1] -= MOD
					}
					add := v * mul % MOD
					nx[x+1] += add
					if nx[x+1] >= MOD {
						nx[x+1] -= MOD
					}
				}

				v = dp[m]
				if v != 0 {
					nx[m-1] += v
					if nx[m-1] >= MOD {
						nx[m-1] -= MOD
					}
				}
			}

			dp, nx = nx, dp
		}

		ans := powM[n] - invalid
		if ans < 0 {
			ans += MOD
		}
		return ans
	}

	M := m + 1
	Umax := (n - s) / 2
	q := int64(m - 1)
	powQ := int64(1)
	invalid := int64(0)

	for u := 0; u <= Umax; u++ {
		T := s + 2*u - 1
		total := int64(0)

		r1 := u % M
		for j := r1; j <= T; j += M {
			total += comb(T, j)
			if total >= MOD {
				total -= MOD
			}
		}

		r2 := (u + s) % M
		for j := r2; j <= T; j += M {
			total -= comb(T, j)
			if total < 0 {
				total += MOD
			}
		}

		term := total * powQ % MOD * powM[n-s-2*u] % MOD
		invalid += term
		if invalid >= MOD {
			invalid -= MOD
		}

		powQ = powQ * q % MOD
	}

	ans := powM[n] - invalid
	if ans < 0 {
		ans += MOD
	}
	return ans
}

func main() {
	fact = make([]int64, MAXN+1)
	invFact = make([]int64, MAXN+1)
	fact[0] = 1
	for i := 1; i <= MAXN; i++ {
		fact[i] = fact[i-1] * int64(i) % MOD
	}
	invFact[MAXN] = modPow(fact[MAXN], MOD-2)
	for i := MAXN; i >= 1; i-- {
		invFact[i-1] = invFact[i] * int64(i) % MOD
	}

	fs := NewFastScanner()
	t := fs.NextInt()
	out := bufio.NewWriterSize(os.Stdout, 1<<20)
	for ; t > 0; t-- {
		n := fs.NextInt()
		m := fs.NextInt()
		b0 := fs.NextInt()
		fmt.Fprintln(out, solve(n, m, b0))
	}
	out.Flush()
}