← Home
For problem statement at 0-999/800-899/830-839/833/problemD.txt this is a correct solution, but verifier at 0-999/800-899/830-839/833/verifierD.go ends with All tests passed can you fix the verifier? package main

import (
	"bufio"
	"fmt"
	"io"
	"os"
	"sort"
)

const MOD int64 = 1000000007

type Edge struct {
	to     int
	w      int64
	dp, dq int
}

type Point struct {
	p, q int
	w    int64
}

type State struct {
	v, par int
	p, q   int
	w      int64
}

var data []byte
var inputPtr int

func nextInt() int {
	for inputPtr < len(data) {
		c := data[inputPtr]
		if c == '-' || (c >= '0' && c <= '9') {
			break
		}
		inputPtr++
	}
	sign := 1
	if data[inputPtr] == '-' {
		sign = -1
		inputPtr++
	}
	val := 0
	for inputPtr < len(data) {
		c := data[inputPtr]
		if c < '0' || c > '9' {
			break
		}
		val = val*10 + int(c-'0')
		inputPtr++
	}
	return sign * val
}

var n int
var g [][]Edge
var removed []bool
var parentArr []int
var sizeArr []int

var orderBuf []int
var stackBuf []int

var collectStack []State
var collectPointsBuf []Point

var tmpNeg []Point
var tmpPos []Point

var qOff int
var bitSize int
var bitCnt []int64
var bitProd []int64
var bitTag []int
var curTag int

var ans int64 = 1

func powmod(a, e int64) int64 {
	res := int64(1)
	for e > 0 {
		if e&1 == 1 {
			res = res * a % MOD
		}
		a = a * a % MOD
		e >>= 1
	}
	return res
}

func modInv(a int64) int64 {
	return powmod(a, MOD-2)
}

func bitStart() {
	curTag++
}

func bitUpdate(idx int, w int64) {
	for idx < bitSize {
		if bitTag[idx] != curTag {
			bitTag[idx] = curTag
			bitCnt[idx] = 0
			bitProd[idx] = 1
		}
		bitCnt[idx]++
		bitProd[idx] = bitProd[idx] * w % MOD
		idx += idx & -idx
	}
}

func bitQuery(idx int) (int64, int64) {
	cnt := int64(0)
	prod := int64(1)
	for idx > 0 {
		if bitTag[idx] == curTag {
			cnt += bitCnt[idx]
			prod = prod * bitProd[idx] % MOD
		}
		idx -= idx & -idx
	}
	return cnt, prod
}

func findCentroid(start int) (int, int) {
	orderBuf = orderBuf[:0]
	stackBuf = stackBuf[:0]
	parentArr[start] = 0
	stackBuf = append(stackBuf, start)

	for len(stackBuf) > 0 {
		v := stackBuf[len(stackBuf)-1]
		stackBuf = stackBuf[:len(stackBuf)-1]
		orderBuf = append(orderBuf, v)
		for _, e := range g[v] {
			to := e.to
			if removed[to] || to == parentArr[v] {
				continue
			}
			parentArr[to] = v
			stackBuf = append(stackBuf, to)
		}
	}

	compSize := len(orderBuf)

	for i := compSize - 1; i >= 0; i-- {
		v := orderBuf[i]
		sizeArr[v] = 1
		for _, e := range g[v] {
			to := e.to
			if removed[to] || parentArr[to] != v {
				continue
			}
			sizeArr[v] += sizeArr[to]
		}
	}

	centroid := start
	best := compSize + 1
	for _, v := range orderBuf {
		mx := compSize - sizeArr[v]
		for _, e := range g[v] {
			to := e.to
			if removed[to] || parentArr[to] != v {
				continue
			}
			if sizeArr[to] > mx {
				mx = sizeArr[to]
			}
		}
		if mx < best {
			best = mx
			centroid = v
		}
	}

	return centroid, compSize
}

func collectPoints(start, par, initP, initQ int, initW int64) []Point {
	collectStack = collectStack[:0]
	collectPointsBuf = collectPointsBuf[:0]
	collectStack = append(collectStack, State{start, par, initP, initQ, initW})

	for len(collectStack) > 0 {
		st := collectStack[len(collectStack)-1]
		collectStack = collectStack[:len(collectStack)-1]
		collectPointsBuf = append(collectPointsBuf, Point{st.p, st.q, st.w})
		for _, e := range g[st.v] {
			if removed[e.to] || e.to == st.par {
				continue
			}
			collectStack = append(collectStack, State{e.to, st.v, st.p + e.dp, st.q + e.dq, st.w * e.w % MOD})
		}
	}

	return collectPointsBuf
}

func calcF(points []Point, addCentroid bool) int64 {
	if !addCentroid && len(points) < 2 {
		return 1
	}

	tmpNeg = tmpNeg[:0]
	tmpPos = tmpPos[:0]

	for _, pt := range points {
		if pt.p <= 0 {
			tmpNeg = append(tmpNeg, pt)
		} else {
			tmpPos = append(tmpPos, pt)
		}
	}
	if addCentroid {
		tmpNeg = append(tmpNeg, Point{0, 0, 1})
	}

	if len(tmpNeg) > 1 {
		sort.Slice(tmpNeg, func(i, j int) bool {
			if tmpNeg[i].p != tmpNeg[j].p {
				return tmpNeg[i].p < tmpNeg[j].p
			}
			return tmpNeg[i].q < tmpNeg[j].q
		})
	}

	res := int64(1)

	if len(tmpNeg) > 1 {
		bitStart()
		for _, pt := range tmpNeg {
			k, p := bitQuery(qOff - pt.q)
			if k != 0 && pt.w != 1 {
				res = res * powmod(pt.w, k) % MOD
			}
			res = res * p % MOD
			bitUpdate(pt.q+qOff, pt.w)
		}
	}

	if len(tmpPos) > 0 && len(tmpNeg) > 0 {
		if len(tmpPos) > 1 {
			sort.Slice(tmpPos, func(i, j int) bool {
				if tmpPos[i].p != tmpPos[j].p {
					return tmpPos[i].p > tmpPos[j].p
				}
				return tmpPos[i].q < tmpPos[j].q
			})
		}

		bitStart()
		ptr := 0
		for _, pt := range tmpPos {
			lim := -pt.p
			for ptr < len(tmpNeg) && tmpNeg[ptr].p <= lim {
				bitUpdate(tmpNeg[ptr].q+qOff, tmpNeg[ptr].w)
				ptr++
			}
			k, p := bitQuery(qOff - pt.q)
			if k != 0 && pt.w != 1 {
				res = res * powmod(pt.w, k) % MOD
			}
			res = res * p % MOD
		}
	}

	return res
}

func decompose(start int) {
	c, compSize := findCentroid(start)
	removed[c] = true

	all := make([]Point, 0, compSize-1)
	denom := int64(1)

	for _, e := range g[c] {
		if removed[e.to] {
			continue
		}
		pts := collectPoints(e.to, c, e.dp, e.dq, e.w)
		denom = denom * calcF(pts, false) % MOD
		all = append(all, pts...)
	}

	fall := calcF(all, true)
	ans = ans * fall % MOD
	ans = ans * modInv(denom) % MOD

	for _, e := range g[c] {
		if !removed[e.to] {
			decompose(e.to)
		}
	}
}

func main() {
	data, _ = io.ReadAll(os.Stdin)
	n = nextInt()

	g = make([][]Edge, n+1)
	for i := 0; i < n-1; i++ {
		u := nextInt()
		v := nextInt()
		x := int64(nextInt())
		c := nextInt()
		dp, dq := -2, 1
		if c == 1 {
			dp, dq = 1, -2
		}
		g[u] = append(g[u], Edge{v, x, dp, dq})
		g[v] = append(g[v], Edge{u, x, dp, dq})
	}

	removed = make([]bool, n+1)
	parentArr = make([]int, n+1)
	sizeArr = make([]int, n+1)

	qOff = 2*n + 20
	bitSize = 4*n + 50
	bitCnt = make([]int64, bitSize)
	bitProd = make([]int64, bitSize)
	bitTag = make([]int, bitSize)

	decompose(1)

	out := bufio.NewWriterSize(os.Stdout, 1<<20)
	fmt.Fprint(out, ans)
	out.Flush()
}