Merge pull request #419 from kallydev/patch

Add host flag for relay
This commit is contained in:
Zack 2021-10-02 10:35:00 -07:00 committed by GitHub
commit d77c83ce09
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 38 additions and 14 deletions

View File

@ -79,6 +79,7 @@ func Run() (err error) {
HelpName: "croc relay", HelpName: "croc relay",
Action: relay, Action: relay,
Flags: []cli.Flag{ Flags: []cli.Flag{
&cli.StringFlag{Name: "host", Usage: "host of the relay"},
&cli.StringFlag{Name: "ports", Value: "9009,9010,9011,9012,9013", Usage: "ports of the relay"}, &cli.StringFlag{Name: "ports", Value: "9009,9010,9011,9012,9013", Usage: "ports of the relay"},
}, },
}, },
@ -526,6 +527,7 @@ func relay(c *cli.Context) (err error) {
if c.Bool("debug") { if c.Bool("debug") {
debugString = "debug" debugString = "debug"
} }
host := c.String("host")
ports := strings.Split(c.String("ports"), ",") ports := strings.Split(c.String("ports"), ",")
tcpPorts := strings.Join(ports[1:], ",") tcpPorts := strings.Join(ports[1:], ",")
for i, port := range ports { for i, port := range ports {
@ -533,11 +535,11 @@ func relay(c *cli.Context) (err error) {
continue continue
} }
go func(portStr string) { go func(portStr string) {
err = tcp.Run(debugString, portStr, determinePass(c)) err = tcp.Run(debugString, host, portStr, determinePass(c))
if err != nil { if err != nil {
panic(err) panic(err)
} }
}(port) }(port)
} }
return tcp.Run(debugString, ports[0], determinePass(c), tcpPorts) return tcp.Run(debugString, host, ports[0], determinePass(c), tcpPorts)
} }

View File

@ -296,7 +296,7 @@ func (c *Client) setupLocalRelay() {
if c.Options.Debug { if c.Options.Debug {
debugString = "debug" debugString = "debug"
} }
err := tcp.Run(debugString, portStr, c.Options.RelayPassword, strings.Join(c.Options.RelayPorts[1:], ",")) err := tcp.Run(debugString, "localhost", portStr, c.Options.RelayPassword, strings.Join(c.Options.RelayPorts[1:], ","))
if err != nil { if err != nil {
panic(err) panic(err)
} }

View File

@ -15,11 +15,11 @@ import (
func init() { func init() {
log.SetLevel("trace") log.SetLevel("trace")
go tcp.Run("debug", "8081", "pass123", "8082,8083,8084,8085") go tcp.Run("debug", "localhost", "8081", "pass123", "8082,8083,8084,8085")
go tcp.Run("debug", "8082", "pass123") go tcp.Run("debug", "localhost", "8082", "pass123")
go tcp.Run("debug", "8083", "pass123") go tcp.Run("debug", "localhost", "8083", "pass123")
go tcp.Run("debug", "8084", "pass123") go tcp.Run("debug", "localhost", "8084", "pass123")
go tcp.Run("debug", "8085", "pass123") go tcp.Run("debug", "localhost", "8085", "pass123")
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
} }

View File

@ -17,6 +17,7 @@ import (
) )
type server struct { type server struct {
host string
port string port string
debugLevel string debugLevel string
banner string banner string
@ -42,8 +43,9 @@ const (
) )
// Run starts a tcp listener, run async // Run starts a tcp listener, run async
func Run(debugLevel, port, password string, banner ...string) (err error) { func Run(debugLevel, host, port, password string, banner ...string) (err error) {
s := new(server) s := new(server)
s.host = host
s.port = port s.port = port
s.password = password s.password = password
s.debugLevel = debugLevel s.debugLevel = debugLevel
@ -87,10 +89,30 @@ func (s *server) start() (err error) {
} }
func (s *server) run() (err error) { func (s *server) run() (err error) {
log.Infof("starting TCP server on " + s.port) network := "tcp"
server, err := net.Listen("tcp", ":"+s.port) addr := net.JoinHostPort(s.host, s.port)
if s.host != "" {
ip := net.ParseIP(s.host)
if ip == nil {
tcpIP, err := net.ResolveIPAddr("ip", s.host)
if err != nil {
return err
}
ip = tcpIP.IP
}
addr = net.JoinHostPort(ip.String(), s.port)
if s.host != "" {
if ip.To4() != nil {
network = "tcp4"
} else {
network = "tcp6"
}
}
}
log.Infof("starting TCP server on " + addr)
server, err := net.Listen(network, addr)
if err != nil { if err != nil {
return fmt.Errorf("error listening on %s: %w", s.port, err) return fmt.Errorf("error listening on %s: %w", addr, err)
} }
defer server.Close() defer server.Close()
// spawn a new goroutine whenever a client connects // spawn a new goroutine whenever a client connects

View File

@ -12,7 +12,7 @@ import (
func BenchmarkConnection(b *testing.B) { func BenchmarkConnection(b *testing.B) {
log.SetLevel("trace") log.SetLevel("trace")
go Run("debug", "8283", "pass123", "8284") go Run("debug", "localhost", "8283", "pass123", "8284")
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
@ -24,7 +24,7 @@ func BenchmarkConnection(b *testing.B) {
func TestTCP(t *testing.T) { func TestTCP(t *testing.T) {
log.SetLevel("error") log.SetLevel("error")
timeToRoomDeletion = 100 * time.Millisecond timeToRoomDeletion = 100 * time.Millisecond
go Run("debug", "8281", "pass123", "8282") go Run("debug", "localhost", "8281", "pass123", "8282")
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
err := PingServer("localhost:8281") err := PingServer("localhost:8281")
assert.Nil(t, err) assert.Nil(t, err)