mirror of https://github.com/schollz/croc.git
use magic bytes instead of maxbytes to check validity
This commit is contained in:
parent
e829ca0ff4
commit
1f7a72467e
|
@ -15,7 +15,7 @@ import (
|
||||||
|
|
||||||
var Socks5Proxy = ""
|
var Socks5Proxy = ""
|
||||||
|
|
||||||
const MAXBYTES = 4000000
|
var MAGIC_BYTES = []byte("croc")
|
||||||
|
|
||||||
// Comm is some basic TCP communication
|
// Comm is some basic TCP communication
|
||||||
type Comm struct {
|
type Comm struct {
|
||||||
|
@ -84,6 +84,7 @@ func (c *Comm) Write(b []byte) (n int, err error) {
|
||||||
fmt.Println("binary.Write failed:", err)
|
fmt.Println("binary.Write failed:", err)
|
||||||
}
|
}
|
||||||
tmpCopy := append(header.Bytes(), b...)
|
tmpCopy := append(header.Bytes(), b...)
|
||||||
|
tmpCopy = append(MAGIC_BYTES, tmpCopy...)
|
||||||
n, err = c.connection.Write(tmpCopy)
|
n, err = c.connection.Write(tmpCopy)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = fmt.Errorf("connection.Write failed: %w", err)
|
err = fmt.Errorf("connection.Write failed: %w", err)
|
||||||
|
@ -104,13 +105,25 @@ func (c *Comm) Read() (buf []byte, numBytes int, bs []byte, err error) {
|
||||||
// must clear the timeout setting
|
// must clear the timeout setting
|
||||||
defer c.connection.SetDeadline(time.Time{})
|
defer c.connection.SetDeadline(time.Time{})
|
||||||
|
|
||||||
// read until we get 4 bytes for the header
|
// read until we get 4 bytes for the magic
|
||||||
header := make([]byte, 4)
|
header := make([]byte, 4)
|
||||||
_, err = io.ReadFull(c.connection, header)
|
_, err = io.ReadFull(c.connection, header)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("initial read error: %v", err)
|
log.Debugf("initial read error: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if !bytes.Equal(header, MAGIC_BYTES) {
|
||||||
|
err = fmt.Errorf("initial bytes are not magic: %x", header)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// read until we get 4 bytes for the header
|
||||||
|
header = make([]byte, 4)
|
||||||
|
_, err = io.ReadFull(c.connection, header)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("initial read error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var numBytesUint32 uint32
|
var numBytesUint32 uint32
|
||||||
rbuf := bytes.NewReader(header)
|
rbuf := bytes.NewReader(header)
|
||||||
|
@ -121,11 +134,6 @@ func (c *Comm) Read() (buf []byte, numBytes int, bs []byte, err error) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
numBytes = int(numBytesUint32)
|
numBytes = int(numBytesUint32)
|
||||||
if numBytes > MAXBYTES {
|
|
||||||
err = fmt.Errorf("too many bytes: %d", numBytes)
|
|
||||||
log.Debug(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// shorten the reading deadline in case getting weird data
|
// shorten the reading deadline in case getting weird data
|
||||||
if err := c.connection.SetReadDeadline(time.Now().Add(10 * time.Second)); err != nil {
|
if err := c.connection.SetReadDeadline(time.Now().Add(10 * time.Second)); err != nil {
|
||||||
|
|
|
@ -11,7 +11,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestComm(t *testing.T) {
|
func TestComm(t *testing.T) {
|
||||||
token := make([]byte, MAXBYTES)
|
token := make([]byte, 3000)
|
||||||
if _, err := rand.Read(token); err != nil {
|
if _, err := rand.Read(token); err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue