package utils import ( "crypto/sha256" "io" "mime/multipart" "net/http" "strconv" "github.com/pkg/errors" ) const chunkSize int = 4096 // FileInfo used to send part filename and content type type FileInfo struct { FileID string Filename string ContentType string Part *multipart.Part FileSize uint64 OrigFileSize uint64 } // FormParser type func that handles parsing of form fields values type FormParser func(part *multipart.Part) error type FileEncryptor func(input []byte) []byte // FindFilePart finds file part of multipart upload func FindFilePart(r *http.Request, boundary string, fileFieldName string, formParser FormParser) (*FileInfo, error) { result := &FileInfo{} mr := multipart.NewReader(r.Body, boundary) for { part, err := mr.NextPart() if errors.Is(err, io.EOF) { return result, nil } else if err != nil { return result, errors.WithStack(err) } if part.FormName() != fileFieldName || len(part.FileName()) == 0 { if formParser != nil { err = formParser(part) if err != nil { return result, err } } part.Close() continue } result.Filename = part.FileName() result.ContentType = GetHTTPHeader(part.Header, "Content-Type", "") len, _ := strconv.ParseInt(GetHTTPHeader(part.Header, "Content-Length", "0"), 10, 32) result.OrigFileSize = uint64(len) result.Part = part return result, nil } } // ChunkFilePart chunks file part into pipe writer func ChunkFilePart(writer *io.PipeWriter, info *FileInfo, encryptor FileEncryptor) { var err error var origFileSize uint64 var fileSize uint64 defer writer.Close() n := 0 buf := make([]byte, chunkSize) for { n, err = io.ReadFull(info.Part, buf) if n > 0 { if encryptor != nil { out := encryptor(buf[:n]) writer.Write(out) fileSize += uint64(len(out)) } else { writer.Write(buf[:n]) fileSize += uint64(n) } origFileSize += uint64(n) } if errors.Is(err, io.EOF) { if info != nil { info.FileSize = fileSize info.OrigFileSize = origFileSize } return } } } // PartReadAll returns entire part value. Maxchars checks that max characters is not exceeded func PartReadAll(part *multipart.Part, maxchars int) (string, error) { var err error n := 0 buf := make([]byte, chunkSize) result := []byte{} for { n, err = part.Read(buf) result = append(result, buf[:n]...) if len(result) > maxchars { return "", errors.Errorf("%s exceeded %d characters", part.FormName(), maxchars) } if errors.Is(err, io.EOF) { return string(result), nil } } } func HashMultipartFile(file multipart.File) ([]byte, error) { var err error n := 0 buf := make([]byte, chunkSize) digest := sha256.New() for { n, err = file.Read(buf) if n > 0 { _, err := digest.Write(buf[:n]) if err != nil { return nil, errors.WithStack(err) } } if errors.Is(err, io.EOF) { break } } sum := digest.Sum(nil) return sum, nil } func PartReadInt64(part *multipart.Part, maxchars int, bitSize int) (int64, error) { value, err := PartReadAll(part, maxchars) if err != nil { return 0, errors.WithStack(err) } return strconv.ParseInt(value, 10, bitSize) } func PartReadUInt64(part *multipart.Part, maxchars int, bitSize int) (uint64, error) { value, err := PartReadAll(part, maxchars) if err != nil { return 0, errors.WithStack(err) } return strconv.ParseUint(value, 10, bitSize) } func PartReadInt(part *multipart.Part, maxchars int) (int, error) { value, err := PartReadInt64(part, maxchars, 32) if err != nil { return 0, errors.WithStack(err) } return int(value), nil }