Encryption works, cleanup is good

This commit is contained in:
Zack Scholl 2017-10-17 21:15:48 -06:00
parent e59df2e617
commit 0cf680fd66
4 changed files with 73 additions and 27 deletions

View File

@ -3,6 +3,7 @@ package main
import (
"fmt"
"io"
"io/ioutil"
"math"
"net"
"os"
@ -29,7 +30,7 @@ func runClient(connectionType string, codePhrase string) {
uiprogress.Start()
bars = make([]*uiprogress.Bar, numberConnections)
fileNameToReceive := ""
var iv, salt, fileNameToReceive string
for id := 0; id < numberConnections; id++ {
go func(id int) {
defer wg.Done()
@ -60,7 +61,7 @@ func runClient(connectionType string, codePhrase string) {
} else { // this is a receiver
// receive file
logger.Debug("receive file")
fileNameToReceive = receiveFile(id, connection, codePhrase)
fileNameToReceive, iv, salt = receiveFile(id, connection, codePhrase)
}
}(id)
@ -69,13 +70,32 @@ func runClient(connectionType string, codePhrase string) {
if connectionType == "r" {
catFile(fileNameToReceive)
encrypted, err := ioutil.ReadFile(fileNameToReceive + ".encrypted")
if err != nil {
log.Error(err)
return
}
fmt.Println("\n\ndecrypting...")
decrypted, err := Decrypt(encrypted, codePhrase, salt, iv)
if err != nil {
log.Error(err)
return
}
ioutil.WriteFile(fileNameToReceive, decrypted, 0644)
os.Remove(fileNameToReceive + ".encrypted")
fmt.Println("\nDownloaded " + fileNameToReceive + "!")
} else {
log.Info("cleaning up")
os.Remove(fileName + ".encrypted")
os.Remove(fileName + ".iv")
os.Remove(fileName + ".salt")
}
}
func catFile(fileNameToReceive string) {
// cat the file
os.Remove(fileNameToReceive)
finished, err := os.Create(fileNameToReceive)
finished, err := os.Create(fileNameToReceive + ".encrypted")
defer finished.Close()
if err != nil {
log.Fatal(err)
@ -94,24 +114,35 @@ func catFile(fileNameToReceive string) {
os.Remove(fileNameToReceive + "." + strconv.Itoa(id))
}
fmt.Println("\n\n\nDownloaded " + fileNameToReceive + "!")
}
func receiveFile(id int, connection net.Conn, codePhrase string) string {
func receiveFile(id int, connection net.Conn, codePhrase string) (fileNameToReceive string, iv string, salt string) {
logger := log.WithFields(log.Fields{
"function": "receiveFile #" + strconv.Itoa(id),
})
bufferFileName := make([]byte, 64)
bufferFileSize := make([]byte, 10)
logger.Debug("waiting for file size")
bufferFileSize := make([]byte, 10)
connection.Read(bufferFileSize)
fileSize, _ := strconv.ParseInt(strings.Trim(string(bufferFileSize), ":"), 10, 64)
logger.Debugf("filesize: %d", fileSize)
bufferFileName := make([]byte, 64)
connection.Read(bufferFileName)
fileNameToReceive := strings.Trim(string(bufferFileName), ":")
fileNameToReceive = strings.Trim(string(bufferFileName), ":")
logger.Debugf("fileName: %v", fileNameToReceive)
ivHex := make([]byte, BUFFERSIZE)
connection.Read(ivHex)
iv = strings.Trim(string(ivHex), ":")
logger.Debugf("iv: %v", iv)
saltHex := make([]byte, BUFFERSIZE)
connection.Read(saltHex)
salt = strings.Trim(string(saltHex), ":")
logger.Debugf("salt: %v", salt)
os.Remove(fileNameToReceive + "." + strconv.Itoa(id))
newFile, err := os.Create(fileNameToReceive + "." + strconv.Itoa(id))
if err != nil {
@ -140,7 +171,7 @@ func receiveFile(id int, connection net.Conn, codePhrase string) string {
receivedBytes += BUFFERSIZE
}
logger.Debug("received file")
return fileNameToReceive
return
}
func sendFile(id int, connection net.Conn, codePhrase string) {
@ -185,6 +216,25 @@ func sendFile(id int, connection net.Conn, codePhrase string) {
connection.Write([]byte(fileSize))
logger.Debugf("sending fileNameToSend: %s", fileNameToSend)
connection.Write([]byte(fileNameToSend))
// send iv
iv, err := ioutil.ReadFile(fileName + ".iv")
if err != nil {
log.Error(err)
return
}
logger.Debugf("sending iv: %s", iv)
connection.Write([]byte(fillString(string(iv), BUFFERSIZE)))
// send salt
salt, err := ioutil.ReadFile(fileName + ".salt")
if err != nil {
log.Error(err)
return
}
logger.Debugf("sending salt: %s", salt)
connection.Write([]byte(fillString(string(salt), BUFFERSIZE)))
sendBuffer := make([]byte, BUFFERSIZE)
chunkI := 0

View File

@ -1,7 +1,6 @@
package main
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
@ -29,7 +28,7 @@ func GetRandomName() string {
return strings.Join(result, "-")
}
func Encrypt(plaintext []byte, passphrase string) (ciphertext []byte, err error) {
func Encrypt(plaintext []byte, passphrase string) ([]byte, string, string) {
key, salt := deriveKey(passphrase, nil)
iv := make([]byte, 12)
// http://nvlpubs.nist.gov/nistpubs/Legacy/SP/nistspecialpublication800-38d.pdf
@ -38,19 +37,16 @@ func Encrypt(plaintext []byte, passphrase string) (ciphertext []byte, err error)
b, _ := aes.NewCipher(key)
aesgcm, _ := cipher.NewGCM(b)
data := aesgcm.Seal(nil, iv, plaintext, nil)
ciphertext = []byte(hex.EncodeToString(salt) + "-" + hex.EncodeToString(iv) + "-" + hex.EncodeToString(data))
return
return data, hex.EncodeToString(salt), hex.EncodeToString(iv)
}
func Decrypt(ciphertext []byte, passphrase string) (plaintext []byte, err error) {
arr := bytes.Split(ciphertext, []byte("-"))
salt, _ := hex.DecodeString(string(arr[0]))
iv, _ := hex.DecodeString(string(arr[1]))
data, _ := hex.DecodeString(string(arr[2]))
key, _ := deriveKey(passphrase, salt)
func Decrypt(data []byte, passphrase string, salt string, iv string) (plaintext []byte, err error) {
saltBytes, _ := hex.DecodeString(salt)
ivBytes, _ := hex.DecodeString(iv)
key, _ := deriveKey(passphrase, saltBytes)
b, _ := aes.NewCipher(key)
aesgcm, _ := cipher.NewGCM(b)
plaintext, err = aesgcm.Open(nil, iv, data, nil)
plaintext, err = aesgcm.Open(nil, ivBytes, data, nil)
return
}

View File

@ -8,19 +8,16 @@ import (
func TestEncrypt(t *testing.T) {
key := GetRandomName()
fmt.Println(key)
encrypted, err := Encrypt([]byte("hello, world"), key)
if err != nil {
t.Error(err)
}
salt, iv, encrypted := Encrypt([]byte("hello, world"), key)
fmt.Println(len(encrypted))
decrypted, err := Decrypt(encrypted, key)
decrypted, err := Decrypt(salt, iv, encrypted, key)
if err != nil {
t.Error(err)
}
if string(decrypted) != "hello, world" {
t.Error("problem decrypting")
}
_, err = Decrypt(encrypted, "wrong passphrase")
_, err = Decrypt(salt, iv, encrypted, "wrong passphrase")
if err == nil {
t.Error("should not work!")
}

View File

@ -56,17 +56,20 @@ func main() {
if connectionTypeFlag == "s" {
// encrypt the file
fmt.Println("encrypting...")
fdata, err := ioutil.ReadFile(fileName)
if err != nil {
log.Fatal(err)
return
}
encrypted, err := Encrypt(fdata, codePhraseFlag)
encrypted, salt, iv := Encrypt(fdata, codePhraseFlag)
if err != nil {
log.Fatal(err)
return
}
ioutil.WriteFile(fileName+".encrypted", encrypted, 0644)
ioutil.WriteFile(fileName+".salt", []byte(salt), 0644)
ioutil.WriteFile(fileName+".iv", []byte(iv), 0644)
}
log.SetFormatter(&log.TextFormatter{})