-
Notifications
You must be signed in to change notification settings - Fork 0
/
errgroup.go
123 lines (111 loc) · 3.43 KB
/
errgroup.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
// Package errgroup provides synchronization, error propagation, and Context
// cancellation for groups of goroutines working on subtasks of a common task.
//
// It wraps, and exposes a similar API to, the upstream package
// golang.org/x/sync/errgroup. Our version additionally recovers from panics,
// converting them into errors.
package errgroup
import (
"context"
"fmt"
"runtime"
"sync"
)
// FromPanicValue takes a value recovered from a panic and converts it into an
// error, for logging purposes. If the value is nil, it returns nil instead of
// an error.
//
// Use like:
// defer func() {
// err := FromPanicValue(recover())
// // log or otheriwse use err
// }()
func FromPanicValue(i interface{}) error {
switch value := i.(type) {
case nil:
return nil
case string:
return fmt.Errorf("panic: %v\n%s", value, CollectStack())
case error:
return fmt.Errorf("panic in errgroup goroutine %w\n%s", value, CollectStack())
default:
return fmt.Errorf("unknown panic: %+v\n%s", value, CollectStack())
}
}
func CollectStack() []byte {
buf := make([]byte, 64<<10)
buf = buf[:runtime.Stack(buf, false)]
return buf
}
func catchPanics(f func() error) func() error {
return func() (err error) {
defer func() {
// modified from log.PanicHandler, except instead of log.Panic we
// set `err`, which is the named-return from our closure to
// `g.Group.Go`, to an error based on the panic value.
// We do not log here -- we are effectively returning the (panic)
// error to our caller which suffices.
if r := recover(); r != nil {
err = FromPanicValue(r)
}
}()
return f()
}
}
// A Group is a collection of goroutines working on subtasks that are part of
// the same overall task.
//
// A zero Group is valid and does not cancel on error.
type Group struct {
// Sadly, we have to copy the whole implementation, because:
// - we want a zero errgroup to work, which means we'd need to embed the
// upstream errgroup by value
// - we can't copy an errgroup, which means we can't embed by value
// (We could get around this with our own initialization-Once, but that
// seems even more convoluted.) So we just copy -- it's not that much
// code. The only change below is to add catchPanics(), in Go().
cancel func()
wg sync.WaitGroup
errOnce sync.Once
err error
}
// WithContext returns a new Group and an associated Context derived from ctx.
//
// The derived Context is canceled the first time a function passed to Go
// returns a non-nil error or panics, or the first time Wait returns,
// whichever occurs first.
func WithContext(ctx context.Context) (*Group, context.Context) {
ctx, cancel := context.WithCancel(ctx)
return &Group{cancel: cancel}, ctx
}
// Wait blocks until all function calls from the Go method have returned, then
// returns the first non-nil error (if any) from them.
func (g *Group) Wait() error {
g.wg.Wait()
if g.cancel != nil {
g.cancel()
}
return g.err
}
// Go calls the given function in a new goroutine.
//
// The first call to return a non-nil error cancels the group; its error will
// be returned by Wait.
//
// If the function panics, this is treated as if it returned an error.
func (g *Group) Go(f func() error) {
g.wg.Add(1)
go func() {
defer g.wg.Done()
// here's the only change from upstream: this was
// err := f(); ...
if err := catchPanics(f)(); err != nil {
g.errOnce.Do(func() {
g.err = err
if g.cancel != nil {
g.cancel()
}
})
}
}()
}