mirror of https://github.com/schollz/croc.git
bug fix: prevent crazy number of bytes getting into comm
This commit is contained in:
parent
44c3d43fa0
commit
b60a841044
16
main.go
16
main.go
|
@ -13,7 +13,21 @@ import (
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
// "github.com/pkg/profile"
|
// "github.com/pkg/profile"
|
||||||
// defer profile.Start(profile.CPUProfile).Stop()
|
// go func() {
|
||||||
|
// for {
|
||||||
|
// f, err := os.Create("croc.pprof")
|
||||||
|
// if err != nil {
|
||||||
|
// panic(err)
|
||||||
|
// }
|
||||||
|
// runtime.GC() // get up-to-date statistics
|
||||||
|
// if err := pprof.WriteHeapProfile(f); err != nil {
|
||||||
|
// panic(err)
|
||||||
|
// }
|
||||||
|
// f.Close()
|
||||||
|
// time.Sleep(3 * time.Second)
|
||||||
|
// fmt.Println("wrote profile")
|
||||||
|
// }
|
||||||
|
// }()
|
||||||
if err := cli.Run(); err != nil {
|
if err := cli.Run(); err != nil {
|
||||||
fmt.Println(err)
|
fmt.Println(err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,8 +8,11 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
"github.com/schollz/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const MAXBYTES = 1000000
|
||||||
|
|
||||||
// Comm is some basic TCP communication
|
// Comm is some basic TCP communication
|
||||||
type Comm struct {
|
type Comm struct {
|
||||||
connection net.Conn
|
connection net.Conn
|
||||||
|
@ -89,12 +92,18 @@ func (c *Comm) Read() (buf []byte, numBytes int, bs []byte, err error) {
|
||||||
rbuf := bytes.NewReader(header)
|
rbuf := bytes.NewReader(header)
|
||||||
err = binary.Read(rbuf, binary.LittleEndian, &numBytesUint32)
|
err = binary.Read(rbuf, binary.LittleEndian, &numBytesUint32)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println("binary.Read failed:", err)
|
err = fmt.Errorf("binary.Read failed: %s", err.Error())
|
||||||
|
return
|
||||||
}
|
}
|
||||||
numBytes = int(numBytesUint32)
|
numBytes = int(numBytesUint32)
|
||||||
|
if numBytes > MAXBYTES {
|
||||||
|
err = fmt.Errorf("too many bytes: %d", numBytes)
|
||||||
|
logger.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
buf = make([]byte, 0)
|
buf = make([]byte, 0)
|
||||||
for {
|
for {
|
||||||
// log.Debugf("bytes: %d/%d",len(buf),numBytes)
|
// log.Debugf("bytes: %d/%d", len(buf), numBytes)
|
||||||
tmp := make([]byte, numBytes-len(buf))
|
tmp := make([]byte, numBytes-len(buf))
|
||||||
n, errRead := c.connection.Read(tmp)
|
n, errRead := c.connection.Read(tmp)
|
||||||
if errRead != nil {
|
if errRead != nil {
|
||||||
|
|
Loading…
Reference in New Issue