Files
cloud-services/pkg/utils/multipart_parser.go

160 lines
3.5 KiB
Go

package utils
import (
"crypto/sha256"
"io"
"mime/multipart"
"net/http"
"strconv"
"github.com/pkg/errors"
)
const chunkSize int = 4096
// FileInfo used to send part filename and content type
type FileInfo struct {
FileID string
Filename string
ContentType string
Part *multipart.Part
FileSize uint64
OrigFileSize uint64
}
// FormParser type func that handles parsing of form fields values
type FormParser func(part *multipart.Part) error
type FileEncryptor func(input []byte) []byte
// FindFilePart finds file part of multipart upload
func FindFilePart(r *http.Request, boundary string, fileFieldName string, formParser FormParser) (*FileInfo, error) {
result := &FileInfo{}
mr := multipart.NewReader(r.Body, boundary)
for {
part, err := mr.NextPart()
if errors.Is(err, io.EOF) {
return result, nil
} else if err != nil {
return result, errors.WithStack(err)
}
if part.FormName() != fileFieldName || len(part.FileName()) == 0 {
if formParser != nil {
err = formParser(part)
if err != nil {
return result, err
}
}
part.Close()
continue
}
result.Filename = part.FileName()
result.ContentType = GetHTTPHeader(part.Header, "Content-Type", "")
len, _ := strconv.ParseInt(GetHTTPHeader(part.Header, "Content-Length", "0"), 10, 32)
result.OrigFileSize = uint64(len)
result.Part = part
return result, nil
}
}
// ChunkFilePart chunks file part into pipe writer
func ChunkFilePart(writer *io.PipeWriter, info *FileInfo, encryptor FileEncryptor) {
var err error
var origFileSize uint64
var fileSize uint64
defer writer.Close()
n := 0
buf := make([]byte, chunkSize)
for {
n, err = io.ReadFull(info.Part, buf)
if n > 0 {
if encryptor != nil {
out := encryptor(buf[:n])
writer.Write(out)
fileSize += uint64(len(out))
} else {
writer.Write(buf[:n])
fileSize += uint64(n)
}
origFileSize += uint64(n)
}
if errors.Is(err, io.EOF) {
if info != nil {
info.FileSize = fileSize
info.OrigFileSize = origFileSize
}
return
}
}
}
// PartReadAll returns entire part value. Maxchars checks that max characters is not exceeded
func PartReadAll(part *multipart.Part, maxchars int) (string, error) {
var err error
n := 0
buf := make([]byte, chunkSize)
result := []byte{}
for {
n, err = part.Read(buf)
result = append(result, buf[:n]...)
if len(result) > maxchars {
return "", errors.Errorf("%s exceeded %d characters", part.FormName(), maxchars)
}
if errors.Is(err, io.EOF) {
return string(result), nil
}
}
}
func HashMultipartFile(file multipart.File) ([]byte, error) {
var err error
n := 0
buf := make([]byte, chunkSize)
digest := sha256.New()
for {
n, err = file.Read(buf)
if n > 0 {
_, err := digest.Write(buf[:n])
if err != nil {
return nil, errors.WithStack(err)
}
}
if errors.Is(err, io.EOF) {
break
}
}
sum := digest.Sum(nil)
return sum, nil
}
func PartReadInt64(part *multipart.Part, maxchars int, bitSize int) (int64, error) {
value, err := PartReadAll(part, maxchars)
if err != nil {
return 0, errors.WithStack(err)
}
return strconv.ParseInt(value, 10, bitSize)
}
func PartReadUInt64(part *multipart.Part, maxchars int, bitSize int) (uint64, error) {
value, err := PartReadAll(part, maxchars)
if err != nil {
return 0, errors.WithStack(err)
}
return strconv.ParseUint(value, 10, bitSize)
}
func PartReadInt(part *multipart.Part, maxchars int) (int, error) {
value, err := PartReadInt64(part, maxchars, 32)
if err != nil {
return 0, errors.WithStack(err)
}
return int(value), nil
}