From de454bbf5aaf17b532fd43584b3cec6e9c8dd933 Mon Sep 17 00:00:00 2001 From: KallyDev Date: Fri, 1 Oct 2021 12:55:28 +0800 Subject: [PATCH] add host flag for relay --- src/cli/cli.go | 6 ++++-- src/croc/croc.go | 2 +- src/croc/croc_test.go | 10 +++++----- src/tcp/tcp.go | 30 ++++++++++++++++++++++++++---- src/tcp/tcp_test.go | 4 ++-- 5 files changed, 38 insertions(+), 14 deletions(-) diff --git a/src/cli/cli.go b/src/cli/cli.go index b28491e..ae61fe4 100644 --- a/src/cli/cli.go +++ b/src/cli/cli.go @@ -83,6 +83,7 @@ func Run() (err error) { return relay(c) }, Flags: []cli.Flag{ + &cli.StringFlag{Name: "host", Usage: "host of the relay"}, &cli.StringFlag{Name: "ports", Value: "9009,9010,9011,9012,9013", Usage: "ports of the relay"}, }, }, @@ -519,6 +520,7 @@ func relay(c *cli.Context) (err error) { if c.Bool("debug") { debugString = "debug" } + host := c.String("host") ports := strings.Split(c.String("ports"), ",") tcpPorts := strings.Join(ports[1:], ",") for i, port := range ports { @@ -526,11 +528,11 @@ func relay(c *cli.Context) (err error) { continue } go func(portStr string) { - err = tcp.Run(debugString, portStr, determinePass(c)) + err = tcp.Run(debugString, host, portStr, determinePass(c)) if err != nil { panic(err) } }(port) } - return tcp.Run(debugString, ports[0], determinePass(c), tcpPorts) + return tcp.Run(debugString, host, ports[0], determinePass(c), tcpPorts) } diff --git a/src/croc/croc.go b/src/croc/croc.go index 95790ca..2cc8321 100644 --- a/src/croc/croc.go +++ b/src/croc/croc.go @@ -291,7 +291,7 @@ func (c *Client) setupLocalRelay() { if c.Options.Debug { debugString = "debug" } - err := tcp.Run(debugString, portStr, c.Options.RelayPassword, strings.Join(c.Options.RelayPorts[1:], ",")) + err := tcp.Run(debugString, "localhost", portStr, c.Options.RelayPassword, strings.Join(c.Options.RelayPorts[1:], ",")) if err != nil { panic(err) } diff --git a/src/croc/croc_test.go b/src/croc/croc_test.go index 4142764..afd53ab 100644 --- a/src/croc/croc_test.go +++ b/src/croc/croc_test.go @@ -15,11 +15,11 @@ import ( func init() { log.SetLevel("trace") - go tcp.Run("debug", "8081", "pass123", "8082,8083,8084,8085") - go tcp.Run("debug", "8082", "pass123") - go tcp.Run("debug", "8083", "pass123") - go tcp.Run("debug", "8084", "pass123") - go tcp.Run("debug", "8085", "pass123") + go tcp.Run("debug", "localhost", "8081", "pass123", "8082,8083,8084,8085") + go tcp.Run("debug", "localhost", "8082", "pass123") + go tcp.Run("debug", "localhost", "8083", "pass123") + go tcp.Run("debug", "localhost", "8084", "pass123") + go tcp.Run("debug", "localhost", "8085", "pass123") time.Sleep(1 * time.Second) } diff --git a/src/tcp/tcp.go b/src/tcp/tcp.go index 96b4145..c17d710 100644 --- a/src/tcp/tcp.go +++ b/src/tcp/tcp.go @@ -17,6 +17,7 @@ import ( ) type server struct { + host string port string debugLevel string banner string @@ -40,8 +41,9 @@ var timeToRoomDeletion = 10 * time.Minute var pingRoom = "pinglkasjdlfjsaldjf" // Run starts a tcp listener, run async -func Run(debugLevel, port, password string, banner ...string) (err error) { +func Run(debugLevel, host, port, password string, banner ...string) (err error) { s := new(server) + s.host = host s.port = port s.password = password s.debugLevel = debugLevel @@ -85,10 +87,30 @@ func (s *server) start() (err error) { } func (s *server) run() (err error) { - log.Infof("starting TCP server on " + s.port) - server, err := net.Listen("tcp", ":"+s.port) + network := "tcp" + addr := net.JoinHostPort(s.host, s.port) + if s.host != "" { + ip := net.ParseIP(s.host) + if ip == nil { + tcpIP, err := net.ResolveIPAddr("ip", s.host) + if err != nil { + return err + } + ip = tcpIP.IP + } + addr = net.JoinHostPort(ip.String(), s.port) + if s.host != "" { + if ip.To4() != nil { + network = "tcp4" + } else { + network = "tcp6" + } + } + } + log.Infof("starting TCP server on " + addr) + server, err := net.Listen(network, addr) if err != nil { - return fmt.Errorf("error listening on %s: %w", s.port, err) + return fmt.Errorf("error listening on %s: %w", addr, err) } defer server.Close() // spawn a new goroutine whenever a client connects diff --git a/src/tcp/tcp_test.go b/src/tcp/tcp_test.go index ccceba7..165f953 100644 --- a/src/tcp/tcp_test.go +++ b/src/tcp/tcp_test.go @@ -12,7 +12,7 @@ import ( func BenchmarkConnection(b *testing.B) { log.SetLevel("trace") - go Run("debug", "8283", "pass123", "8284") + go Run("debug", "localhost", "8283", "pass123", "8284") time.Sleep(100 * time.Millisecond) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -24,7 +24,7 @@ func BenchmarkConnection(b *testing.B) { func TestTCP(t *testing.T) { log.SetLevel("error") timeToRoomDeletion = 100 * time.Millisecond - go Run("debug", "8281", "pass123", "8282") + go Run("debug", "localhost", "8281", "pass123", "8282") time.Sleep(100 * time.Millisecond) err := PingServer("localhost:8281") assert.Nil(t, err)