This commit is contained in:
Zack Scholl 2021-04-23 14:10:54 -07:00
parent f68e194a4a
commit 496a6d3a05
2 changed files with 46 additions and 24 deletions

View File

@ -10,6 +10,7 @@ import (
"filippo.io/age"
log "github.com/schollz/logger"
"golang.org/x/crypto/bcrypt"
"github.com/schollz/croc/v9/src/comm"
"github.com/schollz/croc/v9/src/crypt"
@ -160,12 +161,26 @@ func (s *server) clientCommunication(port string, c *comm.Comm) (room string, er
if err != nil {
return
}
if bytes.Equal(retBytesEnc, []byte("ping")) {
room = pingRoom
c.Send([]byte("pong"))
return
}
retBytes, err := crypt.DecryptAge(retBytesEnc, s.keyPrivate)
if err != nil {
return
}
// check whether we have a valid public key from client
keyPublic := string(retBytes)
foo := bytes.Split(retBytes, []byte("--"))
keyPublic := string(foo[0])
hashedPassword := foo[1]
err = bcrypt.CompareHashAndPassword(hashedPassword, []byte(s.password))
if err != nil {
err = fmt.Errorf("bad password")
return
}
_, err = age.ParseX25519Recipient(keyPublic)
if err != nil {
err = fmt.Errorf("bad public key: %s", keyPublic)
@ -370,7 +385,7 @@ func PingServer(address string) (err error) {
// ConnectToTCPServer will initiate a new connection
// to the specified address, room with optional time limit
func ConnectToTCPServer(address, password, room string, timelimit ...time.Duration) (c *comm.Comm, banner string, ipaddr string, err error) {
func ConnectToTCPServer(address, password, keyPublicRelay, room string, timelimit ...time.Duration) (c *comm.Comm, banner string, ipaddr string, err error) {
if len(timelimit) > 0 {
c, err = comm.NewConnection(address, timelimit[0])
} else {
@ -383,12 +398,13 @@ func ConnectToTCPServer(address, password, room string, timelimit ...time.Durati
// generate ephermeral key
keyPublic, keyPrivate, err := crypt.NewAge()
// send epheremal public key, encrypted using the server's public key
foo := strings.Split(password, "--")
keyPublicRelay := foo[1]
password = foo[2]
// send epheremal public key + bcrypted password, encrypted using the server's public key
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), 10)
if err != nil {
return
}
sendBytesEnc, err := crypt.EncryptAge([]byte(keyPublic), keyPublicRelay)
sendBytesEnc, err := crypt.EncryptAge(append([]byte(keyPublic+"--"), hashedPassword...), keyPublicRelay)
if err != nil {
return
}

View File

@ -2,44 +2,50 @@ package tcp
import (
"bytes"
"fmt"
"testing"
"time"
"github.com/schollz/croc/v9/src/crypt"
log "github.com/schollz/logger"
"github.com/stretchr/testify/assert"
)
func BenchmarkConnection(b *testing.B) {
log.SetLevel("trace")
go Run("debug", "8283", "pass123", "8284")
time.Sleep(100 * time.Millisecond)
b.ResetTimer()
for i := 0; i < b.N; i++ {
c, _, _, _ := ConnectToTCPServer("localhost:8283", "pass123", fmt.Sprintf("testroom%d", i), 1*time.Minute)
c.Close()
}
}
// func BenchmarkConnection(b *testing.B) {
// log.SetLevel("trace")
// go Run("debug", "8283", "pass123", "8284")
// time.Sleep(100 * time.Millisecond)
// b.ResetTimer()
// for i := 0; i < b.N; i++ {
// c, _, _, _ := ConnectToTCPServer("localhost:8283", "pass123", fmt.Sprintf("testroom%d", i), 1*time.Minute)
// c.Close()
// }
// }
func TestTCP(t *testing.T) {
log.SetLevel("error")
timeToRoomDeletion = 100 * time.Millisecond
go Run("debug", "8281", "pass123", "8282")
keyPublic, keyPrivate, err := crypt.NewAge()
if err != nil {
panic(err)
}
go Run("debug", "8281", "pass123", keyPublic, keyPrivate, "8282")
time.Sleep(100 * time.Millisecond)
err := PingServer("localhost:8281")
err = PingServer("localhost:8281")
assert.Nil(t, err)
err = PingServer("localhost:8333")
assert.NotNil(t, err)
time.Sleep(100 * time.Millisecond)
c1, banner, _, err := ConnectToTCPServer("localhost:8281", "pass123", "testRoom", 1*time.Minute)
c1, banner, _, err := ConnectToTCPServer("localhost:8281", "pass123", keyPublic, "testRoom", 1*time.Minute)
assert.Equal(t, banner, "8282")
assert.Nil(t, err)
c2, _, _, err := ConnectToTCPServer("localhost:8281", "pass123", "testRoom")
c2, _, _, err := ConnectToTCPServer("localhost:8281", "pass123", keyPublic, "testRoom")
assert.Nil(t, err)
_, _, _, err = ConnectToTCPServer("localhost:8281", "pass123", "testRoom")
_, _, _, err = ConnectToTCPServer("localhost:8281", "pass123", keyPublic, "testRoom")
assert.NotNil(t, err)
_, _, _, err = ConnectToTCPServer("localhost:8281", "pass123", "testRoom", 1*time.Nanosecond)
_, _, _, err = ConnectToTCPServer("localhost:8281", "pass123", keyPublic, "testRoom", 1*time.Nanosecond)
assert.NotNil(t, err)
_, _, _, err = ConnectToTCPServer("localhost:8281", "pass123", keyPublic+"askldjfklsajdf", "testRoom", 1*time.Nanosecond)
assert.NotNil(t, err)
// try sending data