diff --git a/src/comm/comm.go b/src/comm/comm.go index aea7486..84df3d4 100644 --- a/src/comm/comm.go +++ b/src/comm/comm.go @@ -1,6 +1,7 @@ package comm import ( + "bufio" "bytes" "fmt" "net" @@ -12,32 +13,39 @@ import ( // Comm is some basic TCP communication type Comm struct { connection net.Conn + writer *bufio.Writer } // New returns a new comm -func New(c net.Conn) Comm { - c.SetReadDeadline(time.Now().Add(3 * time.Hour)) - c.SetDeadline(time.Now().Add(3 * time.Hour)) - c.SetWriteDeadline(time.Now().Add(3 * time.Hour)) - return Comm{c} +func New(n net.Conn) *Comm { + c := new(Comm) + c.connection = n + c.connection.SetReadDeadline(time.Now().Add(3 * time.Hour)) + c.connection.SetDeadline(time.Now().Add(3 * time.Hour)) + c.connection.SetWriteDeadline(time.Now().Add(3 * time.Hour)) + c.writer = bufio.NewWriter(n) + return c } // Connection returns the net.Conn connection -func (c Comm) Connection() net.Conn { +func (c *Comm) Connection() net.Conn { return c.connection } // Close closes the connection -func (c Comm) Close() { +func (c *Comm) Close() { c.connection.Close() } -func (c Comm) Write(b []byte) (int, error) { - c.connection.Write([]byte(fmt.Sprintf("%0.6d", len(b)))) - n, err := c.connection.Write(b) +func (c *Comm) Write(b []byte) (int, error) { + c.writer.Write([]byte(fmt.Sprintf("%0.6d", len(b)))) + n, err := c.writer.Write(b) if n != len(b) { err = fmt.Errorf("wanted to write %d but wrote %d", n, len(b)) } + if err == nil { + err = c.writer.Flush() + } // log.Printf("wanted to write %d but wrote %d", n, len(b)) return n, err } diff --git a/src/recipient/recipient.go b/src/recipient/recipient.go index 5d283b2..8e70fec 100644 --- a/src/recipient/recipient.go +++ b/src/recipient/recipient.go @@ -50,7 +50,7 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo var transferTime time.Duration var hash256 []byte var otherIP string - var tcpConnections []comm.Comm + var tcpConnections []*comm.Comm dataChan := make(chan []byte, 1024*1024) useWebsockets := true @@ -176,7 +176,7 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo // connect to TCP to receive file if !useWebsockets { log.Debugf("connecting to server") - tcpConnections = make([]comm.Comm, len(tcpPorts)) + tcpConnections = make([]*comm.Comm, len(tcpPorts)) for i, tcpPort := range tcpPorts { tcpConnections[i], err = connectToTCPServer(utils.SHA256(fmt.Sprintf("%d%x", i, sessionKey)), serverAddress+":"+tcpPort) if err != nil { @@ -300,7 +300,7 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo var wg sync.WaitGroup wg.Add(len(tcpConnections)) for i := range tcpConnections { - go func(wg *sync.WaitGroup, tcpConnection comm.Comm) { + go func(wg *sync.WaitGroup, tcpConnection *comm.Comm) { defer wg.Done() for { // read from TCP connection @@ -405,7 +405,7 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo } } -func connectToTCPServer(room string, address string) (com comm.Comm, err error) { +func connectToTCPServer(room string, address string) (com *comm.Comm, err error) { log.Debugf("recipient connecting to %s", address) connection, err := net.Dial("tcp", address) if err != nil { diff --git a/src/sender/sender.go b/src/sender/sender.go index caed83d..0975118 100644 --- a/src/sender/sender.go +++ b/src/sender/sender.go @@ -51,7 +51,7 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool, var fileHash []byte var otherIP string var startTransfer time.Time - var tcpConnections []comm.Comm + var tcpConnections []*comm.Comm type DataChan struct { b []byte @@ -302,7 +302,7 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool, // connect to TCP to receive file if !useWebsockets { log.Debugf("connecting to server") - tcpConnections = make([]comm.Comm, len(tcpPorts)) + tcpConnections = make([]*comm.Comm, len(tcpPorts)) for i, tcpPort := range tcpPorts { log.Debug(tcpPort) tcpConnections[i], err = connectToTCPServer(utils.SHA256(fmt.Sprintf("%d%x", i, sessionKey)), serverAddress+":"+tcpPort) @@ -346,7 +346,7 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool, var wg sync.WaitGroup wg.Add(len(tcpConnections)) for i := range tcpConnections { - go func(i int, wg *sync.WaitGroup, dataChan <-chan DataChan, tcpConnection comm.Comm) { + go func(i int, wg *sync.WaitGroup, dataChan <-chan DataChan, tcpConnection *comm.Comm) { defer wg.Done() for data := range dataChan { if data.err != nil { @@ -407,7 +407,7 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool, } } -func connectToTCPServer(room string, address string) (com comm.Comm, err error) { +func connectToTCPServer(room string, address string) (com *comm.Comm, err error) { connection, err := net.Dial("tcp", address) if err != nil { return diff --git a/src/tcp/tcp.go b/src/tcp/tcp.go index 4177f77..d6feeb7 100644 --- a/src/tcp/tcp.go +++ b/src/tcp/tcp.go @@ -13,7 +13,7 @@ import ( ) type roomInfo struct { - receiver comm.Comm + receiver *comm.Comm opened time.Time } @@ -62,7 +62,7 @@ func run(port string) (err error) { } } -func clientCommuncation(port string, c comm.Comm) (err error) { +func clientCommuncation(port string, c *comm.Comm) (err error) { // send ok to tell client they are connected err = c.Send("ok") if err != nil { @@ -98,7 +98,7 @@ func clientCommuncation(port string, c comm.Comm) (err error) { wg.Add(1) // start piping - go func(com1, com2 comm.Comm, wg *sync.WaitGroup) { + go func(com1, com2 *comm.Comm, wg *sync.WaitGroup) { log.Debug("starting pipes") pipe(com1.Connection(), com2.Connection()) wg.Done()