使用 slices.Sort
和 slices.SortFunc
避免 sort.Slices
的坑
sort.Slices 介绍
sort.Slices
是go 于1.18 版本新增的排序函数,签名如下:
func Slice(x any, less func(i, j int) bool)
使用起来非常简单:
func main() {
type Student struct {
Name string
Age int
}
students := []Student{
{Name: "Gopher", Age: 14},
{Name: "Carol", Age: 10},
{Name: "Alice", Age: 10},
{Name: "Bob", Age: 15},
{Name: "Dave", Age: 12},
}
// sort by Age first, Name second
sort.Slice(students, func(i, j int) bool {
x, y := students[i], students[j]
if x.Age != y.Age {
return x.Age < y.Age
}
return x.Name < y.Name
})
for _, student := range students {
fmt.Printf("%d %s\n", student.Age, student.Name)
}
}
// Output:
// 10 Alice
// 10 Carol
// 12 Dave
// 14 Gopher
// 15 Bob
闭包的坑
上面的代码中,甚至进行了多字段排序。但如果我们只需要部分排序,代码又该怎么写呢?
这样对吗?
func SortAfter(nums []int, p int) { // 从p开始排序
sort.Slice(nums[p:], func(i, j int) bool {
return nums[i] < nums[j]
})
}
咋一看没什么问题,跑下测试用例吧:
func main() {
nums := []int{2, 3, 1, 5, 4, 6, 7}
tests := []struct {
p int
want []int
}{
{p: 1, want: []int{2, 1, 3, 4, 5, 6, 7}},
{p: 2, want: []int{2, 3, 1, 4, 5, 6, 7}},
{p: 3, want: []int{2, 3, 1, 4, 5, 6, 7}},
}
for _, tt := range tests {
got := slices.Clone(nums) // 拷贝原始数据,用于测试
SortAfter(got, tt.p)
if !reflect.DeepEqual(got, tt.want) {
fmt.Printf("when p: %v ,want: %v, but got: %v\n", tt.p, tt.want, got)
}
}
}
// Output:
// when p: 1 ,want: [2 1 3 4 5 6 7], but got: [2 3 5 6 7 4 1]
// when p: 3 ,want: [2 3 1 4 5 6 7], but got: [2 3 1 5 6 4 7]
为什么 p = 1,3 时不对,p = 2 时又是对的呢?
因为 sort.Slices 接收到的参数是 nums[p:]
,less
闭包里的参数 i, j
是在 nums[p:]
的位置,如果直接比较 nums[i]
和nums[j]
,那就忽略了 p
偏移的影响,所以实际要比较的元素其实是 nums[i+p]
和 nums[j+p]
,那么修复后的函数:
func SortAfter(nums []int, p int) { // 从p开始排序
sort.Slice(nums[p:], func(i, j int) bool { // 这里传入的slice不再是完整的nums,而是nums[p:]
return nums[i+p] < nums[j+p]
})
}
使用 slices.Sort 优化
手动修正偏移量,可以避免部分排序这个坑一时,但日后依旧有可能因为思维惯性而导致再次踩坑。具体实现也很别扭,也不方便修改。
好在 go 在 1.21 版本新增了了 slices
这个泛型库,里面包含了很多切片的通用操作,其中的 slices.Sort
和 slices.SortFunc
函数就可以避免上面的坑。
这两个函数的签名:
func Sort[S ~[]E, E cmp.Ordered](x S)
func SortFunc[S ~[]E, E any](x S, cmp func(a, b E) int)
使用方法也很简单,直接传入要排序的部分,如果调用 slices.Sort
,还可以省去用手写 less
闭包函数。
最重要的是,按照直觉使用这两个函数就可以避免 sort.Slices
在部分排序时的坑:
func main() {
nums := []int{2, 3, 1, 5, 4, 6, 7}
tests := []struct {
p int
want []int
}{
{p: 0, want: []int{1, 2, 3, 4, 5, 6, 7}},
{p: 1, want: []int{2, 1, 3, 4, 5, 6, 7}},
{p: 2, want: []int{2, 3, 1, 4, 5, 6, 7}},
{p: 3, want: []int{2, 3, 1, 4, 5, 6, 7}},
}
for _, tt := range tests {
got := slices.Clone(nums) // 拷贝原始数据,用于测试
slices.Sort(got[tt.p:])
if !reflect.DeepEqual(got, tt.want) {
fmt.Printf("when p: %v ,want: %v, but got: %v\n", tt.p, tt.want, got)
}
}
}
而使用 sort.SortFunc
,搭配 go 1.22 新增的泛型函数 cmp.Or
,可以更轻松的实现多字段排序,这段代码来自 go1.22 标准库 cmp/cmp_test.go
:
func main() {
type Order struct {
Product string
Customer string
Price float64
}
orders := []Order{
{"foo", "alice", 1.00},
{"bar", "bob", 3.00},
{"baz", "carol", 4.00},
{"foo", "alice", 2.00},
{"bar", "carol", 1.00},
{"foo", "bob", 4.00},
}
// Sort by customer first, product second, and last by higher price
slices.SortFunc(orders, func(a, b Order) int {
return cmp.Or(
cmp.Compare(a.Customer, b.Customer),
cmp.Compare(a.Product, b.Product),
cmp.Compare(b.Price, a.Price),
)
})
for _, order := range orders {
fmt.Printf("%s %s %.2f\n", order.Product, order.Customer, order.Price)
}
}
// Output:
// foo alice 2.00
// foo alice 1.00
// bar bob 3.00
// foo bob 4.00
// bar carol 1.00
// baz carol 4.00
其中 cmp.Or
的源码 非常简单,只是用于找出切片中第一个非零的元素:
// Or returns the first of its arguments that is not equal to the zero value.
// If no argument is non-zero, it returns the zero value.
func Or[T comparable](vals ...T) T {
var zero T
for _, val := range vals {
if val != zero {
return val
}
}
return zero
}
可以看到,这样写 less
函数,比手写多个 if
优雅太多了。