implement readers and writers all around

This commit is contained in:
Zack Scholl 2018-09-26 07:39:45 -07:00
parent ea548f290c
commit 085dd4e4c3
4 changed files with 29 additions and 21 deletions

View File

@ -1,6 +1,7 @@
package comm package comm
import ( import (
"bufio"
"bytes" "bytes"
"fmt" "fmt"
"net" "net"
@ -12,32 +13,39 @@ import (
// Comm is some basic TCP communication // Comm is some basic TCP communication
type Comm struct { type Comm struct {
connection net.Conn connection net.Conn
writer *bufio.Writer
} }
// New returns a new comm // New returns a new comm
func New(c net.Conn) Comm { func New(n net.Conn) *Comm {
c.SetReadDeadline(time.Now().Add(3 * time.Hour)) c := new(Comm)
c.SetDeadline(time.Now().Add(3 * time.Hour)) c.connection = n
c.SetWriteDeadline(time.Now().Add(3 * time.Hour)) c.connection.SetReadDeadline(time.Now().Add(3 * time.Hour))
return Comm{c} c.connection.SetDeadline(time.Now().Add(3 * time.Hour))
c.connection.SetWriteDeadline(time.Now().Add(3 * time.Hour))
c.writer = bufio.NewWriter(n)
return c
} }
// Connection returns the net.Conn connection // Connection returns the net.Conn connection
func (c Comm) Connection() net.Conn { func (c *Comm) Connection() net.Conn {
return c.connection return c.connection
} }
// Close closes the connection // Close closes the connection
func (c Comm) Close() { func (c *Comm) Close() {
c.connection.Close() c.connection.Close()
} }
func (c Comm) Write(b []byte) (int, error) { func (c *Comm) Write(b []byte) (int, error) {
c.connection.Write([]byte(fmt.Sprintf("%0.6d", len(b)))) c.writer.Write([]byte(fmt.Sprintf("%0.6d", len(b))))
n, err := c.connection.Write(b) n, err := c.writer.Write(b)
if n != len(b) { if n != len(b) {
err = fmt.Errorf("wanted to write %d but wrote %d", n, len(b)) err = fmt.Errorf("wanted to write %d but wrote %d", n, len(b))
} }
if err == nil {
err = c.writer.Flush()
}
// log.Printf("wanted to write %d but wrote %d", n, len(b)) // log.Printf("wanted to write %d but wrote %d", n, len(b))
return n, err return n, err
} }

View File

@ -50,7 +50,7 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo
var transferTime time.Duration var transferTime time.Duration
var hash256 []byte var hash256 []byte
var otherIP string var otherIP string
var tcpConnections []comm.Comm var tcpConnections []*comm.Comm
dataChan := make(chan []byte, 1024*1024) dataChan := make(chan []byte, 1024*1024)
useWebsockets := true useWebsockets := true
@ -176,7 +176,7 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo
// connect to TCP to receive file // connect to TCP to receive file
if !useWebsockets { if !useWebsockets {
log.Debugf("connecting to server") log.Debugf("connecting to server")
tcpConnections = make([]comm.Comm, len(tcpPorts)) tcpConnections = make([]*comm.Comm, len(tcpPorts))
for i, tcpPort := range tcpPorts { for i, tcpPort := range tcpPorts {
tcpConnections[i], err = connectToTCPServer(utils.SHA256(fmt.Sprintf("%d%x", i, sessionKey)), serverAddress+":"+tcpPort) tcpConnections[i], err = connectToTCPServer(utils.SHA256(fmt.Sprintf("%d%x", i, sessionKey)), serverAddress+":"+tcpPort)
if err != nil { if err != nil {
@ -300,7 +300,7 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(len(tcpConnections)) wg.Add(len(tcpConnections))
for i := range tcpConnections { for i := range tcpConnections {
go func(wg *sync.WaitGroup, tcpConnection comm.Comm) { go func(wg *sync.WaitGroup, tcpConnection *comm.Comm) {
defer wg.Done() defer wg.Done()
for { for {
// read from TCP connection // read from TCP connection
@ -405,7 +405,7 @@ func receive(forceSend int, serverAddress string, tcpPorts []string, isLocal boo
} }
} }
func connectToTCPServer(room string, address string) (com comm.Comm, err error) { func connectToTCPServer(room string, address string) (com *comm.Comm, err error) {
log.Debugf("recipient connecting to %s", address) log.Debugf("recipient connecting to %s", address)
connection, err := net.Dial("tcp", address) connection, err := net.Dial("tcp", address)
if err != nil { if err != nil {

View File

@ -51,7 +51,7 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool,
var fileHash []byte var fileHash []byte
var otherIP string var otherIP string
var startTransfer time.Time var startTransfer time.Time
var tcpConnections []comm.Comm var tcpConnections []*comm.Comm
type DataChan struct { type DataChan struct {
b []byte b []byte
@ -302,7 +302,7 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool,
// connect to TCP to receive file // connect to TCP to receive file
if !useWebsockets { if !useWebsockets {
log.Debugf("connecting to server") log.Debugf("connecting to server")
tcpConnections = make([]comm.Comm, len(tcpPorts)) tcpConnections = make([]*comm.Comm, len(tcpPorts))
for i, tcpPort := range tcpPorts { for i, tcpPort := range tcpPorts {
log.Debug(tcpPort) log.Debug(tcpPort)
tcpConnections[i], err = connectToTCPServer(utils.SHA256(fmt.Sprintf("%d%x", i, sessionKey)), serverAddress+":"+tcpPort) tcpConnections[i], err = connectToTCPServer(utils.SHA256(fmt.Sprintf("%d%x", i, sessionKey)), serverAddress+":"+tcpPort)
@ -346,7 +346,7 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool,
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(len(tcpConnections)) wg.Add(len(tcpConnections))
for i := range tcpConnections { for i := range tcpConnections {
go func(i int, wg *sync.WaitGroup, dataChan <-chan DataChan, tcpConnection comm.Comm) { go func(i int, wg *sync.WaitGroup, dataChan <-chan DataChan, tcpConnection *comm.Comm) {
defer wg.Done() defer wg.Done()
for data := range dataChan { for data := range dataChan {
if data.err != nil { if data.err != nil {
@ -407,7 +407,7 @@ func send(forceSend int, serverAddress string, tcpPorts []string, isLocal bool,
} }
} }
func connectToTCPServer(room string, address string) (com comm.Comm, err error) { func connectToTCPServer(room string, address string) (com *comm.Comm, err error) {
connection, err := net.Dial("tcp", address) connection, err := net.Dial("tcp", address)
if err != nil { if err != nil {
return return

View File

@ -13,7 +13,7 @@ import (
) )
type roomInfo struct { type roomInfo struct {
receiver comm.Comm receiver *comm.Comm
opened time.Time opened time.Time
} }
@ -62,7 +62,7 @@ func run(port string) (err error) {
} }
} }
func clientCommuncation(port string, c comm.Comm) (err error) { func clientCommuncation(port string, c *comm.Comm) (err error) {
// send ok to tell client they are connected // send ok to tell client they are connected
err = c.Send("ok") err = c.Send("ok")
if err != nil { if err != nil {
@ -98,7 +98,7 @@ func clientCommuncation(port string, c comm.Comm) (err error) {
wg.Add(1) wg.Add(1)
// start piping // start piping
go func(com1, com2 comm.Comm, wg *sync.WaitGroup) { go func(com1, com2 *comm.Comm, wg *sync.WaitGroup) {
log.Debug("starting pipes") log.Debug("starting pipes")
pipe(com1.Connection(), com2.Connection()) pipe(com1.Connection(), com2.Connection())
wg.Done() wg.Done()