← Home
package main

import (
	"bufio"
	"fmt"
	"os"
)

func readInt(in *bufio.Reader) int {
	var n int
	var c byte
	var err error
	for {
		c, err = in.ReadByte()
		if err != nil || (c >= '0' && c <= '9') {
			break
		}
	}
	for err == nil && c >= '0' && c <= '9' {
		n = n*10 + int(c-'0')
		c, err = in.ReadByte()
	}
	return n
}

func main() {
	in := bufio.NewReader(os.Stdin)
	n := readInt(in)
	k := readInt(in)

	sz := make([]int, n+1)
	for i := 0; i < 2*k; i++ {
		u := readInt(in)
		sz[u] = 1
	}

	adj := make([][]int, n+1)
	for i := 0; i < n-1; i++ {
		u := readInt(in)
		v := readInt(in)
		adj[u] = append(adj[u], v)
		adj[v] = append(adj[v], u)
	}

	q := make([]int, 0, n)
	q = append(q, 1)
	visited := make([]bool, n+1)
	visited[1] = true
	parent := make([]int, n+1)

	for i := 0; i < len(q); i++ {
		u := q[i]
		for _, v := range adj[u] {
			if !visited[v] {
				visited[v] = true
				parent[v] = u
				q = append(q, v)
			}
		}
	}

	var totalDist int64
	for i := n - 1; i >= 0; i-- {
		u := q[i]
		if u != 1 {
			c := sz[u]
			rem := 2*k - c
			if c < rem {
				totalDist += int64(c)
			} else {
				totalDist += int64(rem)
			}
			sz[parent[u]] += c
		}
	}

	fmt.Println(totalDist)
}