160 lines
3.5 KiB
Go
160 lines
3.5 KiB
Go
package utils
|
|
|
|
import (
|
|
"crypto/sha256"
|
|
"io"
|
|
"mime/multipart"
|
|
"net/http"
|
|
"strconv"
|
|
|
|
"github.com/pkg/errors"
|
|
)
|
|
|
|
const chunkSize int = 4096
|
|
|
|
// FileInfo used to send part filename and content type
|
|
type FileInfo struct {
|
|
FileID string
|
|
Filename string
|
|
ContentType string
|
|
Part *multipart.Part
|
|
FileSize uint64
|
|
OrigFileSize uint64
|
|
}
|
|
|
|
// FormParser type func that handles parsing of form fields values
|
|
type FormParser func(part *multipart.Part) error
|
|
type FileEncryptor func(input []byte) []byte
|
|
|
|
// FindFilePart finds file part of multipart upload
|
|
func FindFilePart(r *http.Request, boundary string, fileFieldName string, formParser FormParser) (*FileInfo, error) {
|
|
result := &FileInfo{}
|
|
mr := multipart.NewReader(r.Body, boundary)
|
|
|
|
for {
|
|
part, err := mr.NextPart()
|
|
if errors.Is(err, io.EOF) {
|
|
return result, nil
|
|
} else if err != nil {
|
|
return result, errors.WithStack(err)
|
|
}
|
|
|
|
if part.FormName() != fileFieldName || len(part.FileName()) == 0 {
|
|
if formParser != nil {
|
|
err = formParser(part)
|
|
if err != nil {
|
|
return result, err
|
|
}
|
|
}
|
|
part.Close()
|
|
continue
|
|
}
|
|
|
|
result.Filename = part.FileName()
|
|
result.ContentType = GetHTTPHeader(part.Header, "Content-Type", "")
|
|
len, _ := strconv.ParseInt(GetHTTPHeader(part.Header, "Content-Length", "0"), 10, 32)
|
|
result.OrigFileSize = uint64(len)
|
|
result.Part = part
|
|
|
|
return result, nil
|
|
}
|
|
}
|
|
|
|
// ChunkFilePart chunks file part into pipe writer
|
|
func ChunkFilePart(writer *io.PipeWriter, info *FileInfo, encryptor FileEncryptor) {
|
|
var err error
|
|
var origFileSize uint64
|
|
var fileSize uint64
|
|
defer writer.Close()
|
|
|
|
n := 0
|
|
buf := make([]byte, chunkSize)
|
|
|
|
for {
|
|
n, err = io.ReadFull(info.Part, buf)
|
|
if n > 0 {
|
|
if encryptor != nil {
|
|
out := encryptor(buf[:n])
|
|
writer.Write(out)
|
|
fileSize += uint64(len(out))
|
|
} else {
|
|
writer.Write(buf[:n])
|
|
fileSize += uint64(n)
|
|
}
|
|
origFileSize += uint64(n)
|
|
}
|
|
if errors.Is(err, io.EOF) {
|
|
if info != nil {
|
|
info.FileSize = fileSize
|
|
info.OrigFileSize = origFileSize
|
|
}
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// PartReadAll returns entire part value. Maxchars checks that max characters is not exceeded
|
|
func PartReadAll(part *multipart.Part, maxchars int) (string, error) {
|
|
var err error
|
|
n := 0
|
|
buf := make([]byte, chunkSize)
|
|
result := []byte{}
|
|
|
|
for {
|
|
n, err = part.Read(buf)
|
|
result = append(result, buf[:n]...)
|
|
if len(result) > maxchars {
|
|
return "", errors.Errorf("%s exceeded %d characters", part.FormName(), maxchars)
|
|
}
|
|
if errors.Is(err, io.EOF) {
|
|
return string(result), nil
|
|
}
|
|
}
|
|
}
|
|
|
|
func HashMultipartFile(file multipart.File) ([]byte, error) {
|
|
var err error
|
|
n := 0
|
|
buf := make([]byte, chunkSize)
|
|
digest := sha256.New()
|
|
|
|
for {
|
|
n, err = file.Read(buf)
|
|
if n > 0 {
|
|
_, err := digest.Write(buf[:n])
|
|
if err != nil {
|
|
return nil, errors.WithStack(err)
|
|
}
|
|
}
|
|
if errors.Is(err, io.EOF) {
|
|
break
|
|
}
|
|
}
|
|
sum := digest.Sum(nil)
|
|
return sum, nil
|
|
}
|
|
|
|
func PartReadInt64(part *multipart.Part, maxchars int, bitSize int) (int64, error) {
|
|
value, err := PartReadAll(part, maxchars)
|
|
if err != nil {
|
|
return 0, errors.WithStack(err)
|
|
}
|
|
return strconv.ParseInt(value, 10, bitSize)
|
|
}
|
|
|
|
func PartReadUInt64(part *multipart.Part, maxchars int, bitSize int) (uint64, error) {
|
|
value, err := PartReadAll(part, maxchars)
|
|
if err != nil {
|
|
return 0, errors.WithStack(err)
|
|
}
|
|
return strconv.ParseUint(value, 10, bitSize)
|
|
}
|
|
|
|
func PartReadInt(part *multipart.Part, maxchars int) (int, error) {
|
|
value, err := PartReadInt64(part, maxchars, 32)
|
|
if err != nil {
|
|
return 0, errors.WithStack(err)
|
|
}
|
|
return int(value), nil
|
|
}
|