diff --git a/src/crypt/crypt.go b/src/crypt/crypt.go index 65ba77c..b864d46 100644 --- a/src/crypt/crypt.go +++ b/src/crypt/crypt.go @@ -5,56 +5,37 @@ import ( "crypto/cipher" "crypto/rand" "crypto/sha256" + "fmt" "golang.org/x/crypto/pbkdf2" ) -// Encryption is the basic type for storing -// the key, passphrase and salt -type Encryption struct { - key []byte - passphrase []byte - salt []byte -} - -// New generates a new Encryption, using the supplied passphrase and -// an optional supplied salt. -// Passing nil passphrase will not use decryption. -func New(passphrase []byte, salt []byte) (e Encryption, err error) { - if passphrase == nil { - e = Encryption{nil, nil, nil} +// New generates a new key based on a passphrase and salt +func New(passphrase []byte, usersalt []byte) (key []byte, salt []byte, err error) { + if len(passphrase) < 1 { + err = fmt.Errorf("need more than that for passphrase") return } - e.passphrase = passphrase - if salt == nil { - e.salt = make([]byte, 8) + if usersalt == nil { + salt = make([]byte, 8) // http://www.ietf.org/rfc/rfc2898.txt // Salt. - rand.Read(e.salt) + rand.Read(salt) } else { - e.salt = salt + salt = usersalt } - e.key = pbkdf2.Key([]byte(passphrase), e.salt, 100, 32, sha256.New) + key = pbkdf2.Key([]byte(passphrase), salt, 100, 32, sha256.New) return } -// Salt returns the salt bytes -func (e Encryption) Salt() []byte { - return e.salt -} - -// Encrypt will generate an Encryption, prefixed with the IV -func (e Encryption) Encrypt(plaintext []byte) (encrypted []byte, err error) { - if e.passphrase == nil { - encrypted = plaintext - return - } +// Encrypt will encrypt using the pre-generated key +func Encrypt(plaintext []byte, key []byte) (encrypted []byte, err error) { // generate a random iv each time // http://nvlpubs.nist.gov/nistpubs/Legacy/SP/nistspecialpublication800-38d.pdf // Section 8.2 ivBytes := make([]byte, 12) rand.Read(ivBytes) - b, err := aes.NewCipher(e.key) + b, err := aes.NewCipher(key) if err != nil { return } @@ -67,13 +48,9 @@ func (e Encryption) Encrypt(plaintext []byte) (encrypted []byte, err error) { return } -// Decrypt an Encryption -func (e Encryption) Decrypt(encrypted []byte) (plaintext []byte, err error) { - if e.passphrase == nil { - plaintext = encrypted - return - } - b, err := aes.NewCipher(e.key) +// Decrypt using the pre-generated key +func Decrypt(encrypted []byte, key []byte) (plaintext []byte, err error) { + b, err := aes.NewCipher(key) if err != nil { return } diff --git a/src/crypt/crypt_test.go b/src/crypt/crypt_test.go index 14ac639..d03c4bd 100644 --- a/src/crypt/crypt_test.go +++ b/src/crypt/crypt_test.go @@ -6,55 +6,42 @@ import ( "github.com/stretchr/testify/assert" ) -func BenchmarkEncryptionNew(b *testing.B) { +func BenchmarkEncrypt(b *testing.B) { + bob, _, _ := New([]byte("password"), nil) for i := 0; i < b.N; i++ { - bob, _ := New([]byte("password"), nil) - bob.Encrypt([]byte("hello, world")) + Encrypt([]byte("hello, world"), bob) } } -func BenchmarkEncryption(b *testing.B) { - bob, _ := New([]byte("password"), nil) +func BenchmarkDecrypt(b *testing.B) { + key, _, _ := New([]byte("password"), nil) + msg := []byte("hello, world") + enc, _ := Encrypt(msg, key) + b.ResetTimer() for i := 0; i < b.N; i++ { - bob.Encrypt([]byte("hello, world")) + Decrypt(enc, key) } } func TestEncryption(t *testing.T) { - bob, err := New([]byte("password"), nil) + key, salt, err := New([]byte("password"), nil) assert.Nil(t, err) - jane, err := New([]byte("password"), bob.Salt()) + msg := []byte("hello, world") + enc, err := Encrypt(msg, key) assert.Nil(t, err) - enc, err := bob.Encrypt([]byte("hello, world")) + dec, err := Decrypt(enc, key) assert.Nil(t, err) - dec, err := jane.Decrypt(enc) - assert.Nil(t, err) - assert.Equal(t, dec, []byte("hello, world")) + assert.Equal(t, msg, dec) - jane2, err := New([]byte("password"), nil) + // check reusing the salt + key2, _, err := New([]byte("password"), salt) + dec, err = Decrypt(enc, key2) assert.Nil(t, err) - dec, err = jane2.Decrypt(enc) + assert.Equal(t, msg, dec) + + // check reusing the salt + key2, _, err = New([]byte("wrong password"), salt) + dec, err = Decrypt(enc, key2) assert.NotNil(t, err) - assert.NotEqual(t, dec, []byte("hello, world")) - - jane3, err := New([]byte("passwordwrong"), bob.Salt()) - assert.Nil(t, err) - dec, err = jane3.Decrypt(enc) - assert.NotNil(t, err) - assert.NotEqual(t, dec, []byte("hello, world")) - -} - -func TestNoEncryption(t *testing.T) { - bob, err := New(nil, nil) - assert.Nil(t, err) - jane, err := New(nil, nil) - assert.Nil(t, err) - enc, err := bob.Encrypt([]byte("hello, world")) - assert.Nil(t, err) - dec, err := jane.Decrypt(enc) - assert.Nil(t, err) - assert.Equal(t, dec, []byte("hello, world")) - assert.Equal(t, enc, []byte("hello, world")) - + assert.NotEqual(t, msg, dec) }