diff --git a/main.go b/main.go index f1d9237..64ea4e9 100644 --- a/main.go +++ b/main.go @@ -13,7 +13,21 @@ import ( func main() { // "github.com/pkg/profile" - // defer profile.Start(profile.CPUProfile).Stop() + // go func() { + // for { + // f, err := os.Create("croc.pprof") + // if err != nil { + // panic(err) + // } + // runtime.GC() // get up-to-date statistics + // if err := pprof.WriteHeapProfile(f); err != nil { + // panic(err) + // } + // f.Close() + // time.Sleep(3 * time.Second) + // fmt.Println("wrote profile") + // } + // }() if err := cli.Run(); err != nil { fmt.Println(err) } diff --git a/src/comm/comm.go b/src/comm/comm.go index 4976126..eafd6d7 100644 --- a/src/comm/comm.go +++ b/src/comm/comm.go @@ -8,8 +8,11 @@ import ( "time" "github.com/pkg/errors" + "github.com/schollz/logger" ) +const MAXBYTES = 1000000 + // Comm is some basic TCP communication type Comm struct { connection net.Conn @@ -89,12 +92,18 @@ func (c *Comm) Read() (buf []byte, numBytes int, bs []byte, err error) { rbuf := bytes.NewReader(header) err = binary.Read(rbuf, binary.LittleEndian, &numBytesUint32) if err != nil { - fmt.Println("binary.Read failed:", err) + err = fmt.Errorf("binary.Read failed: %s", err.Error()) + return } numBytes = int(numBytesUint32) + if numBytes > MAXBYTES { + err = fmt.Errorf("too many bytes: %d", numBytes) + logger.Error(err) + return + } buf = make([]byte, 0) for { - // log.Debugf("bytes: %d/%d",len(buf),numBytes) + // log.Debugf("bytes: %d/%d", len(buf), numBytes) tmp := make([]byte, numBytes-len(buf)) n, errRead := c.connection.Read(tmp) if errRead != nil {