← Home
For problem statement at 1000-1999/1900-1999/1910-1919/1917/problemD.txt this is a correct solution, but verifier at 1000-1999/1900-1999/1910-1919/1917/verifierD.go ends with All 100 tests passed can you fix the verifier? package main

import (
	"io"
	"math/bits"
	"os"
	"strconv"
)

const MOD int64 = 998244353
const LIM int64 = 1 << 60

type FastScanner struct {
	data []byte
	idx  int
}

func (fs *FastScanner) nextInt() int {
	n := len(fs.data)
	for fs.idx < n && fs.data[fs.idx] <= ' ' {
		fs.idx++
	}
	if fs.idx >= n {
		return 0
	}
	sign := 1
	if fs.data[fs.idx] == '-' {
		sign = -1
		fs.idx++
	}
	val := 0
	for fs.idx < n {
		c := fs.data[fs.idx]
		if c < '0' || c > '9' {
			break
		}
		val = val*10 + int(c-'0')
		fs.idx++
	}
	return sign * val
}

func add(bit []int, idx int) {
	for idx < len(bit) {
		bit[idx]++
		idx += idx & -idx
	}
}

func sum(bit []int, idx int) int {
	res := 0
	for idx > 0 {
		res += bit[idx]
		idx -= idx & -idx
	}
	return res
}

func rangeCount(bit []int, n, maxVal, lVal, rVal int) int {
	if lVal > rVal || rVal < 1 || lVal > maxVal {
		return 0
	}
	if lVal < 1 {
		lVal = 1
	}
	if rVal > maxVal {
		rVal = maxVal
	}
	l := (lVal + 2) >> 1
	r := (rVal + 1) >> 1
	if l < 1 {
		l = 1
	}
	if r > n {
		r = n
	}
	if l > r {
		return 0
	}
	return sum(bit, r) - sum(bit, l-1)
}

func countPairsGE(k int, t int) int64 {
	kk := int64(k)
	if t <= -(k - 1) {
		return kk * kk
	}
	if t > k-1 {
		return 0
	}
	if t >= 0 {
		m := int64(k - t)
		return m * (m + 1) / 2
	}
	s := int64(-t)
	return kk*kk - (kk-s-1)*(kk-s)/2
}

func main() {
	data, _ := io.ReadAll(os.Stdin)
	fs := FastScanner{data: data}

	t := fs.nextInt()
	out := make([]byte, 0, t*12)

	for ; t > 0; t-- {
		n := fs.nextInt()
		k := fs.nextInt()

		maxVal := 2*n - 1
		maxShift := bits.Len(uint(maxVal))

		posMod := make([]int64, maxShift+1)
		negMod := make([]int64, maxShift+1)
		for i := 1; i <= maxShift; i++ {
			posMod[i] = countPairsGE(k, i) % MOD
		}
		for i := 0; i <= maxShift; i++ {
			negMod[i] = countPairsGE(k, -i) % MOD
		}
		tailMod := (int64(k) % MOD) * (int64(k) % MOD) % MOD

		bitP := make([]int, n+1)
		var ans int64

		limitPos := k - 1
		if limitPos > maxShift {
			limitPos = maxShift
		}
		limitNeg := k - 2
		if limitNeg > maxShift {
			limitNeg = maxShift
		}

		for i := 0; i < n; i++ {
			y := fs.nextInt()

			for d := 1; d <= limitPos; d++ {
				upper := y >> (d - 1)
				if upper < 1 {
					break
				}
				lower := (y >> d) + 1
				if lower <= upper {
					cnt := rangeCount(bitP, n, maxVal, lower, upper)
					if cnt != 0 {
						ans += int64(cnt) * posMod[d]
						if ans >= LIM {
							ans %= MOD
						}
					}
				}
			}

			curL := int64(y) + 1
			curR := int64(y) << 1
			for s := 0; s <= limitNeg && curL <= int64(maxVal); s++ {
				upper := curR
				if upper > int64(maxVal) {
					upper = int64(maxVal)
				}
				cnt := rangeCount(bitP, n, maxVal, int(curL), int(upper))
				if cnt != 0 {
					ans += int64(cnt) * negMod[s]
					if ans >= LIM {
						ans %= MOD
					}
				}
				curL = curR + 1
				curR <<= 1
			}

			if curL <= int64(maxVal) {
				cnt := rangeCount(bitP, n, maxVal, int(curL), maxVal)
				if cnt != 0 {
					ans += int64(cnt) * tailMod
					if ans >= LIM {
						ans %= MOD
					}
				}
			}

			add(bitP, (y+1)>>1)
		}

		ans %= MOD

		bitQ := make([]int, k+1)
		var invQ int64
		for i := 0; i < k; i++ {
			x := fs.nextInt() + 1
			le := sum(bitQ, x)
			invQ += int64(i - le)
			add(bitQ, x)
		}

		ans += (int64(n) % MOD) * (invQ % MOD)
		ans %= MOD

		out = strconv.AppendInt(out, ans, 10)
		out = append(out, '\n')
	}

	_, _ = os.Stdout.Write(out)
}