|
@@ -523,6 +523,13 @@ func tryDownload(inputURL string, filename string, path string, message *discord
|
|
|
}
|
|
|
downloadTime := time.Now()
|
|
|
|
|
|
+ // Read
|
|
|
+ bodyOfResp, err := ioutil.ReadAll(response.Body)
|
|
|
+ if err != nil {
|
|
|
+ log.Println(logPrefixErrorHere, color.HiRedString("Could not read response from \"%s\": %s", inputURL, err))
|
|
|
+ return mDownloadStatus(downloadFailedReadResponse, err)
|
|
|
+ }
|
|
|
+
|
|
|
// Filename
|
|
|
if filename == "" {
|
|
|
filename = filenameFromURL(response.Request.URL.String())
|
|
@@ -542,19 +549,18 @@ func tryDownload(inputURL string, filename string, path string, message *discord
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- // Read
|
|
|
- bodyOfResp, err := ioutil.ReadAll(response.Body)
|
|
|
- if err != nil {
|
|
|
- log.Println(logPrefixErrorHere, color.HiRedString("Could not read response from \"%s\": %s", inputURL, err))
|
|
|
- return mDownloadStatus(downloadFailedReadResponse, err)
|
|
|
- }
|
|
|
+ extension := strings.ToLower(filepath.Ext(filename))
|
|
|
|
|
|
contentType := http.DetectContentType(bodyOfResp)
|
|
|
contentTypeParts := strings.Split(contentType, "/")
|
|
|
contentTypeFound := contentTypeParts[0]
|
|
|
|
|
|
+ parsedURL, err := url.Parse(inputURL)
|
|
|
+ if err != nil {
|
|
|
+ log.Println(logPrefixErrorHere, color.RedString("Error while parsing url:\t%s", err))
|
|
|
+ }
|
|
|
+
|
|
|
// Check extension
|
|
|
- extension := strings.ToLower(filepath.Ext(filename))
|
|
|
if stringInSlice(extension, *channelConfig.ExtensionBlacklist) || stringInSlice(extension, []string{".com", ".net", ".org"}) {
|
|
|
if !historyCmd {
|
|
|
log.Println(logPrefixFileSkip, color.GreenString("Unpermitted extension (%s) found at %s", extension, inputURL))
|
|
@@ -593,14 +599,13 @@ func tryDownload(inputURL string, filename string, path string, message *discord
|
|
|
|
|
|
// Check Domain
|
|
|
if channelConfig.DomainBlacklist != nil {
|
|
|
- u, err := url.Parse(inputURL)
|
|
|
- if err != nil {
|
|
|
- log.Println(logPrefixErrorHere, color.RedString("Error while parsing url for DomainBlacklist:\t%s", err))
|
|
|
- } else if stringInSlice(u.Hostname(), *channelConfig.DomainBlacklist) {
|
|
|
- if !historyCmd {
|
|
|
- log.Println(logPrefixFileSkip, color.GreenString("Unpermitted domain (%s) found at %s", u.Hostname(), inputURL))
|
|
|
+ if parsedURL != nil {
|
|
|
+ if stringInSlice(parsedURL.Hostname(), *channelConfig.DomainBlacklist) {
|
|
|
+ if !historyCmd {
|
|
|
+ log.Println(logPrefixFileSkip, color.GreenString("Unpermitted domain (%s) found at %s", parsedURL.Hostname(), inputURL))
|
|
|
+ }
|
|
|
+ return mDownloadStatus(downloadSkippedUnpermittedDomain)
|
|
|
}
|
|
|
- return mDownloadStatus(downloadSkippedUnpermittedDomain)
|
|
|
}
|
|
|
}
|
|
|
|