-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpacket-manager.go
173 lines (153 loc) · 4.15 KB
/
packet-manager.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
package sftp
import (
"encoding"
"sort"
"sync"
)
// The goal of the packetManager is to keep the outgoing packets in the same
// order as the incoming. This is due to some sftp clients requiring this
// behavior (eg. winscp).
type packetSender interface {
sendPacket(encoding.BinaryMarshaler) error
}
type packetManager struct {
requests chan requestPacket
responses chan responsePacket
fini chan struct{}
incoming requestPacketIDs
outgoing responsePackets
sender packetSender // connection object
working *sync.WaitGroup
}
func newPktMgr(sender packetSender) *packetManager {
s := &packetManager{
requests: make(chan requestPacket, SftpServerWorkerCount),
responses: make(chan responsePacket, SftpServerWorkerCount),
fini: make(chan struct{}),
incoming: make([]uint32, 0, SftpServerWorkerCount),
outgoing: make([]responsePacket, 0, SftpServerWorkerCount),
sender: sender,
working: &sync.WaitGroup{},
}
go s.controller()
return s
}
type responsePackets []responsePacket
func (r responsePackets) Sort() {
sort.Slice(r, func(i, j int) bool {
return r[i].id() < r[j].id()
})
}
type requestPacketIDs []uint32
func (r requestPacketIDs) Sort() {
sort.Slice(r, func(i, j int) bool {
return r[i] < r[j]
})
}
// register incoming packets to be handled
// send id of 0 for packets without id
func (s *packetManager) incomingPacket(pkt requestPacket) {
s.working.Add(1)
s.requests <- pkt // buffer == SftpServerWorkerCount
}
// register outgoing packets as being ready
func (s *packetManager) readyPacket(pkt responsePacket) {
s.responses <- pkt
s.working.Done()
}
// shut down packetManager controller
func (s *packetManager) close() {
// pause until current packets are processed
s.working.Wait()
close(s.fini)
}
// Passed a worker function, returns a channel for incoming packets.
// The goal is to process packets in the order they are received as is
// requires by section 7 of the RFC, while maximizing throughput of file
// transfers.
func (s *packetManager) workerChan(runWorker func(requestChan)) requestChan {
rwChan := make(chan requestPacket, SftpServerWorkerCount)
for i := 0; i < SftpServerWorkerCount; i++ {
runWorker(rwChan)
}
cmdChan := make(chan requestPacket)
runWorker(cmdChan)
pktChan := make(chan requestPacket, SftpServerWorkerCount)
go func() {
// start with cmdChan
curChan := cmdChan
for pkt := range pktChan {
// on file open packet, switch to rwChan
switch pkt.(type) {
case *sshFxpOpenPacket:
curChan = rwChan
// on file close packet, switch back to cmdChan
// after waiting for any reads/writes to finish
case *sshFxpClosePacket:
// wait for rwChan to finish
s.working.Wait()
// stop using rwChan
curChan = cmdChan
}
s.incomingPacket(pkt)
curChan <- pkt
}
close(rwChan)
close(cmdChan)
s.close()
}()
return pktChan
}
// process packets
func (s *packetManager) controller() {
for {
select {
case pkt := <-s.requests:
debug("incoming id: %v", pkt.id())
s.incoming = append(s.incoming, pkt.id())
if len(s.incoming) > 1 {
s.incoming.Sort()
}
case pkt := <-s.responses:
debug("outgoing pkt: %v", pkt.id())
s.outgoing = append(s.outgoing, pkt)
if len(s.outgoing) > 1 {
s.outgoing.Sort()
}
case <-s.fini:
return
}
s.maybeSendPackets()
}
}
// send as many packets as are ready
func (s *packetManager) maybeSendPackets() {
for {
if len(s.outgoing) == 0 || len(s.incoming) == 0 {
debug("break! -- outgoing: %v; incoming: %v",
len(s.outgoing), len(s.incoming))
break
}
out := s.outgoing[0]
in := s.incoming[0]
// debug("incoming: %v", s.incoming)
// debug("outgoing: %v", outfilter(s.outgoing))
if in == out.id() {
s.sender.sendPacket(out)
// pop off heads
copy(s.incoming, s.incoming[1:]) // shift left
s.incoming = s.incoming[:len(s.incoming)-1] // remove last
copy(s.outgoing, s.outgoing[1:]) // shift left
s.outgoing = s.outgoing[:len(s.outgoing)-1] // remove last
} else {
break
}
}
}
//func outfilter(o []responsePacket) []uint32 {
// res := make([]uint32, 0, len(o))
// for _, v := range o {
// res = append(res, v.id())
// }
// return res
//}