-
Notifications
You must be signed in to change notification settings - Fork 99
/
transport.go
79 lines (61 loc) · 1.22 KB
/
transport.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
package smb2
import (
"errors"
"io"
"net"
)
const (
maxDirectTCPSize = 0xffffff // 16777215
// maxNetBTSize = 0x1ffff // 131071
)
type transport interface {
Write(p []byte) (n int, err error)
ReadSize() (size int, err error)
Read(p []byte) (n int, err error)
Close() error
}
type directTCP struct {
sb [4]byte
rb [4]byte
conn net.Conn
}
func direct(tcpConn net.Conn) transport {
return &directTCP{conn: tcpConn}
}
func (t *directTCP) Write(p []byte) (n int, err error) {
if len(p) > maxDirectTCPSize {
return -1, errors.New("max transport size exceeds")
}
bs := t.sb[:]
be.PutUint32(bs, uint32(len(p)))
_, err = t.conn.Write(bs)
if err != nil {
return -1, err
}
n, err = t.conn.Write(p)
if err != nil {
return -1, err
}
return n + 4, nil
}
func (t *directTCP) ReadSize() (size int, err error) {
bs := t.rb[:]
_, err = io.ReadFull(t.conn, bs)
if err != nil {
return -1, err
}
if bs[0] != 0 {
return -1, errors.New("invalid transport format")
}
return int(be.Uint32(bs)), nil
}
func (t *directTCP) Read(p []byte) (n int, err error) {
n, err = io.ReadFull(t.conn, p)
if err != nil {
return -1, err
}
return n, err
}
func (t *directTCP) Close() error {
return t.conn.Close()
}