mirror of https://github.com/schollz/croc.git
use magic bytes instead of maxbytes to check validity
This commit is contained in:
parent
e829ca0ff4
commit
1f7a72467e
|
@ -15,7 +15,7 @@ import (
|
|||
|
||||
var Socks5Proxy = ""
|
||||
|
||||
const MAXBYTES = 4000000
|
||||
var MAGIC_BYTES = []byte("croc")
|
||||
|
||||
// Comm is some basic TCP communication
|
||||
type Comm struct {
|
||||
|
@ -84,6 +84,7 @@ func (c *Comm) Write(b []byte) (n int, err error) {
|
|||
fmt.Println("binary.Write failed:", err)
|
||||
}
|
||||
tmpCopy := append(header.Bytes(), b...)
|
||||
tmpCopy = append(MAGIC_BYTES, tmpCopy...)
|
||||
n, err = c.connection.Write(tmpCopy)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("connection.Write failed: %w", err)
|
||||
|
@ -104,13 +105,25 @@ func (c *Comm) Read() (buf []byte, numBytes int, bs []byte, err error) {
|
|||
// must clear the timeout setting
|
||||
defer c.connection.SetDeadline(time.Time{})
|
||||
|
||||
// read until we get 4 bytes for the header
|
||||
// read until we get 4 bytes for the magic
|
||||
header := make([]byte, 4)
|
||||
_, err = io.ReadFull(c.connection, header)
|
||||
if err != nil {
|
||||
log.Debugf("initial read error: %v", err)
|
||||
return
|
||||
}
|
||||
if !bytes.Equal(header, MAGIC_BYTES) {
|
||||
err = fmt.Errorf("initial bytes are not magic: %x", header)
|
||||
return
|
||||
}
|
||||
|
||||
// read until we get 4 bytes for the header
|
||||
header = make([]byte, 4)
|
||||
_, err = io.ReadFull(c.connection, header)
|
||||
if err != nil {
|
||||
log.Debugf("initial read error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var numBytesUint32 uint32
|
||||
rbuf := bytes.NewReader(header)
|
||||
|
@ -121,11 +134,6 @@ func (c *Comm) Read() (buf []byte, numBytes int, bs []byte, err error) {
|
|||
return
|
||||
}
|
||||
numBytes = int(numBytesUint32)
|
||||
if numBytes > MAXBYTES {
|
||||
err = fmt.Errorf("too many bytes: %d", numBytes)
|
||||
log.Debug(err)
|
||||
return
|
||||
}
|
||||
|
||||
// shorten the reading deadline in case getting weird data
|
||||
if err := c.connection.SetReadDeadline(time.Now().Add(10 * time.Second)); err != nil {
|
||||
|
|
|
@ -11,7 +11,7 @@ import (
|
|||
)
|
||||
|
||||
func TestComm(t *testing.T) {
|
||||
token := make([]byte, MAXBYTES)
|
||||
token := make([]byte, 3000)
|
||||
if _, err := rand.Read(token); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue