```go
package main
import (
"bufio"
"fmt"
"os"
)
func main() {
reader := bufio.NewReaderSize(os.Stdin, 1024*1024)
readInt := func() int {
res := 0
b, err := reader.ReadByte()
for err == nil && (b < '0' || b > '9') {
b, err = reader.ReadByte()
}
if err != nil {
return 0
}
for err == nil && b >= '0' && b <= '9' {
res = res*10 + int(b-'0')
b, err = reader.ReadByte()
}
return res
}
n := readInt()
k := readInt()
if n == 0 {
return
}
col := make([]int, n+1)
V := make([][]int, k+1)
for i := 1; i <= n; i++ {
col[i] = readInt()
if col[i] > 0 {
V[col[i]] = append(V[col[i]], i)
}
}
adj := make([][]int, n+1)
for i := 0; i < n-1; i++ {
u := readInt()
v := readInt()
adj[u] = append(adj[u], v)
adj[v] = append(adj[v], u)
}
const MAXLOG = 20
up := make([][MAXLOG]int, n+1)
tin := make([]int, n+1)
tout := make([]int, n+1)
depth := make([]int, n+1)
timer := 1
var dfs_lca func(u, p, d int)
dfs_lca = func(u, p, d int) {
tin[u] = timer
timer++
up[u][0] = p
depth[u] = d
for i := 1; i < MAXLOG; i++ {
if up[u][i-1] != 0 {
up[u][i] = up[up[u][i-1]][i-1]
}
}
for _, v := range adj[u] {
if v != p {
dfs_lca(v, u, d+1)
}
}
tout[u] = timer
timer++
}
dfs_lca(1, 0, 0)
is_ancestor := func(u, v int) bool {
if u == 0 {
return true
}
return tin[u] <= tin[v] && tout[u] >= tout[v]
}
lca := func(u, v int) int {
if is_ancestor(u, v) {
return u
}
if is_ancestor(v, u) {
return v
}
for i := MAXLOG - 1; i >= 0; i-- {
if up[u][i] != 0 && !is_ancestor(up[u][i], v) {
u = up[u][i]
}
}
return up[u][0]
}
visited := make([]bool, n+1)
for c := 1; c <= k; c++ {
if len(V[c]) == 0 {
continue
}
root_c := V[c][0]
for i := 1; i < len(V[c]); i++ {
root_c = lca(root_c, V[c][i])
}
for _, x := range V[c] {
curr := x
for depth[curr] > depth[root_c] {
if col[curr] != 0 && col[curr] != c {
fmt.Println(0)
return
}
if visited[curr] {
break
}
col[curr] = c
visited[curr] = true
curr = up[curr][0]
}
}
if col[root_c] != 0 && col[root_c] != c {
fmt.Println(0)
return
}
col[root_c] = c
visited[root_c] = true
}
new_id := make([]int, n+1)
for i := 1; i <= n; i++ {
if col[i] == 0 {
new_id[i] = i
} else {
new_id[i] = n + col[i]
}
}
new_adj := make([][]int, n+k+1)
for u := 1; u <= n; u++ {
for _, v := range adj[u] {
nu := new_id[u]
nv := new_id[v]
if nu != nv && u < v {
new_adj[nu] = append(new_adj[nu], nv)
new_adj[nv] = append(new_adj[nv], nu)
}
}
}
is_colored := make([]bool, n+k+1)
for i := 1; i <= k; i++ {
is_colored[n+i] = true
}
dp := make([][2]int, n+k+1)
var dfs_dp func(u, p int)
dfs_dp = func(u, p int) {
if is_colored[u] {
dp[u][0] = 0
dp[u][1] = 1
} else {
dp[u][0] = 1
dp[u][1] = 0
}
for _, v := range new_adj[u] {
if v == p {
continue
}
dfs_dp(v, u)
sum_v := (dp[v][0] + dp[v][1]) % 998244353
new_dp0 := (dp[u][0] * sum_v) % 998244353
new_dp1 := (dp[u][1] * sum_v + dp[u][0] * dp[v][1]) % 998244353
dp[u][0] = new_dp0
dp[u][1] = new_dp1
}
}
root := new_id[1]
dfs_dp(root, 0)
fmt.Println(dp[root][1])
}
```