diff --git a/fs.go b/fs.go index 77dce66..b5733cf 100644 --- a/fs.go +++ b/fs.go @@ -51,18 +51,24 @@ func (fs *FS) ServeHTTP(w http.ResponseWriter, r *http.Request, info RoutingInfo } // search for file in static file system - file, err := fs.findFile(url) + file, filePath, err := fs.findFile(url) if err != nil { http.NotFound(w, r) return } + // check, if the file is a folder + file, ok := fs.checkFolder(file, filePath) + if !ok { + http.NotFound(w, r) + return + } defer file.Close() // serve file content fs.serve(w, r, info, file) } -func (fs *FS) findFile(url string) (fs.File, error) { +func (fs *FS) findFile(url string) (fs.File, string, error) { // build file path filePath := path.Join(fs.FolderPrefix, url) filePath = strings.TrimPrefix(filePath, "/") @@ -71,15 +77,40 @@ func (fs *FS) findFile(url string) (fs.File, error) { if fs.UseLocalFolder { localFilePath := path.Join(fs.LocalFolderPrefix, filePath) localFilePath = strings.TrimPrefix(localFilePath, "/") + if localFilePath == "" { + localFilePath = "." + } file, err := os.Open(localFilePath) if err == nil { - return file, nil + return file, localFilePath, nil } } // use static file system - return fs.StaticFiles.Open(filePath) + if filePath == "" { + filePath = "." + } + file, err := fs.StaticFiles.Open(filePath) + return file, filePath, err +} + +func (fs *FS) checkFolder(file fs.File, url string) (fs.File, bool) { + fileInfo, err := file.Stat() + if err != nil { + _ = file.Close() + return file, false + } + + // check for folder / root location + if fileInfo.IsDir() { + // close folder handler (we don't need it anymore) + _ = file.Close() + + return file, false + } + + return file, true } func (fs *FS) serve(w http.ResponseWriter, r *http.Request, info RoutingInfo, file fs.File) { diff --git a/fs_test.go b/fs_test.go index 5b74c64..049acc8 100644 --- a/fs_test.go +++ b/fs_test.go @@ -114,6 +114,48 @@ func TestFSLicense(t *testing.T) { } } +func TestFSFolder(t *testing.T) { + // create new FS + fs := NewFS(&staticFiles) + + // serve the request + tw := &testWriter{} + tr := &http.Request{ + Method: http.MethodHead, + URL: &url.URL{ + Path: "/", + }, + } + fs.ServeHTTP(tw, tr, RoutingInfo{}) + + // check data + if tw.statusCode != http.StatusNotFound { + t.Error("received invalid http status code", tw.statusCode) + } +} + +func TestFSFolderLocal(t *testing.T) { + // create new FS + fs := NewFS(&staticFiles) + fs.UseLocalFolder = true + fs.LocalFolderPrefix = "/" + + // serve the request + tw := &testWriter{} + tr := &http.Request{ + Method: http.MethodHead, + URL: &url.URL{ + Path: "/", + }, + } + fs.ServeHTTP(tw, tr, RoutingInfo{}) + + // check data + if tw.statusCode != http.StatusNotFound { + t.Error("received invalid http status code", tw.statusCode) + } +} + func TestFSNotFound(t *testing.T) { // create new FS fs := NewFS(&staticFiles)