For problem statement at 1000-1999/1600-1699/1610-1619/1613/problemF.txt this is a correct solution, but verifier at 1000-1999/1600-1699/1610-1619/1613/verifierF.go ends with All tests passed can you fix the verifier? package main
import (
"container/heap"
"fmt"
"io"
"os"
)
const MOD = 998244353
const G = 3
func power(base, exp int64) int64 {
res := int64(1)
base %= MOD
for exp > 0 {
if exp%2 == 1 {
res = (res * base) % MOD
}
base = (base * base) % MOD
exp /= 2
}
return res
}
func modInverse(n int64) int64 {
return power(n, MOD-2)
}
func ntt(a []int64, invert bool) {
n := len(a)
j := 0
for i := 1; i < n; i++ {
bit := n >> 1
for j&bit != 0 {
j ^= bit
bit >>= 1
}
j ^= bit
if i < j {
a[i], a[j] = a[j], a[i]
}
}
for length := 2; length <= n; length <<= 1 {
half := length >> 1
wLen := power(G, (MOD-1)/int64(length))
if invert {
wLen = modInverse(wLen)
}
for i := 0; i < n; i += length {
w := int64(1)
for j := 0; j < half; j++ {
u := a[i+j]
v := (a[i+j+half] * w) % MOD
a[i+j] = u + v
if a[i+j] >= MOD {
a[i+j] -= MOD
}
a[i+j+half] = u - v + MOD
if a[i+j+half] >= MOD {
a[i+j+half] -= MOD
}
w = (w * wLen) % MOD
}
}
}
if invert {
invN := modInverse(int64(n))
for i := 0; i < n; i++ {
a[i] = (a[i] * invN) % MOD
}
}
}
func multiply(a, b []int64) []int64 {
n1 := len(a)
n2 := len(b)
if n1 == 0 || n2 == 0 {
return []int64{}
}
if n1 < 64 || n2 < 64 || n1*n2 <= 4096 {
res := make([]uint64, n1+n2-1)
for i := 0; i < n1; i++ {
if a[i] == 0 {
continue
}
ui := uint64(a[i])
for j := 0; j < n2; j++ {
res[i+j] += ui * uint64(b[j])
if res[i+j] >= 8e18 {
res[i+j] %= MOD
}
}
}
out := make([]int64, n1+n2-1)
for i := 0; i < len(out); i++ {
out[i] = int64(res[i] % MOD)
}
return out
}
n := 1
for n < n1+n2-1 {
n <<= 1
}
A := make([]int64, n)
copy(A, a)
B := make([]int64, n)
copy(B, b)
ntt(A, false)
ntt(B, false)
for i := 0; i < n; i++ {
A[i] = (A[i] * B[i]) % MOD
}
ntt(A, true)
return A[:n1+n2-1]
}
var fact, invFact []int64
func initFactorials(n int) {
fact = make([]int64, n+1)
invFact = make([]int64, n+1)
fact[0] = 1
invFact[0] = 1
for i := 1; i <= n; i++ {
fact[i] = (fact[i-1] * int64(i)) % MOD
}
invFact[n] = modInverse(fact[n])
for i := n - 1; i >= 1; i-- {
invFact[i] = (invFact[i+1] * int64(i+1)) % MOD
}
}
func nCr(n, r int) int64 {
if r < 0 || r > n {
return 0
}
return fact[n] * invFact[r] % MOD * invFact[n-r] % MOD
}
func buildPoly(d int64, C int) []int64 {
P := make([]int64, C+1)
dPow := int64(1)
for k := 0; k <= C; k++ {
P[k] = (nCr(C, k) * dPow) % MOD
dPow = (dPow * d) % MOD
}
return P
}
type Poly []int64
type PolyHeap []Poly
func (h PolyHeap) Len() int { return len(h) }
func (h PolyHeap) Less(i, j int) bool { return len(h[i]) < len(h[j]) }
func (h PolyHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
func (h *PolyHeap) Push(x interface{}) {
*h = append(*h, x.(Poly))
}
func (h *PolyHeap) Pop() interface{} {
old := *h
n := len(old)
x := old[n-1]
*h = old[0 : n-1]
return x
}
func main() {
input, _ := io.ReadAll(os.Stdin)
idx := 0
nextInt := func() int {
for idx < len(input) && (input[idx] < '0' || input[idx] > '9') {
idx++
}
if idx >= len(input) {
return 0
}
res := 0
for idx < len(input) && input[idx] >= '0' && input[idx] <= '9' {
res = res*10 + int(input[idx]-'0')
idx++
}
return res
}
n := nextInt()
if n == 0 {
return
}
adj := make([][]int, n+1)
for i := 0; i < n-1; i++ {
u := nextInt()
v := nextInt()
adj[u] = append(adj[u], v)
adj[v] = append(adj[v], u)
}
initFactorials(n + 5)
deg := make([]int, n+1)
queue := make([]int, 0, n)
queue = append(queue, 1)
visited := make([]bool, n+1)
visited[1] = true
for len(queue) > 0 {
u := queue[0]
queue = queue[1:]
outCount := 0
for _, v := range adj[u] {
if !visited[v] {
visited[v] = true
outCount++
queue = append(queue, v)
}
}
deg[u] = outCount
}
counts := make(map[int]int)
for i := 1; i <= n; i++ {
if deg[i] > 0 {
counts[deg[i]]++
}
}
h := &PolyHeap{}
heap.Init(h)
for d, c := range counts {
heap.Push(h, Poly(buildPoly(int64(d), c)))
}
if h.Len() == 0 {
heap.Push(h, Poly{1})
}
for h.Len() > 1 {
p1 := heap.Pop(h).(Poly)
p2 := heap.Pop(h).(Poly)
p3 := multiply(p1, p2)
heap.Push(h, Poly(p3))
}
ansPoly := heap.Pop(h).(Poly)
ans := int64(0)
for k := 0; k < len(ansPoly); k++ {
term := (ansPoly[k] * fact[n-k]) % MOD
if k%2 == 1 {
ans = (ans - term + MOD) % MOD
} else {
ans = (ans + term) % MOD
}
}
fmt.Println(ans)
}