croc/src/tcp/tcp.go

314 lines
6.6 KiB
Go

package tcp
import (
"bytes"
"fmt"
"net"
"strings"
"sync"
"time"
"github.com/pkg/errors"
"github.com/schollz/croc/v6/src/comm"
"github.com/schollz/croc/v6/src/models"
log "github.com/schollz/logger"
)
type server struct {
port string
debugLevel string
banner string
password string
rooms roomMap
}
type roomInfo struct {
first *comm.Comm
second *comm.Comm
opened time.Time
full bool
}
type roomMap struct {
rooms map[string]roomInfo
sync.Mutex
}
var timeToRoomDeletion = 10 * time.Minute
// Run starts a tcp listener, run async
func Run(debugLevel, port, password string, banner ...string) (err error) {
s := new(server)
s.port = port
s.password = password
s.debugLevel = debugLevel
if len(banner) > 0 {
s.banner = banner[0]
}
return s.start()
}
func (s *server) start() (err error) {
log.SetLevel(s.debugLevel)
log.Debugf("starting with password '%s'", s.password)
s.rooms.Lock()
s.rooms.rooms = make(map[string]roomInfo)
s.rooms.Unlock()
// delete old rooms
go func() {
for {
time.Sleep(timeToRoomDeletion)
roomsToDelete := []string{}
s.rooms.Lock()
for room := range s.rooms.rooms {
if time.Since(s.rooms.rooms[room].opened) > 3*time.Hour {
roomsToDelete = append(roomsToDelete, room)
}
}
s.rooms.Unlock()
for _, room := range roomsToDelete {
s.deleteRoom(room)
}
}
}()
err = s.run()
if err != nil {
log.Error(err)
}
return
}
func (s *server) run() (err error) {
log.Infof("starting TCP server on " + s.port)
server, err := net.Listen("tcp", ":"+s.port)
if err != nil {
return errors.Wrap(err, "Error listening on :"+s.port)
}
defer server.Close()
// spawn a new goroutine whenever a client connects
for {
connection, err := server.Accept()
if err != nil {
return errors.Wrap(err, "problem accepting connection")
}
log.Debugf("client %s connected", connection.RemoteAddr().String())
go func(port string, connection net.Conn) {
errCommunication := s.clientCommuncation(port, comm.New(connection))
if errCommunication != nil {
log.Warnf("relay-%s: %s", connection.RemoteAddr().String(), errCommunication.Error())
}
}(s.port, connection)
}
}
func (s *server) clientCommuncation(port string, c *comm.Comm) (err error) {
log.Debugf("waiting for password")
passwordBytes, err := c.Receive()
if err != nil {
return
}
if strings.TrimSpace(string(passwordBytes)) != s.password {
err = fmt.Errorf("bad password")
c.Send([]byte(err.Error()))
return
}
// send ok to tell client they are connected
banner := s.banner
if len(banner) == 0 {
banner = "ok"
}
log.Debugf("sending '%s'", banner)
err = c.Send([]byte(banner + "|||" + c.Connection().RemoteAddr().String()))
if err != nil {
return
}
// wait for client to tell me which room they want
log.Debug("waiting for answer")
roomBytes, err := c.Receive()
if err != nil {
return
}
room := string(roomBytes)
s.rooms.Lock()
// create the room if it is new
if _, ok := s.rooms.rooms[room]; !ok {
s.rooms.rooms[room] = roomInfo{
first: c,
opened: time.Now(),
}
s.rooms.Unlock()
// tell the client that they got the room
err = c.Send([]byte("ok"))
if err != nil {
log.Error(err)
s.deleteRoom(room)
return
}
log.Debugf("room %s has 1", room)
return nil
}
if s.rooms.rooms[room].full {
s.rooms.Unlock()
err = c.Send([]byte("room full"))
if err != nil {
log.Error(err)
s.deleteRoom(room)
return
}
return nil
}
log.Debugf("room %s has 2", room)
s.rooms.rooms[room] = roomInfo{
first: s.rooms.rooms[room].first,
second: c,
opened: s.rooms.rooms[room].opened,
full: true,
}
otherConnection := s.rooms.rooms[room].first
s.rooms.Unlock()
// second connection is the sender, time to staple connections
var wg sync.WaitGroup
wg.Add(1)
// start piping
go func(com1, com2 *comm.Comm, wg *sync.WaitGroup) {
log.Debug("starting pipes")
pipe(com1.Connection(), com2.Connection())
wg.Done()
log.Debug("done piping")
}(otherConnection, c, &wg)
// tell the sender everything is ready
err = c.Send([]byte("ok"))
if err != nil {
s.deleteRoom(room)
return
}
wg.Wait()
// delete room
s.deleteRoom(room)
return nil
}
func (s *server) deleteRoom(room string) {
s.rooms.Lock()
defer s.rooms.Unlock()
if _, ok := s.rooms.rooms[room]; !ok {
return
}
log.Debugf("deleting room: %s", room)
if s.rooms.rooms[room].first != nil {
s.rooms.rooms[room].first.Close()
}
if s.rooms.rooms[room].second != nil {
s.rooms.rooms[room].second.Close()
}
s.rooms.rooms[room] = roomInfo{first: nil, second: nil}
delete(s.rooms.rooms, room)
}
// chanFromConn creates a channel from a Conn object, and sends everything it
// Read()s from the socket to the channel.
func chanFromConn(conn net.Conn) chan []byte {
c := make(chan []byte, 1)
go func() {
b := make([]byte, models.TCP_BUFFER_SIZE)
for {
n, err := conn.Read(b)
if n > 0 {
res := make([]byte, n)
// Copy the buffer so it doesn't get changed while read by the recipient.
copy(res, b[:n])
c <- res
}
if err != nil {
log.Debug(err)
c <- nil
break
}
}
log.Debug("exiting")
}()
return c
}
// pipe creates a full-duplex pipe between the two sockets and
// transfers data from one to the other.
func pipe(conn1 net.Conn, conn2 net.Conn) {
chan1 := chanFromConn(conn1)
chan2 := chanFromConn(conn2)
for {
select {
case b1 := <-chan1:
if b1 == nil {
return
}
conn2.Write(b1)
case b2 := <-chan2:
if b2 == nil {
return
}
conn1.Write(b2)
}
}
}
// ConnectToTCPServer will initiate a new connection
// to the specified address, room with optional time limit
func ConnectToTCPServer(address, password, room string, timelimit ...time.Duration) (c *comm.Comm, banner string, ipaddr string, err error) {
if len(timelimit) > 0 {
c, err = comm.NewConnection(address, timelimit[0])
} else {
c, err = comm.NewConnection(address)
}
if err != nil {
return
}
log.Debug("sending password")
err = c.Send([]byte(password))
if err != nil {
return
}
log.Debug("waiting for first ok")
data, err := c.Receive()
if err != nil {
return
}
if !strings.Contains(string(data), "|||") {
err = fmt.Errorf("bad response: %s", string(data))
return
}
banner = strings.Split(string(data), "|||")[0]
ipaddr = strings.Split(string(data), "|||")[1]
log.Debug("sending room")
err = c.Send([]byte(room))
if err != nil {
return
}
log.Debug("waiting for room confirmation")
data, err = c.Receive()
if err != nil {
return
}
if !bytes.Equal(data, []byte("ok")) {
err = fmt.Errorf("got bad response: %s", data)
return
}
log.Debug("all set")
return
}