diff --git a/src/cli/cli.go b/src/cli/cli.go index 6695f95..b9daf87 100644 --- a/src/cli/cli.go +++ b/src/cli/cli.go @@ -79,6 +79,7 @@ func Run() (err error) { HelpName: "croc relay", Action: relay, Flags: []cli.Flag{ + &cli.StringFlag{Name: "host", Usage: "host of the relay"}, &cli.StringFlag{Name: "ports", Value: "9009,9010,9011,9012,9013", Usage: "ports of the relay"}, }, }, @@ -526,6 +527,7 @@ func relay(c *cli.Context) (err error) { if c.Bool("debug") { debugString = "debug" } + host := c.String("host") ports := strings.Split(c.String("ports"), ",") tcpPorts := strings.Join(ports[1:], ",") for i, port := range ports { @@ -533,11 +535,11 @@ func relay(c *cli.Context) (err error) { continue } go func(portStr string) { - err = tcp.Run(debugString, portStr, determinePass(c)) + err = tcp.Run(debugString, host, portStr, determinePass(c)) if err != nil { panic(err) } }(port) } - return tcp.Run(debugString, ports[0], determinePass(c), tcpPorts) + return tcp.Run(debugString, host, ports[0], determinePass(c), tcpPorts) } diff --git a/src/croc/croc.go b/src/croc/croc.go index acd562d..93c6e18 100644 --- a/src/croc/croc.go +++ b/src/croc/croc.go @@ -296,7 +296,7 @@ func (c *Client) setupLocalRelay() { if c.Options.Debug { debugString = "debug" } - err := tcp.Run(debugString, portStr, c.Options.RelayPassword, strings.Join(c.Options.RelayPorts[1:], ",")) + err := tcp.Run(debugString, "localhost", portStr, c.Options.RelayPassword, strings.Join(c.Options.RelayPorts[1:], ",")) if err != nil { panic(err) } diff --git a/src/croc/croc_test.go b/src/croc/croc_test.go index 4142764..afd53ab 100644 --- a/src/croc/croc_test.go +++ b/src/croc/croc_test.go @@ -15,11 +15,11 @@ import ( func init() { log.SetLevel("trace") - go tcp.Run("debug", "8081", "pass123", "8082,8083,8084,8085") - go tcp.Run("debug", "8082", "pass123") - go tcp.Run("debug", "8083", "pass123") - go tcp.Run("debug", "8084", "pass123") - go tcp.Run("debug", "8085", "pass123") + go tcp.Run("debug", "localhost", "8081", "pass123", "8082,8083,8084,8085") + go tcp.Run("debug", "localhost", "8082", "pass123") + go tcp.Run("debug", "localhost", "8083", "pass123") + go tcp.Run("debug", "localhost", "8084", "pass123") + go tcp.Run("debug", "localhost", "8085", "pass123") time.Sleep(1 * time.Second) } diff --git a/src/tcp/tcp.go b/src/tcp/tcp.go index c57baf3..d7b3c3b 100644 --- a/src/tcp/tcp.go +++ b/src/tcp/tcp.go @@ -17,6 +17,7 @@ import ( ) type server struct { + host string port string debugLevel string banner string @@ -42,8 +43,9 @@ const ( ) // Run starts a tcp listener, run async -func Run(debugLevel, port, password string, banner ...string) (err error) { +func Run(debugLevel, host, port, password string, banner ...string) (err error) { s := new(server) + s.host = host s.port = port s.password = password s.debugLevel = debugLevel @@ -87,10 +89,30 @@ func (s *server) start() (err error) { } func (s *server) run() (err error) { - log.Infof("starting TCP server on " + s.port) - server, err := net.Listen("tcp", ":"+s.port) + network := "tcp" + addr := net.JoinHostPort(s.host, s.port) + if s.host != "" { + ip := net.ParseIP(s.host) + if ip == nil { + tcpIP, err := net.ResolveIPAddr("ip", s.host) + if err != nil { + return err + } + ip = tcpIP.IP + } + addr = net.JoinHostPort(ip.String(), s.port) + if s.host != "" { + if ip.To4() != nil { + network = "tcp4" + } else { + network = "tcp6" + } + } + } + log.Infof("starting TCP server on " + addr) + server, err := net.Listen(network, addr) if err != nil { - return fmt.Errorf("error listening on %s: %w", s.port, err) + return fmt.Errorf("error listening on %s: %w", addr, err) } defer server.Close() // spawn a new goroutine whenever a client connects diff --git a/src/tcp/tcp_test.go b/src/tcp/tcp_test.go index ccceba7..165f953 100644 --- a/src/tcp/tcp_test.go +++ b/src/tcp/tcp_test.go @@ -12,7 +12,7 @@ import ( func BenchmarkConnection(b *testing.B) { log.SetLevel("trace") - go Run("debug", "8283", "pass123", "8284") + go Run("debug", "localhost", "8283", "pass123", "8284") time.Sleep(100 * time.Millisecond) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -24,7 +24,7 @@ func BenchmarkConnection(b *testing.B) { func TestTCP(t *testing.T) { log.SetLevel("error") timeToRoomDeletion = 100 * time.Millisecond - go Run("debug", "8281", "pass123", "8282") + go Run("debug", "localhost", "8281", "pass123", "8282") time.Sleep(100 * time.Millisecond) err := PingServer("localhost:8281") assert.Nil(t, err)