diff --git a/connect.go b/connect.go index 048785b..eb45317 100644 --- a/connect.go +++ b/connect.go @@ -6,6 +6,7 @@ import ( "math" "net" "os" + "path" "strconv" "strings" "sync" @@ -55,11 +56,11 @@ func runClient(connectionType string, codePhrase string) { time.Sleep(100 * time.Millisecond) // Write data from file logger.Debug("send file") - sendFile(id, connection) + sendFile(id, connection, codePhrase) } else { // this is a receiver // receive file logger.Debug("receive file") - fileNameToReceive = receiveFile(id, connection) + fileNameToReceive = receiveFile(id, connection, codePhrase) } }(id) @@ -96,7 +97,7 @@ func catFile(fileNameToReceive string) { fmt.Println("\n\n\nDownloaded " + fileNameToReceive + "!") } -func receiveFile(id int, connection net.Conn) string { +func receiveFile(id int, connection net.Conn, codePhrase string) string { logger := log.WithFields(log.Fields{ "function": "receiveFile #" + strconv.Itoa(id), }) @@ -142,19 +143,20 @@ func receiveFile(id int, connection net.Conn) string { return fileNameToReceive } -func sendFile(id int, connection net.Conn) { +func sendFile(id int, connection net.Conn, codePhrase string) { logger := log.WithFields(log.Fields{ "function": "sendFile #" + strconv.Itoa(id), }) defer connection.Close() - //Open the file that needs to be send to the client - file, err := os.Open(fileName) + + // Open the file that needs to be send to the client + file, err := os.Open(fileName + ".encrypted") if err != nil { fmt.Println(err) return } defer file.Close() - //Get the filename and filesize + // Get the filename and filesize fileInfo, err := file.Stat() if err != nil { fmt.Println(err) @@ -170,13 +172,13 @@ func sendFile(id int, connection net.Conn) { } fileSize := fillString(strconv.FormatInt(int64(bytesPerConnection), 10), 10) - fileNameToSend := fillString(fileInfo.Name(), 64) + fileNameToSend := fillString(path.Base(fileName), 64) if id == 0 || id == numberConnections-1 { logger.Debugf("numChunks: %v", numChunks) logger.Debugf("chunksPerWorker: %v", chunksPerWorker) logger.Debugf("bytesPerConnection: %v", bytesPerConnection) - logger.Debugf("fileNameToSend: %v", fileInfo.Name()) + logger.Debugf("fileNameToSend: %v", path.Base(fileName)) } logger.Debugf("sending %s", fileSize) diff --git a/main.go b/main.go index bfcc6a1..e31a604 100644 --- a/main.go +++ b/main.go @@ -4,6 +4,7 @@ import ( "bufio" "flag" "fmt" + "io/ioutil" "os" "strings" @@ -36,11 +37,6 @@ func main() { } if len(fileName) > 0 { - _, err := os.Open(fileName) - if err != nil { - log.Fatal(err) - return - } connectionTypeFlag = "s" // sender } else { connectionTypeFlag = "r" //receiver @@ -58,6 +54,21 @@ func main() { } } + if connectionTypeFlag == "s" { + // encrypt the file + fdata, err := ioutil.ReadFile(fileName) + if err != nil { + log.Fatal(err) + return + } + encrypted, err := Encrypt(fdata, codePhraseFlag) + if err != nil { + log.Fatal(err) + return + } + ioutil.WriteFile(fileName+".encrypted", encrypted, 0644) + } + log.SetFormatter(&log.TextFormatter{}) if debugFlag { log.SetLevel(log.DebugLevel)