view ServerMonitor/Objects/ServerMonitor.cs @ 5:b6fe203af9d5

Private key passwords and validation
author Brad Greco <brad@bgreco.net>
date Thu, 28 Feb 2019 21:19:32 -0500
parents 3142e52cbe69
children c1dffaac66fa
line wrap: on
line source

using Renci.SshNet;
using Renci.SshNet.Common;
using ServerMonitorApp.Properties;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Net.NetworkInformation;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using System.Windows.Forms;
using System.Xml.Serialization;

namespace ServerMonitorApp
{
    public class ServerMonitor
    {
        private readonly string configFileDir;
        private readonly Logger logger;
        private readonly Dictionary<int, CancellationTokenSource> tokens = new Dictionary<int, CancellationTokenSource>();
        private readonly Dictionary<string, PrivateKeyFile> privateKeys = new Dictionary<string, PrivateKeyFile>();
        private readonly List<int> pausedChecks = new List<int>();
        private bool running, networkAvailable;
        private Dictionary<Task<CheckResult>, int> tasks;
        private ServerSummaryForm mainForm;

        //private List<Task<CheckResult>> tasks;

        public event EventHandler<CheckStatusChangedEventArgs> CheckStatusChanged;

        public List<Server> Servers { get; private set; } = new List<Server>();

        public IEnumerable<Check> Checks => Servers.SelectMany(s => s.Checks);

        public string ConfigFile { get; private set; }

        public IEnumerable<string> LockedKeys { get { return privateKeys.Where(kvp => kvp.Value == null).Select(kvp => kvp.Key); } }

        public ServerMonitor(ServerSummaryForm mainForm)
        {
            this.mainForm = mainForm;
            configFileDir = Path.Combine(Environment.GetFolderPath(Environment.SpecialFolder.ApplicationData), "ServerMonitor");
            ConfigFile = Path.Combine(configFileDir, "servers.xml");
            logger = new Logger(Path.Combine(configFileDir, "monitor.log"));
        }

        public void AddServer(Server server)
        {
            Servers.Add(server);
            SaveServers();
        }

        public void DeleteServer(Server server)
        {
            Servers.Remove(server);
            SaveServers();
        }

        public void LoadServers()
        {
            TextReader reader = null;
            try
            {
                reader = new StreamReader(ConfigFile);
                XmlSerializer serializer = CreateXmlSerializer();
                Servers.Clear();
                Servers.AddRange((List<Server>)serializer.Deserialize(reader));
                // Update the Checks so they know what Server they belong to.
                // Would rather do this in the Server object on deserialization, but
                // that doesn't work when using the XML serializer for some reason.
                foreach (Server server in Servers)
                {
                    if (server.LoginType == LoginType.PrivateKey)
                        OpenPrivateKey(server.KeyFile);
                    foreach (Check check in server.Checks)
                    {
                        check.Server = server;
                        if (check.Status == CheckStatus.Running)
                            check.Status = check.LastRunStatus;
                    }
                    server.CheckModified += Server_CheckModified;
                    server.EnabledChanged += Server_EnabledChanged;
                }
            }
            // If the file doesn't exist, no special handling is needed. It will be created later.
            catch (FileNotFoundException) { }
            catch (DirectoryNotFoundException) { }
            catch (InvalidOperationException)
            {
                //TODO log
                throw;
            }
            finally
            {
                reader?.Close();
            }
            NetworkChange.NetworkAddressChanged += NetworkChange_NetworkAddressChanged;
            Run();
        }

        public void SaveServers()
        {
            GenerateIds();
            TextWriter writer = null;
            XmlSerializer serializer = null;
            try
            {
                writer = new StreamWriter(ConfigFile);
                serializer = CreateXmlSerializer();
                serializer.Serialize(writer, Servers);
            }
            catch (DirectoryNotFoundException)
            {
                Directory.CreateDirectory(configFileDir);
                writer = new StreamWriter(ConfigFile);
                serializer = CreateXmlSerializer();
                serializer.Serialize(writer, Servers);
            }
            catch (Exception)
            {
                //TODO log
                throw;
            }
            finally
            {
                writer?.Close();
            }
        }

        private async void Run()
        {
            if (running)
                return;
            running = true;
            networkAvailable = Helpers.IsNetworkAvailable();
            if (networkAvailable)
            {
                foreach (int id in pausedChecks)
                {
                    await ExecuteCheckAsync(Checks.FirstOrDefault(c => c.Id == id));
                }
                pausedChecks.Clear();
            }
            //TODO subscribe to power events. Find any check's NextExecutionTime is in the past. Cancel waiting task and run immediately (or after short delay).
            //tasks = Checks.Select(c => ScheduleExecuteCheckAsync(c)).ToList();
            tasks = Checks.ToDictionary(c => ScheduleExecuteCheckAsync(c), c => c.Id);
            while (tasks.Count > 0)
            {
                Task<CheckResult> task = await Task.WhenAny(tasks.Keys);
                tasks.Remove(task);
                try
                {
                    CheckResult result = await task;
                    // Result will be null if a scheduled check was disabled
                    if (result != null && result.CheckStatus != CheckStatus.Disabled)
                        tasks.Add(ScheduleExecuteCheckAsync(result.Check), result.Check.Id);
                }
                catch (OperationCanceledException)
                {

                }
            }
            running = false;
        }

        public async Task<CheckResult> ExecuteCheckAsync(Check check, CancellationToken token = default(CancellationToken))
        {
            check.Status = CheckStatus.Running;
            OnCheckStatusChanged(check);
            CheckResult result = await check.ExecuteAsync(token);
            OnCheckStatusChanged(check, result);
            HandleResultAsync(result);
            return result;
        }

        private void HandleResultAsync(CheckResult result)
        {
            logger.Log(result);
            if (result.FailAction == FailAction.FlashTaskbar)
                mainForm.AlertServerForm(result.Check);
            if (result.FailAction.In(FailAction.FlashTaskbar, FailAction.NotificationBalloon))
                mainForm.ShowBalloon(result);
        }

        public IList<CheckResult> GetLog(Server server)
        {
            return logger.Read(server);
        }

        private void OnCheckStatusChanged(Check check, CheckResult result = null)
        {
            SaveServers();
            CheckStatusChanged?.Invoke(check, new CheckStatusChangedEventArgs(check, result));
        }

        private void Server_CheckModified(object sender, EventArgs e)
        {
            Check check = (Check)sender;
            Task<CheckResult> task = tasks.FirstOrDefault(kvp => kvp.Value == check.Id).Key;
            if (running)
            {
                if (task == null)
                {
                    // No tasks associated with the check, so schedule a new one
                    tasks.Add(ScheduleExecuteCheckAsync(check), check.Id);
                }
                else
                {
                    // Check was modified or deleted, so remove any waiting tasks
                    CancelCheck(check);
                    if (check.Server != null)
                    {
                        // If the check was not deleted, schedule the new check.
                        // But only if it's still running, otherwise restarting the monitor below
                        // will create a duplicate run.
                        if (running)
                            tasks.Add(ScheduleExecuteCheckAsync(check), check.Id);
                    }
                }
            }
            // Run again in case removing a task above caused it to stop
            Run();
        }

        private void Server_EnabledChanged(object sender, EventArgs e)
        {
            Server server = (Server)sender;
            if (server.Enabled)
            {
                Run();
            }
            else
            {
                foreach (Check check in server.Checks)
                {
                    CancelCheck(check);
                }
            }
        }

        private void CancelCheck(Check check)
        {
            Task<CheckResult> task = tasks.FirstOrDefault(kvp => kvp.Value == check.Id).Key;
            if (task != null)
                tasks.Remove(task);
            pausedChecks.RemoveAll(id => id == check.Id);
            if (tokens.TryGetValue(check.Id, out CancellationTokenSource cts))
                cts.Cancel();
        }

        private async Task<CheckResult> ScheduleExecuteCheckAsync(Check check)
        {
            if (!check.Enabled || !check.Server.Enabled)
                return await Task.FromResult(new CheckResult(check, CheckStatus.Disabled, null));

            CancellationTokenSource cts = new CancellationTokenSource();
            tokens[check.Id] = cts;
            check.NextRunTime = check.Schedule.GetNextTime(check.LastScheduledRunTime);
            await Task.Delay(check.NextRunTime - DateTime.Now, cts.Token);
            check.LastScheduledRunTime = check.NextRunTime;
            if (networkAvailable)
                return await ExecuteCheckAsync(check, cts.Token);
            else
            {
                if (!pausedChecks.Contains(check.Id))
                    pausedChecks.Add(check.Id);
                return await Task.FromResult(new CheckResult(check, CheckStatus.Disabled, null));
            }
        }

        private void NetworkChange_NetworkAddressChanged(object sender, EventArgs e)
        {
            networkAvailable = Helpers.IsNetworkAvailable();
            if (networkAvailable)
                mainForm.Invoke((MethodInvoker)(() => Run()));
        }

        public KeyStatus OpenPrivateKey(string path, string password = null)
        {
            KeyStatus keyStatus;
            if (path == null)
                return KeyStatus.NotAccessible;
            if (privateKeys.TryGetValue(path, out PrivateKeyFile key) && key != null)
                return KeyStatus.Open;
            try
            {
               key = new PrivateKeyFile(path, password);
                keyStatus = KeyStatus.Open;
            }
            catch (Exception e) when (e is SshPassPhraseNullOrEmptyException || e is InvalidOperationException)
            {
                keyStatus = KeyStatus.NeedPassword;
            }
            catch (Exception)
            {
                keyStatus = KeyStatus.NotAccessible;
            }
            foreach (Server server in Servers)
            {
                if (server.KeyFile == path)
                {
                    server.PrivateKeyFile = key;
                    server.KeyStatus = keyStatus;
                }
            }
            privateKeys[path] = key;
            return keyStatus;
        }

        private void GenerateIds()
        {
            if (Servers.Any())
            {
                int id = Servers.Max(s => s.Id);
                foreach (Server server in Servers)
                {
                    if (server.Id == 0)
                        server.Id = ++id;
                }
            }

            if (Checks.Any())
            {
                int id = Math.Max(Settings.Default.MaxCheckId, Checks.Max(c => c.Id));
                foreach (Check check in Checks)
                {
                    if (check.Id == 0)
                        check.Id = ++id;
                }
                Settings.Default.MaxCheckId = id;
                Settings.Default.Save();
            }
        }

        private XmlSerializer CreateXmlSerializer()
        {
            return new XmlSerializer(typeof(List<Server>), Check.CheckTypes);
        }
    }

    public class CheckStatusChangedEventArgs : EventArgs
    {
        public Check Check { get; private set; }

        public CheckResult CheckResult { get; private set; }

        public CheckStatusChangedEventArgs(Check check, CheckResult result)
        {
            Check = check;
            CheckResult = result;
        }
    }

    public enum FailAction { FlashTaskbar = 0, NotificationBalloon = 1, None = 10 }
}