package tcp import ( "bytes" "fmt" "net" "strings" "sync" "time" "github.com/pkg/errors" "github.com/schollz/croc/v6/src/comm" "github.com/schollz/croc/v6/src/models" log "github.com/schollz/logger" ) type server struct { port string debugLevel string banner string password string rooms roomMap } type roomInfo struct { first *comm.Comm second *comm.Comm opened time.Time full bool } type roomMap struct { rooms map[string]roomInfo sync.Mutex } var timeToRoomDeletion = 10 * time.Minute // Run starts a tcp listener, run async func Run(debugLevel, port, password string, banner ...string) (err error) { s := new(server) s.port = port s.password = password s.debugLevel = debugLevel if len(banner) > 0 { s.banner = banner[0] } return s.start() } func (s *server) start() (err error) { log.SetLevel(s.debugLevel) log.Debugf("starting with password '%s'", s.password) s.rooms.Lock() s.rooms.rooms = make(map[string]roomInfo) s.rooms.Unlock() // delete old rooms go func() { for { time.Sleep(timeToRoomDeletion) roomsToDelete := []string{} s.rooms.Lock() for room := range s.rooms.rooms { if time.Since(s.rooms.rooms[room].opened) > 3*time.Hour { roomsToDelete = append(roomsToDelete, room) } } s.rooms.Unlock() for _, room := range roomsToDelete { s.deleteRoom(room) } } }() err = s.run() if err != nil { log.Error(err) } return } func (s *server) run() (err error) { log.Infof("starting TCP server on " + s.port) server, err := net.Listen("tcp", ":"+s.port) if err != nil { return errors.Wrap(err, "Error listening on :"+s.port) } defer server.Close() // spawn a new goroutine whenever a client connects for { connection, err := server.Accept() if err != nil { return errors.Wrap(err, "problem accepting connection") } log.Debugf("client %s connected", connection.RemoteAddr().String()) go func(port string, connection net.Conn) { errCommunication := s.clientCommuncation(port, comm.New(connection)) if errCommunication != nil { log.Warnf("relay-%s: %s", connection.RemoteAddr().String(), errCommunication.Error()) } }(s.port, connection) } } func (s *server) clientCommuncation(port string, c *comm.Comm) (err error) { log.Debugf("waiting for password") passwordBytes, err := c.Receive() if err != nil { return } if strings.TrimSpace(string(passwordBytes)) != s.password { err = fmt.Errorf("bad password") c.Send([]byte(err.Error())) return } // send ok to tell client they are connected banner := s.banner if len(banner) == 0 { banner = "ok" } log.Debugf("sending '%s'", banner) err = c.Send([]byte(banner + "|||" + c.Connection().RemoteAddr().String())) if err != nil { return } // wait for client to tell me which room they want log.Debug("waiting for answer") roomBytes, err := c.Receive() if err != nil { return } room := string(roomBytes) s.rooms.Lock() // create the room if it is new if _, ok := s.rooms.rooms[room]; !ok { s.rooms.rooms[room] = roomInfo{ first: c, opened: time.Now(), } s.rooms.Unlock() // tell the client that they got the room err = c.Send([]byte("ok")) if err != nil { log.Error(err) s.deleteRoom(room) return } log.Debugf("room %s has 1", room) return nil } if s.rooms.rooms[room].full { s.rooms.Unlock() err = c.Send([]byte("room full")) if err != nil { log.Error(err) s.deleteRoom(room) return } return nil } log.Debugf("room %s has 2", room) s.rooms.rooms[room] = roomInfo{ first: s.rooms.rooms[room].first, second: c, opened: s.rooms.rooms[room].opened, full: true, } otherConnection := s.rooms.rooms[room].first s.rooms.Unlock() // second connection is the sender, time to staple connections var wg sync.WaitGroup wg.Add(1) // start piping go func(com1, com2 *comm.Comm, wg *sync.WaitGroup) { log.Debug("starting pipes") pipe(com1.Connection(), com2.Connection()) wg.Done() log.Debug("done piping") }(otherConnection, c, &wg) // tell the sender everything is ready err = c.Send([]byte("ok")) if err != nil { s.deleteRoom(room) return } wg.Wait() // delete room s.deleteRoom(room) return nil } func (s *server) deleteRoom(room string) { s.rooms.Lock() defer s.rooms.Unlock() if _, ok := s.rooms.rooms[room]; !ok { return } log.Debugf("deleting room: %s", room) if s.rooms.rooms[room].first != nil { s.rooms.rooms[room].first.Close() } if s.rooms.rooms[room].second != nil { s.rooms.rooms[room].second.Close() } s.rooms.rooms[room] = roomInfo{first: nil, second: nil} delete(s.rooms.rooms, room) } // chanFromConn creates a channel from a Conn object, and sends everything it // Read()s from the socket to the channel. func chanFromConn(conn net.Conn) chan []byte { c := make(chan []byte, 1) go func() { b := make([]byte, models.TCP_BUFFER_SIZE) for { n, err := conn.Read(b) if n > 0 { res := make([]byte, n) // Copy the buffer so it doesn't get changed while read by the recipient. copy(res, b[:n]) c <- res } if err != nil { log.Debug(err) c <- nil break } } log.Debug("exiting") }() return c } // pipe creates a full-duplex pipe between the two sockets and // transfers data from one to the other. func pipe(conn1 net.Conn, conn2 net.Conn) { chan1 := chanFromConn(conn1) chan2 := chanFromConn(conn2) for { select { case b1 := <-chan1: if b1 == nil { return } conn2.Write(b1) case b2 := <-chan2: if b2 == nil { return } conn1.Write(b2) } } } // ConnectToTCPServer will initiate a new connection // to the specified address, room with optional time limit func ConnectToTCPServer(address, password, room string, timelimit ...time.Duration) (c *comm.Comm, banner string, ipaddr string, err error) { if len(timelimit) > 0 { c, err = comm.NewConnection(address, timelimit[0]) } else { c, err = comm.NewConnection(address) } if err != nil { return } log.Debug("sending password") err = c.Send([]byte(password)) if err != nil { return } log.Debug("waiting for first ok") data, err := c.Receive() if err != nil { return } if !strings.Contains(string(data), "|||") { err = fmt.Errorf("bad response: %s", string(data)) return } banner = strings.Split(string(data), "|||")[0] ipaddr = strings.Split(string(data), "|||")[1] log.Debug("sending room") err = c.Send([]byte(room)) if err != nil { return } log.Debug("waiting for room confirmation") data, err = c.Receive() if err != nil { return } if !bytes.Equal(data, []byte("ok")) { err = fmt.Errorf("got bad response: %s", data) return } log.Debug("all set") return }