← Home
```go
package main

import (
	"bufio"
	"fmt"
	"os"
)

func main() {
	reader := bufio.NewReaderSize(os.Stdin, 1024*1024)
	readInt := func() int {
		res := 0
		b, err := reader.ReadByte()
		for err == nil && (b < '0' || b > '9') {
			b, err = reader.ReadByte()
		}
		if err != nil {
			return 0
		}
		for err == nil && b >= '0' && b <= '9' {
			res = res*10 + int(b-'0')
			b, err = reader.ReadByte()
		}
		return res
	}

	n := readInt()
	k := readInt()
	if n == 0 {
		return
	}

	col := make([]int, n+1)
	V := make([][]int, k+1)
	for i := 1; i <= n; i++ {
		col[i] = readInt()
		if col[i] > 0 {
			V[col[i]] = append(V[col[i]], i)
		}
	}

	adj := make([][]int, n+1)
	for i := 0; i < n-1; i++ {
		u := readInt()
		v := readInt()
		adj[u] = append(adj[u], v)
		adj[v] = append(adj[v], u)
	}

	const MAXLOG = 20
	up := make([][MAXLOG]int, n+1)
	tin := make([]int, n+1)
	tout := make([]int, n+1)
	depth := make([]int, n+1)
	timer := 1

	var dfs_lca func(u, p, d int)
	dfs_lca = func(u, p, d int) {
		tin[u] = timer
		timer++
		up[u][0] = p
		depth[u] = d
		for i := 1; i < MAXLOG; i++ {
			if up[u][i-1] != 0 {
				up[u][i] = up[up[u][i-1]][i-1]
			}
		}
		for _, v := range adj[u] {
			if v != p {
				dfs_lca(v, u, d+1)
			}
		}
		tout[u] = timer
		timer++
	}

	dfs_lca(1, 0, 0)

	is_ancestor := func(u, v int) bool {
		if u == 0 {
			return true
		}
		return tin[u] <= tin[v] && tout[u] >= tout[v]
	}

	lca := func(u, v int) int {
		if is_ancestor(u, v) {
			return u
		}
		if is_ancestor(v, u) {
			return v
		}
		for i := MAXLOG - 1; i >= 0; i-- {
			if up[u][i] != 0 && !is_ancestor(up[u][i], v) {
				u = up[u][i]
			}
		}
		return up[u][0]
	}

	visited := make([]bool, n+1)
	for c := 1; c <= k; c++ {
		if len(V[c]) == 0 {
			continue
		}
		root_c := V[c][0]
		for i := 1; i < len(V[c]); i++ {
			root_c = lca(root_c, V[c][i])
		}

		for _, x := range V[c] {
			curr := x
			for depth[curr] > depth[root_c] {
				if col[curr] != 0 && col[curr] != c {
					fmt.Println(0)
					return
				}
				if visited[curr] {
					break
				}
				col[curr] = c
				visited[curr] = true
				curr = up[curr][0]
			}
		}
		if col[root_c] != 0 && col[root_c] != c {
			fmt.Println(0)
			return
		}
		col[root_c] = c
		visited[root_c] = true
	}

	new_id := make([]int, n+1)
	for i := 1; i <= n; i++ {
		if col[i] == 0 {
			new_id[i] = i
		} else {
			new_id[i] = n + col[i]
		}
	}

	new_adj := make([][]int, n+k+1)
	for u := 1; u <= n; u++ {
		for _, v := range adj[u] {
			nu := new_id[u]
			nv := new_id[v]
			if nu != nv && u < v {
				new_adj[nu] = append(new_adj[nu], nv)
				new_adj[nv] = append(new_adj[nv], nu)
			}
		}
	}

	is_colored := make([]bool, n+k+1)
	for i := 1; i <= k; i++ {
		is_colored[n+i] = true
	}

	dp := make([][2]int, n+k+1)
	var dfs_dp func(u, p int)
	dfs_dp = func(u, p int) {
		if is_colored[u] {
			dp[u][0] = 0
			dp[u][1] = 1
		} else {
			dp[u][0] = 1
			dp[u][1] = 0
		}

		for _, v := range new_adj[u] {
			if v == p {
				continue
			}
			dfs_dp(v, u)

			sum_v := (dp[v][0] + dp[v][1]) % 998244353

			new_dp0 := (dp[u][0] * sum_v) % 998244353
			new_dp1 := (dp[u][1] * sum_v + dp[u][0] * dp[v][1]) % 998244353

			dp[u][0] = new_dp0
			dp[u][1] = new_dp1
		}
	}

	root := new_id[1]
	dfs_dp(root, 0)

	fmt.Println(dp[root][1])
}
```