Skip to content

Commit

Permalink
Merge pull request #120 from ethuleau/traffic-class
Browse files Browse the repository at this point in the history
Add support for setting traffic class on outgoing packets
  • Loading branch information
SuperQ authored Nov 22, 2024
2 parents 622e8b3 + cd05b29 commit 543f9b2
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 6 deletions.
16 changes: 10 additions & 6 deletions cmd/ping/ping.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,7 @@ import (
probing "github.com/prometheus-community/pro-bing"
)

var usage = `
Usage:
ping [-c count] [-i interval] [-t timeout] [-I interface] [--privileged] host
var examples = `
Examples:
# ping google continuously
Expand All @@ -37,6 +33,9 @@ Examples:
# Send ICMP messages with a 100-byte payload
ping -s 100 1.1.1.1
# Send ICMP messages with DSCP CS4 and ECN bits set to 0
ping -Q 128 8.8.8.8
`

func main() {
Expand All @@ -46,9 +45,13 @@ func main() {
size := flag.Int("s", 24, "")
ttl := flag.Int("l", 64, "TTL")
iface := flag.String("I", "", "interface name")
tclass := flag.Int("Q", 192, "Set Quality of Service related bits in ICMP datagrams (DSCP + ECN bits). Only decimal number supported")
privileged := flag.Bool("privileged", false, "")
flag.Usage = func() {
fmt.Print(usage)
out := flag.CommandLine.Output()
fmt.Fprintf(out, "Usage of %s:\n", os.Args[0])
flag.PrintDefaults()
fmt.Fprint(out, examples)
}
flag.Parse()

Expand Down Expand Up @@ -96,6 +99,7 @@ func main() {
pinger.TTL = *ttl
pinger.InterfaceName = *iface
pinger.SetPrivileged(*privileged)
pinger.SetTrafficClass(uint8(*tclass))

fmt.Printf("PING %s (%s):\n", pinger.Addr(), pinger.IPAddr())
err = pinger.Run()
Expand Down
9 changes: 9 additions & 0 deletions packetconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ type packetConn interface {
SetDoNotFragment() error
SetBroadcastFlag() error
SetIfIndex(ifIndex int)
SetTrafficClass(uint8) error
}

type icmpConn struct {
Expand Down Expand Up @@ -58,6 +59,10 @@ func (c *icmpv4Conn) SetFlagTTL() error {
return err
}

func (c *icmpv4Conn) SetTrafficClass(tclass uint8) error {
return c.c.IPv4PacketConn().SetTOS(int(tclass))
}

func (c *icmpv4Conn) ReadFrom(b []byte) (int, int, net.Addr, error) {
ttl := -1
n, cm, src, err := c.c.IPv4PacketConn().ReadFrom(b)
Expand Down Expand Up @@ -99,6 +104,10 @@ func (c *icmpV6Conn) SetFlagTTL() error {
return err
}

func (c *icmpV6Conn) SetTrafficClass(tclass uint8) error {
return c.c.IPv6PacketConn().SetTrafficClass(int(tclass))
}

func (c *icmpV6Conn) ReadFrom(b []byte) (int, int, net.Addr, error) {
ttl := -1
n, cm, src, err := c.c.IPv6PacketConn().ReadFrom(b)
Expand Down
22 changes: 22 additions & 0 deletions ping.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ func New(addr string) *Pinger {
protocol: "udp",
awaitingSequences: firstSequence,
TTL: 64,
tclass: 192, // CS6 (network control)
logger: StdLogger{Logger: log.New(log.Writer(), log.Prefix(), log.Flags())},
}
}
Expand Down Expand Up @@ -239,6 +240,9 @@ type Pinger struct {
logger Logger

TTL int

// tclass defines the traffic class (ToS for IPv4) set on outgoing icmp packets
tclass uint8
}

type packet struct {
Expand Down Expand Up @@ -488,6 +492,18 @@ func (p *Pinger) SetDoNotFragment(df bool) {
p.df = df
}

// SetTrafficClass sets the traffic class (type-of-service field for IPv4) field
// value for future outgoing packets.
func (p *Pinger) SetTrafficClass(tc uint8) {
p.tclass = tc
}

// TrafficClass returns the traffic class field (type-of-service field for IPv4)
// value for outgoing packets.
func (p *Pinger) TrafficClass() uint8 {
return p.tclass
}

// Run runs the pinger. This is a blocking function that will exit when it's
// done. If Count or Interval are not specified, it will run continuously until
// it is interrupted.
Expand Down Expand Up @@ -527,6 +543,12 @@ func (p *Pinger) RunWithContext(ctx context.Context) error {
}
}

if p.tclass != 0 {
if err := conn.SetTrafficClass(p.tclass); err != nil {
return fmt.Errorf("error setting traffic class: %v", err)
}
}

conn.SetTTL(p.TTL)
if p.InterfaceName != "" {
iface, err := net.InterfaceByName(p.InterfaceName)
Expand Down
14 changes: 14 additions & 0 deletions ping_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ func TestNewPingerValid(t *testing.T) {
AssertNotEqualStrings(t, "www.google.com", p.IPAddr().String())
AssertTrue(t, isIPv4(p.IPAddr().IP))
AssertFalse(t, p.Privileged())
AssertEquals(t, 192, p.tclass)
// Test that SetPrivileged works
p.SetPrivileged(true)
AssertTrue(t, p.Privileged())
Expand All @@ -253,6 +254,9 @@ func TestNewPingerValid(t *testing.T) {
err = p.SetAddr("ipv6.google.com")
AssertNoError(t, err)
AssertFalse(t, isIPv4(p.IPAddr().IP))
// Test setting traffic class
p.SetTrafficClass(0)
AssertEquals(t, 0, p.tclass)

p = New("localhost")
err = p.Resolve()
Expand Down Expand Up @@ -541,6 +545,14 @@ func AssertEqualStrings(t *testing.T, expected, actual string) {
}
}

func AssertEquals[T comparable](t *testing.T, expected, actual T) {
t.Helper()
if expected != actual {
t.Errorf("Expected %v, got %v, Stack:\n%s",
expected, actual, string(debug.Stack()))
}
}

func AssertNotEqualStrings(t *testing.T, expected, actual string) {
t.Helper()
if expected == actual {
Expand Down Expand Up @@ -666,6 +678,8 @@ func (c testPacketConn) SetMark(m uint) error { return nil }
func (c testPacketConn) SetDoNotFragment() error { return nil }
func (c testPacketConn) SetBroadcastFlag() error { return nil }
func (c testPacketConn) SetIfIndex(ifIndex int) {}
func (c testPacketConn) SetTrafficClass(uint8) error { return nil }

func (c testPacketConn) ReadFrom(b []byte) (n int, ttl int, src net.Addr, err error) {
return 0, 0, testAddr, nil
}
Expand Down

0 comments on commit 543f9b2

Please sign in to comment.