diff --git a/go.mod b/go.mod index 2dbdf61..3df0dbb 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/schollz/croc/v9 go 1.13 require ( - filippo.io/age v1.0.0-rc.1 // indirect + filippo.io/age v1.0.0-rc.1 github.com/OneOfOne/xxhash v1.2.5 // indirect github.com/cespare/xxhash v1.1.0 github.com/cpuguy83/go-md2man/v2 v2.0.0 // indirect diff --git a/src/crypt/crypt.go b/src/crypt/crypt.go index b1b431c..712eba1 100644 --- a/src/crypt/crypt.go +++ b/src/crypt/crypt.go @@ -86,7 +86,7 @@ func NewAge() (pubkey string, privkey string, err error) { return } -func EncryptAge(pubkey string, data []byte) (encrypted []byte, err error) { +func EncryptAge(data []byte, pubkey string) (encrypted []byte, err error) { recipient, err := age.ParseX25519Recipient(pubkey) if err != nil { return @@ -109,7 +109,7 @@ func EncryptAge(pubkey string, data []byte) (encrypted []byte, err error) { return } -func DecryptAge(privkey string, encrypted []byte) (data []byte, err error) { +func DecryptAge(encrypted []byte, privkey string) (data []byte, err error) { identity, err := age.ParseX25519Identity(privkey) if err != nil { return diff --git a/src/crypt/crypt_test.go b/src/crypt/crypt_test.go index d47dd4c..eed6fc7 100644 --- a/src/crypt/crypt_test.go +++ b/src/crypt/crypt_test.go @@ -103,3 +103,15 @@ func TestEncryptionChaCha(t *testing.T) { _, _, err = NewArgon2([]byte(""), nil) assert.NotNil(t, err) } + +func TestEncryptionAge(t *testing.T) { + pub, priv, err := NewAge() + fmt.Printf("key: %s\n", pub) + assert.Nil(t, err) + msg := []byte("hello, world") + enc, err := EncryptAge(msg, pub) + assert.Nil(t, err) + dec, err := DecryptAge(enc, priv) + assert.Nil(t, err) + assert.Equal(t, msg, dec) +} diff --git a/src/tcp/tcp.go b/src/tcp/tcp.go index a578b14..5613fcd 100644 --- a/src/tcp/tcp.go +++ b/src/tcp/tcp.go @@ -8,10 +8,9 @@ import ( "sync" "time" - log "github.com/schollz/logger" - "github.com/schollz/pake/v3" - "filippo.io/age" + log "github.com/schollz/logger" + "github.com/schollz/croc/v9/src/comm" "github.com/schollz/croc/v9/src/crypt" "github.com/schollz/croc/v9/src/models" @@ -19,8 +18,8 @@ import ( type server struct { port string - privateKey string - publicKey string + keyPrivate string + keyPublic string debugLevel string banner string password string @@ -43,9 +42,11 @@ var timeToRoomDeletion = 10 * time.Minute var pingRoom = "pinglkasjdlfjsaldjf" // Run starts a tcp listener, run async -func Run(debugLevel, port, password string, banner ...string) (err error) { +func Run(debugLevel, port, password string, keyPublic string, keyPrivate string, banner ...string) (err error) { s := new(server) s.port = port + s.keyPrivate = keyPrivate + s.keyPublic = keyPublic s.password = password s.debugLevel = debugLevel if len(banner) > 0 { @@ -154,68 +155,21 @@ func (s *server) run() (err error) { var weakKey = []byte{1, 2, 3} func (s *server) clientCommunication(port string, c *comm.Comm) (room string, err error) { - identity, err := age.GenerateX25519Identity() + // get public key of the connecting client + retBytesEnc, err := c.Receive() if err != nil { return } - // send public key for encryption - c.Send([]byte(identity.Recipient().String())) - - // establish secure password with PAKE for communication with relay - B, err := pake.InitCurve(weakKey, 1, "siec") + retBytes, err := crypt.DecryptAge(retBytesEnc, s.keyPrivate) if err != nil { return } - Abytes, err := c.Receive() + // check whether we have a valid public key from client + keyPublic := string(retBytes) + _, err := age.ParseX25519Recipient(keyPublic) if err != nil { return } - if bytes.Equal(Abytes, []byte("ping")) { - room = pingRoom - c.Send([]byte("pong")) - return - } - err = B.Update(Abytes) - if err != nil { - return - } - err = c.Send(B.Bytes()) - if err != nil { - return - } - strongKey, err := B.SessionKey() - if err != nil { - return - } - log.Debugf("strongkey: %x", strongKey) - - // receive salt - salt, err := c.Receive() - if err != nil { - return - } - strongKeyForEncryption, _, err := crypt.New(strongKey, salt) - if err != nil { - return - } - - log.Debugf("waiting for password") - passwordBytesEnc, err := c.Receive() - if err != nil { - return - } - passwordBytes, err := crypt.Decrypt(passwordBytesEnc, strongKeyForEncryption) - if err != nil { - return - } - if strings.TrimSpace(string(passwordBytes)) != s.password { - err = fmt.Errorf("bad password") - enc, _ := crypt.Decrypt([]byte(err.Error()), strongKeyForEncryption) - if err := c.Send(enc); err != nil { - return "", fmt.Errorf("send error: %w", err) - } - return - } // send ok to tell client they are connected banner := s.banner @@ -223,7 +177,7 @@ func (s *server) clientCommunication(port string, c *comm.Comm) (room string, er banner = "ok" } log.Debugf("sending '%s'", banner) - bSend, err := crypt.Encrypt([]byte(banner+"|||"+c.Connection().RemoteAddr().String()), strongKeyForEncryption) + bSend, err := crypt.EncryptAge([]byte(banner+"|||"+c.Connection().RemoteAddr().String()), keyPublic) if err != nil { return } @@ -238,7 +192,7 @@ func (s *server) clientCommunication(port string, c *comm.Comm) (room string, er if err != nil { return } - roomBytes, err := crypt.Decrypt(enc, strongKeyForEncryption) + roomBytes, err := crypt.DecryptAge(enc, s.keyPrivate) if err != nil { return } @@ -254,7 +208,7 @@ func (s *server) clientCommunication(port string, c *comm.Comm) (room string, er s.rooms.Unlock() // tell the client that they got the room - bSend, err = crypt.Encrypt([]byte("ok"), strongKeyForEncryption) + bSend, err = crypt.EncryptAge([]byte("ok"), keyPublic) if err != nil { return } @@ -269,7 +223,7 @@ func (s *server) clientCommunication(port string, c *comm.Comm) (room string, er } if s.rooms.rooms[room].full { s.rooms.Unlock() - bSend, err = crypt.Encrypt([]byte("room full"), strongKeyForEncryption) + bSend, err = crypt.EncryptAge([]byte("room full"), keyPublic) if err != nil { return } @@ -425,65 +379,40 @@ func ConnectToTCPServer(address, password, room string, timelimit ...time.Durati return } - // get PAKE connection with server to establish strong key to transfer info - A, err := pake.InitCurve(weakKey, 0, "siec") - if err != nil { - return - } - err = c.Send(A.Bytes()) - if err != nil { - return - } - Bbytes, err := c.Receive() - if err != nil { - return - } - err = A.Update(Bbytes) - if err != nil { - return - } - strongKey, err := A.SessionKey() - if err != nil { - return - } - log.Debugf("strong key: %x", strongKey) + // generate ephermeral key + keyPublic, keyPrivate, err := crypt.NewAge() - strongKeyForEncryption, salt, err := crypt.New(strongKey, nil) - if err != nil { - return - } - // send salt - err = c.Send(salt) + // send epheremal public key, encrypted using the server's public key + foo := strings.Split(password, "--") + keyPublicRelay = foo[1] + password = foo[2] + + sendBytesEnc, err := crypt.EncryptAge([]byte(keyPublic), keyPublicRelay) if err != nil { return } - log.Debug("sending password") - bSend, err := crypt.Encrypt([]byte(password), strongKeyForEncryption) + err = c.Send(sendBytesEnc) if err != nil { return } - err = c.Send(bSend) + retBytesEnc, err := c.Receive() if err != nil { return } - log.Debug("waiting for first ok") - enc, err := c.Receive() + retBytes, err := crypt.DecryptAge(retBytesEnc, keyPrivate) if err != nil { return } - data, err := crypt.Decrypt(enc, strongKeyForEncryption) - if err != nil { + if !strings.Contains(string(retBytes), "|||") { + err = fmt.Errorf("bad response: %s", string(retBytes)) return } - if !strings.Contains(string(data), "|||") { - err = fmt.Errorf("bad response: %s", string(data)) - return - } - banner = strings.Split(string(data), "|||")[0] - ipaddr = strings.Split(string(data), "|||")[1] + banner = strings.Split(string(retBytes), "|||")[0] + ipaddr = strings.Split(string(retBytes), "|||")[1] + log.Debug("sending room") - bSend, err = crypt.Encrypt([]byte(room), strongKeyForEncryption) + bSend, err = crypt.EncryptAge([]byte(room), keyPublicRelay) if err != nil { return } @@ -496,7 +425,7 @@ func ConnectToTCPServer(address, password, room string, timelimit ...time.Durati if err != nil { return } - data, err = crypt.Decrypt(enc, strongKeyForEncryption) + data, err = crypt.DecryptAge(enc, keyPrivate) if err != nil { return }