From 1f7a72467e5ea285fc89616f53fa45f10f589d1e Mon Sep 17 00:00:00 2001 From: Zack Scholl Date: Tue, 26 Jan 2021 10:47:14 -0800 Subject: [PATCH] use magic bytes instead of maxbytes to check validity --- src/comm/comm.go | 22 +++++++++++++++------- src/comm/comm_test.go | 2 +- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/src/comm/comm.go b/src/comm/comm.go index d1b77db..4d5fafe 100644 --- a/src/comm/comm.go +++ b/src/comm/comm.go @@ -15,7 +15,7 @@ import ( var Socks5Proxy = "" -const MAXBYTES = 4000000 +var MAGIC_BYTES = []byte("croc") // Comm is some basic TCP communication type Comm struct { @@ -84,6 +84,7 @@ func (c *Comm) Write(b []byte) (n int, err error) { fmt.Println("binary.Write failed:", err) } tmpCopy := append(header.Bytes(), b...) + tmpCopy = append(MAGIC_BYTES, tmpCopy...) n, err = c.connection.Write(tmpCopy) if err != nil { err = fmt.Errorf("connection.Write failed: %w", err) @@ -104,13 +105,25 @@ func (c *Comm) Read() (buf []byte, numBytes int, bs []byte, err error) { // must clear the timeout setting defer c.connection.SetDeadline(time.Time{}) - // read until we get 4 bytes for the header + // read until we get 4 bytes for the magic header := make([]byte, 4) _, err = io.ReadFull(c.connection, header) if err != nil { log.Debugf("initial read error: %v", err) return } + if !bytes.Equal(header, MAGIC_BYTES) { + err = fmt.Errorf("initial bytes are not magic: %x", header) + return + } + + // read until we get 4 bytes for the header + header = make([]byte, 4) + _, err = io.ReadFull(c.connection, header) + if err != nil { + log.Debugf("initial read error: %v", err) + return + } var numBytesUint32 uint32 rbuf := bytes.NewReader(header) @@ -121,11 +134,6 @@ func (c *Comm) Read() (buf []byte, numBytes int, bs []byte, err error) { return } numBytes = int(numBytesUint32) - if numBytes > MAXBYTES { - err = fmt.Errorf("too many bytes: %d", numBytes) - log.Debug(err) - return - } // shorten the reading deadline in case getting weird data if err := c.connection.SetReadDeadline(time.Now().Add(10 * time.Second)); err != nil { diff --git a/src/comm/comm_test.go b/src/comm/comm_test.go index 9037075..e69ae73 100644 --- a/src/comm/comm_test.go +++ b/src/comm/comm_test.go @@ -11,7 +11,7 @@ import ( ) func TestComm(t *testing.T) { - token := make([]byte, MAXBYTES) + token := make([]byte, 3000) if _, err := rand.Read(token); err != nil { t.Error(err) }