diff --git a/connect.go b/connect.go index 4804c17..765555d 100644 --- a/connect.go +++ b/connect.go @@ -3,6 +3,7 @@ package main import ( "fmt" "io" + "io/ioutil" "math" "net" "os" @@ -29,7 +30,7 @@ func runClient(connectionType string, codePhrase string) { uiprogress.Start() bars = make([]*uiprogress.Bar, numberConnections) - fileNameToReceive := "" + var iv, salt, fileNameToReceive string for id := 0; id < numberConnections; id++ { go func(id int) { defer wg.Done() @@ -60,7 +61,7 @@ func runClient(connectionType string, codePhrase string) { } else { // this is a receiver // receive file logger.Debug("receive file") - fileNameToReceive = receiveFile(id, connection, codePhrase) + fileNameToReceive, iv, salt = receiveFile(id, connection, codePhrase) } }(id) @@ -69,13 +70,32 @@ func runClient(connectionType string, codePhrase string) { if connectionType == "r" { catFile(fileNameToReceive) + encrypted, err := ioutil.ReadFile(fileNameToReceive + ".encrypted") + if err != nil { + log.Error(err) + return + } + fmt.Println("\n\ndecrypting...") + decrypted, err := Decrypt(encrypted, codePhrase, salt, iv) + if err != nil { + log.Error(err) + return + } + ioutil.WriteFile(fileNameToReceive, decrypted, 0644) + os.Remove(fileNameToReceive + ".encrypted") + fmt.Println("\nDownloaded " + fileNameToReceive + "!") + } else { + log.Info("cleaning up") + os.Remove(fileName + ".encrypted") + os.Remove(fileName + ".iv") + os.Remove(fileName + ".salt") } } func catFile(fileNameToReceive string) { // cat the file os.Remove(fileNameToReceive) - finished, err := os.Create(fileNameToReceive) + finished, err := os.Create(fileNameToReceive + ".encrypted") defer finished.Close() if err != nil { log.Fatal(err) @@ -94,24 +114,35 @@ func catFile(fileNameToReceive string) { os.Remove(fileNameToReceive + "." + strconv.Itoa(id)) } - fmt.Println("\n\n\nDownloaded " + fileNameToReceive + "!") } -func receiveFile(id int, connection net.Conn, codePhrase string) string { +func receiveFile(id int, connection net.Conn, codePhrase string) (fileNameToReceive string, iv string, salt string) { logger := log.WithFields(log.Fields{ "function": "receiveFile #" + strconv.Itoa(id), }) - bufferFileName := make([]byte, 64) - bufferFileSize := make([]byte, 10) logger.Debug("waiting for file size") + + bufferFileSize := make([]byte, 10) connection.Read(bufferFileSize) fileSize, _ := strconv.ParseInt(strings.Trim(string(bufferFileSize), ":"), 10, 64) logger.Debugf("filesize: %d", fileSize) + bufferFileName := make([]byte, 64) connection.Read(bufferFileName) - fileNameToReceive := strings.Trim(string(bufferFileName), ":") + fileNameToReceive = strings.Trim(string(bufferFileName), ":") logger.Debugf("fileName: %v", fileNameToReceive) + + ivHex := make([]byte, BUFFERSIZE) + connection.Read(ivHex) + iv = strings.Trim(string(ivHex), ":") + logger.Debugf("iv: %v", iv) + + saltHex := make([]byte, BUFFERSIZE) + connection.Read(saltHex) + salt = strings.Trim(string(saltHex), ":") + logger.Debugf("salt: %v", salt) + os.Remove(fileNameToReceive + "." + strconv.Itoa(id)) newFile, err := os.Create(fileNameToReceive + "." + strconv.Itoa(id)) if err != nil { @@ -140,7 +171,7 @@ func receiveFile(id int, connection net.Conn, codePhrase string) string { receivedBytes += BUFFERSIZE } logger.Debug("received file") - return fileNameToReceive + return } func sendFile(id int, connection net.Conn, codePhrase string) { @@ -185,6 +216,25 @@ func sendFile(id int, connection net.Conn, codePhrase string) { connection.Write([]byte(fileSize)) logger.Debugf("sending fileNameToSend: %s", fileNameToSend) connection.Write([]byte(fileNameToSend)) + + // send iv + iv, err := ioutil.ReadFile(fileName + ".iv") + if err != nil { + log.Error(err) + return + } + logger.Debugf("sending iv: %s", iv) + connection.Write([]byte(fillString(string(iv), BUFFERSIZE))) + + // send salt + salt, err := ioutil.ReadFile(fileName + ".salt") + if err != nil { + log.Error(err) + return + } + logger.Debugf("sending salt: %s", salt) + connection.Write([]byte(fillString(string(salt), BUFFERSIZE))) + sendBuffer := make([]byte, BUFFERSIZE) chunkI := 0 diff --git a/crypto.go b/crypto.go index 3afa7af..658d791 100644 --- a/crypto.go +++ b/crypto.go @@ -1,7 +1,6 @@ package main import ( - "bytes" "crypto/aes" "crypto/cipher" "crypto/rand" @@ -29,7 +28,7 @@ func GetRandomName() string { return strings.Join(result, "-") } -func Encrypt(plaintext []byte, passphrase string) (ciphertext []byte, err error) { +func Encrypt(plaintext []byte, passphrase string) ([]byte, string, string) { key, salt := deriveKey(passphrase, nil) iv := make([]byte, 12) // http://nvlpubs.nist.gov/nistpubs/Legacy/SP/nistspecialpublication800-38d.pdf @@ -38,19 +37,16 @@ func Encrypt(plaintext []byte, passphrase string) (ciphertext []byte, err error) b, _ := aes.NewCipher(key) aesgcm, _ := cipher.NewGCM(b) data := aesgcm.Seal(nil, iv, plaintext, nil) - ciphertext = []byte(hex.EncodeToString(salt) + "-" + hex.EncodeToString(iv) + "-" + hex.EncodeToString(data)) - return + return data, hex.EncodeToString(salt), hex.EncodeToString(iv) } -func Decrypt(ciphertext []byte, passphrase string) (plaintext []byte, err error) { - arr := bytes.Split(ciphertext, []byte("-")) - salt, _ := hex.DecodeString(string(arr[0])) - iv, _ := hex.DecodeString(string(arr[1])) - data, _ := hex.DecodeString(string(arr[2])) - key, _ := deriveKey(passphrase, salt) +func Decrypt(data []byte, passphrase string, salt string, iv string) (plaintext []byte, err error) { + saltBytes, _ := hex.DecodeString(salt) + ivBytes, _ := hex.DecodeString(iv) + key, _ := deriveKey(passphrase, saltBytes) b, _ := aes.NewCipher(key) aesgcm, _ := cipher.NewGCM(b) - plaintext, err = aesgcm.Open(nil, iv, data, nil) + plaintext, err = aesgcm.Open(nil, ivBytes, data, nil) return } diff --git a/crypto_test.go b/crypto_test.go index 5e2dfca..3bafec9 100644 --- a/crypto_test.go +++ b/crypto_test.go @@ -8,19 +8,16 @@ import ( func TestEncrypt(t *testing.T) { key := GetRandomName() fmt.Println(key) - encrypted, err := Encrypt([]byte("hello, world"), key) - if err != nil { - t.Error(err) - } + salt, iv, encrypted := Encrypt([]byte("hello, world"), key) fmt.Println(len(encrypted)) - decrypted, err := Decrypt(encrypted, key) + decrypted, err := Decrypt(salt, iv, encrypted, key) if err != nil { t.Error(err) } if string(decrypted) != "hello, world" { t.Error("problem decrypting") } - _, err = Decrypt(encrypted, "wrong passphrase") + _, err = Decrypt(salt, iv, encrypted, "wrong passphrase") if err == nil { t.Error("should not work!") } diff --git a/main.go b/main.go index e31a604..341960b 100644 --- a/main.go +++ b/main.go @@ -56,17 +56,20 @@ func main() { if connectionTypeFlag == "s" { // encrypt the file + fmt.Println("encrypting...") fdata, err := ioutil.ReadFile(fileName) if err != nil { log.Fatal(err) return } - encrypted, err := Encrypt(fdata, codePhraseFlag) + encrypted, salt, iv := Encrypt(fdata, codePhraseFlag) if err != nil { log.Fatal(err) return } ioutil.WriteFile(fileName+".encrypted", encrypted, 0644) + ioutil.WriteFile(fileName+".salt", []byte(salt), 0644) + ioutil.WriteFile(fileName+".iv", []byte(iv), 0644) } log.SetFormatter(&log.TextFormatter{})