← Home
For problem statement at 0-999/200-299/240-249/241/problemB.txt this is a correct solution, but verifier at 0-999/200-299/240-249/241/verifierB.go ends with All 200 tests passed can you fix the verifier? package main

import (
	"io"
	"math/bits"
	"os"
	"runtime"
	"strconv"
)

const B = 30
const MOD int64 = 1000000007

var arr []int

var c0 []int32
var c1 []int32
var ccnt []uint16
var used []int32

func countLessAll(K int) int64 {
	n := len(arr)
	if K <= 0 {
		return 0
	}
	if K >= 1<<B {
		return int64(n) * int64(n-1) / 2
	}

	lc0, lc1, lcnt := c0, c1, ccnt
	for _, idx := range used {
		i := int(idx)
		lc0[i] = 0
		lc1[i] = 0
		lcnt[i] = 0
	}
	used = used[:0]
	used = append(used, 1)

	next := 1
	var res int64

	for _, x := range arr {
		node := 1
		for i := B - 1; i >= 0 && node != 0; i-- {
			xb := (x >> i) & 1
			kb := (K >> i) & 1
			if kb == 1 {
				if xb == 0 {
					add := int(lc0[node])
					if add != 0 {
						res += int64(lcnt[add])
					}
					node = int(lc1[node])
				} else {
					add := int(lc1[node])
					if add != 0 {
						res += int64(lcnt[add])
					}
					node = int(lc0[node])
				}
			} else {
				if xb == 0 {
					node = int(lc0[node])
				} else {
					node = int(lc1[node])
				}
			}
		}

		node = 1
		lcnt[node]++
		for i := B - 1; i >= 0; i-- {
			if ((x >> i) & 1) == 0 {
				nxt := lc0[node]
				if nxt == 0 {
					next++
					lc0[node] = int32(next)
					used = append(used, int32(next))
					nxt = int32(next)
				}
				node = int(nxt)
			} else {
				nxt := lc1[node]
				if nxt == 0 {
					next++
					lc1[node] = int32(next)
					used = append(used, int32(next))
					nxt = int32(next)
				}
				node = int(nxt)
			}
			lcnt[node]++
		}
	}

	return res
}

var s0 []int32
var s1 []int32
var scnt []uint16
var ssum []int64
var sones []uint16
var off []int
var pos []uint8
var pow2 [B]int64

func sumXorSubtree(node int, x int64, ps []uint8) int64 {
	res := int64(scnt[node])*x + ssum[node]
	base := node * B
	for _, b := range ps {
		bi := int(b)
		res -= 2 * int64(sones[base+bi]) * pow2[bi]
	}
	return res
}

func countSumLessAll(K int, totalPairs, totalSum int64) (int64, int64) {
	if K <= 0 {
		return 0, 0
	}
	if K >= 1<<B {
		return totalPairs, totalSum
	}

	next := 1
	var totalC, totalS int64

	for idx, x := range arr {
		x64 := int64(x)
		ps := pos[off[idx]:off[idx+1]]

		node := 1
		for i := B - 1; i >= 0 && node != 0; i-- {
			xb := (x >> i) & 1
			kb := (K >> i) & 1
			if kb == 1 {
				if xb == 0 {
					add := int(s0[node])
					if add != 0 {
						totalC += int64(scnt[add])
						totalS += sumXorSubtree(add, x64, ps)
					}
					node = int(s1[node])
				} else {
					add := int(s1[node])
					if add != 0 {
						totalC += int64(scnt[add])
						totalS += sumXorSubtree(add, x64, ps)
					}
					node = int(s0[node])
				}
			} else {
				if xb == 0 {
					node = int(s0[node])
				} else {
					node = int(s1[node])
				}
			}
		}

		node = 1
		scnt[node]++
		ssum[node] += x64
		base := node * B
		for _, b := range ps {
			sones[base+int(b)]++
		}

		for i := B - 1; i >= 0; i-- {
			if ((x >> i) & 1) == 0 {
				nxt := s0[node]
				if nxt == 0 {
					next++
					s0[node] = int32(next)
					nxt = int32(next)
				}
				node = int(nxt)
			} else {
				nxt := s1[node]
				if nxt == 0 {
					next++
					s1[node] = int32(next)
					nxt = int32(next)
				}
				node = int(nxt)
			}
			scnt[node]++
			ssum[node] += x64
			base = node * B
			for _, b := range ps {
				sones[base+int(b)]++
			}
		}
	}

	return totalC, totalS
}

func nextInt(data []byte, idx *int) int {
	n := len(data)
	for *idx < n && data[*idx] <= ' ' {
		*idx++
	}
	sign := 1
	if *idx < n && data[*idx] == '-' {
		sign = -1
		*idx++
	}
	val := 0
	for *idx < n {
		c := data[*idx]
		if c < '0' || c > '9' {
			break
		}
		val = val*10 + int(c-'0')
		*idx++
	}
	return val * sign
}

func main() {
	pow2[0] = 1
	for i := 1; i < B; i++ {
		pow2[i] = pow2[i-1] << 1
	}

	data, _ := io.ReadAll(os.Stdin)
	idx := 0
	n := nextInt(data, &idx)
	m := nextInt(data, &idx)

	if m == 0 || n <= 1 {
		os.Stdout.Write([]byte("0"))
		return
	}

	arr = make([]int, n)
	off = make([]int, n+1)
	var bitCnt [B]int
	pos = make([]uint8, 0, n*15)

	for i := 0; i < n; i++ {
		x := nextInt(data, &idx)
		arr[i] = x
		off[i] = len(pos)
		u := uint32(x)
		for u != 0 {
			b := bits.TrailingZeros32(u)
			pos = append(pos, uint8(b))
			bitCnt[b]++
			u &= u - 1
		}
	}
	off[n] = len(pos)

	totalPairs := int64(n) * int64(n-1) / 2
	m64 := int64(m)

	var totalSum int64
	for i := 0; i < B; i++ {
		ones := int64(bitCnt[i])
		zeros := int64(n) - ones
		totalSum += ones * zeros * pow2[i]
	}

	maxNodes := n*(B+1) + 5

	c0 = make([]int32, maxNodes)
	c1 = make([]int32, maxNodes)
	ccnt = make([]uint16, maxNodes)
	used = make([]int32, 0, maxNodes)

	lo, hi := 0, 1<<B
	for lo+1 < hi {
		mid := (lo + hi) >> 1
		cntGE := totalPairs - countLessAll(mid)
		if cntGE >= m64 {
			lo = mid
		} else {
			hi = mid
		}
	}
	T := lo

	c0 = nil
	c1 = nil
	ccnt = nil
	used = nil
	runtime.GC()

	K := T + 1
	var countLE, sumLE int64

	if K >= 1<<B {
		countLE = totalPairs
		sumLE = totalSum
	} else {
		s0 = make([]int32, maxNodes)
		s1 = make([]int32, maxNodes)
		scnt = make([]uint16, maxNodes)
		ssum = make([]int64, maxNodes)
		sones = make([]uint16, maxNodes*B)

		countLE, sumLE = countSumLessAll(K, totalPairs, totalSum)
	}

	countGreater := totalPairs - countLE
	sumGreater := totalSum - sumLE
	ans := sumGreater + (m64-countGreater)*int64(T)
	ans %= MOD

	out := strconv.FormatInt(ans, 10)
	os.Stdout.Write([]byte(out))
}