use magic bytes instead of maxbytes to check validity

This commit is contained in:
Zack Scholl 2021-01-26 10:47:14 -08:00
parent e829ca0ff4
commit 1f7a72467e
2 changed files with 16 additions and 8 deletions

View File

@ -15,7 +15,7 @@ import (
var Socks5Proxy = "" var Socks5Proxy = ""
const MAXBYTES = 4000000 var MAGIC_BYTES = []byte("croc")
// Comm is some basic TCP communication // Comm is some basic TCP communication
type Comm struct { type Comm struct {
@ -84,6 +84,7 @@ func (c *Comm) Write(b []byte) (n int, err error) {
fmt.Println("binary.Write failed:", err) fmt.Println("binary.Write failed:", err)
} }
tmpCopy := append(header.Bytes(), b...) tmpCopy := append(header.Bytes(), b...)
tmpCopy = append(MAGIC_BYTES, tmpCopy...)
n, err = c.connection.Write(tmpCopy) n, err = c.connection.Write(tmpCopy)
if err != nil { if err != nil {
err = fmt.Errorf("connection.Write failed: %w", err) err = fmt.Errorf("connection.Write failed: %w", err)
@ -104,13 +105,25 @@ func (c *Comm) Read() (buf []byte, numBytes int, bs []byte, err error) {
// must clear the timeout setting // must clear the timeout setting
defer c.connection.SetDeadline(time.Time{}) defer c.connection.SetDeadline(time.Time{})
// read until we get 4 bytes for the header // read until we get 4 bytes for the magic
header := make([]byte, 4) header := make([]byte, 4)
_, err = io.ReadFull(c.connection, header) _, err = io.ReadFull(c.connection, header)
if err != nil { if err != nil {
log.Debugf("initial read error: %v", err) log.Debugf("initial read error: %v", err)
return return
} }
if !bytes.Equal(header, MAGIC_BYTES) {
err = fmt.Errorf("initial bytes are not magic: %x", header)
return
}
// read until we get 4 bytes for the header
header = make([]byte, 4)
_, err = io.ReadFull(c.connection, header)
if err != nil {
log.Debugf("initial read error: %v", err)
return
}
var numBytesUint32 uint32 var numBytesUint32 uint32
rbuf := bytes.NewReader(header) rbuf := bytes.NewReader(header)
@ -121,11 +134,6 @@ func (c *Comm) Read() (buf []byte, numBytes int, bs []byte, err error) {
return return
} }
numBytes = int(numBytesUint32) numBytes = int(numBytesUint32)
if numBytes > MAXBYTES {
err = fmt.Errorf("too many bytes: %d", numBytes)
log.Debug(err)
return
}
// shorten the reading deadline in case getting weird data // shorten the reading deadline in case getting weird data
if err := c.connection.SetReadDeadline(time.Now().Add(10 * time.Second)); err != nil { if err := c.connection.SetReadDeadline(time.Now().Add(10 * time.Second)); err != nil {

View File

@ -11,7 +11,7 @@ import (
) )
func TestComm(t *testing.T) { func TestComm(t *testing.T) {
token := make([]byte, MAXBYTES) token := make([]byte, 3000)
if _, err := rand.Read(token); err != nil { if _, err := rand.Read(token); err != nil {
t.Error(err) t.Error(err)
} }