#!/usr/bin/perl -w

# Copyright 2004 Victor Danilchenko <danilche@cs.umass.edu>
#
# This code may be distributed under the terms of GPL version 2,
# or at your option any subsequent version.

use lib "/exp/rcf/common";
use strict;
use Socket;
use Fcntl;
use Fcntl qw(:DEFAULT);
use Getopt::Long;
use IO::Socket::INET;
use POSIX ":errno_h";

$SIG{CHLD} = 'IGNORE';

my $name = "sshd_sentry_server";
my $pidfile = "/var/run/$name.pid";
my $spooldir = "/var/spool/$name";
my $hosts = {};
my ($help, $restart, $threshold, $duration, $server_port, $interval, $verbose);

my $threshold_default = 2;
my $duration_default = "1 day";
my $server_port_default = 6094;
my $interval_default = "1 min";
my $verbose_default = 0;
my $blacklist = {};
my $clients = {};
my $last_distrib = 0;
my $children = {};
my %old_blist = ();

my $lhost = (`/bin/hostname`)[0]; chomp $lhost;
my $shost = (split(/\./, $lhost))[0];
my $domain = 'cs.umass.edu';
my $mail_server = "mail.$domain";
my @sysmail = ("sysscript\@$domain"); 


sub help () {
    my $filr = " " x length($name);
    return << "EOT";
Usage: $name [-h | --help]
       $filr [-r | --restart ]
       $filr [-t | --threshold <threshold number of failures> ]
       $filr [-d | --duration <duration to disable host for> ]
       $filr [-s | --server_port <port> ]
       $filr [-i | --interval <duration between distributions>]
       $filr [-v | --verbose <verbosity level> ]

help        Show this message
restart     Focibly restart $name, kill current process if needed
threshold   Number of times a given host is reported, before it\'s
            blacklisted globally
            default: $threshold_default.
duration    Duration of time for which the host which went over the failure
            threshold should be blocked. Must be a number followed by units
            (e.g. '1 hr' or '3 days'). Unqualified number is treated as hours.
	    default: $duration_default
server_port Port on which to listen to blacklisting reports
	    default: $server_port_default
interval    length of time to wait between distributing blacklist updates
            default: $interval_default
verbose     verbosity level. Can be indicated either numerically, or by
            using multiple '-v' options.
	    default: $verbose_default
EOT
}

sub log_prefix() {
    return  localtime()." $name\[$$\]: ";
}

sub log_out {
    my $prefix = log_prefix();
    my @out = @_;
    s/^/$prefix/gsm for @out;
    return @out;
}

sub log_strings {
    my @out = log_out (@_);
    print STDERR (join("\n", @out), "\n");
}

sub mail_to_users {
    my $text = shift;
    my $subject = shift;
    my @users = @_; @users = @sysmail unless @users;

    my $socket=IO::Socket::INET->new("$mail_server:25");
    #my $socket = \*STDOUT;
    print $socket ("HELO $shost.$domain\n");
    print $socket ("MAIL FROM: root\@$shost.$domain\n");
    print $socket ("RCPT TO: ", join ("\nRCPT TO: ", @users), "\n");
    print $socket ("DATA\n");
    print $socket ("To: ", join (",", @users), "\n");
    print $socket ("Subject: $subject\n\n");
    print $socket($text);
    print $socket ("\n.\nQUIT\n");
    close $socket;
}

sub die_with_mail($;@) {
    my $text = shift;
    my @users = @_; @users = @sysmail unless @users;
    my $subject = "$name died on $shost";
    mail_to_users ($text, $subject, @users);
    if (-t STDIN) { die $text;}
    else          { exit 1;   }
}

sub max ($;@) { my $max = shift; for my $i (@_) { $max = $i if $max < $i; } return $max;}
sub min ($;@) { my $min = shift; for my $i (@_) { $min = $i if $min > $i; } return $min;}

sub ObjectContents ($;$$$);

sub ObjectContents ($;$$$) {
    # Print the content of the hash reference. Recurse on any members
    # which are also hash references, until the specified $depth is
    # reached. Useful for examining object content.
    #
    # $offset is used internally to offset children from a parent for
    # easier visual processing by humans

    my $object = shift;
    my $depth  = shift; if (!defined $depth)  { $depth  = 1   ;}
    my $offset = shift; if (!defined $offset) { $offset = "  ";}
    my $force_long_array_output = shift || 0;
    my $output = "";

    if ($depth > 0) {
        if ("$object" =~ /HASH/) {
            my $maxlen = 0;
            foreach my $i (sort keys %$object) { $maxlen = max ($maxlen, length ($i));}
            foreach my $i (sort keys %$object) {
                my $head = sprintf ("%s%-${maxlen}s =",$offset, $i) ;
                $output .= "$head> ";
                if (defined $object->{$i}) {
                    my $string = $object->{$i};
                    if (ref ($string) =~ /array/i) {
                        $string = "$string   (".scalar @$string." members)";
                    } elsif (ref ($string) =~ /hash|::/i) {
                        $string = "$string   (".scalar keys (%$string)." members)";
                    }
                    my $pad = " " x length ($head); 
                    $string =~ s/\n/\n$pad\> /gm;
                    $output .= "$string\n";
                    $output .= ObjectContents ($object->{$i}, $depth - 1, $offset."   ", $force_long_array_output)
                        if ref $object->{$i};
#                   if ($depth > 1) { $output .= "\n";}
                } else { $output .= "undefined\n";}
            }
        } elsif ("$object" =~ /ARRAY/) {
            my $isref = 0;
            for (my $index = 0; $index < scalar @$object; ++$index) {
                my $head = sprintf ("%s%-4s =", $offset, "[".$index."]");
                $output .= "$head> ";
                if (defined $object->[$index]) {
                    my $string = $object->[$index];
                    if (ref ($string) =~ /array/i) {
                        $string = "$string   (".scalar @$string." members)";
                    } elsif (ref ($string) =~ /hash|::/i) {
                        $string = "$string   (".scalar keys (%$string)." members)";
                    }
                    my $pad = " " x length ($head); $string =~ s/\n$//m;
                    $string =~ s/\n/\n$pad\> /gm;
                    $output .= "$string\n";
                    $output .= ObjectContents ($object->[$index], $depth - 1, $offset."   ", $force_long_array_output
)
                        if ref $object->[$index];
#                   if ($depth > 1) { $output .= "\n";}             
                } else  { $output .= "undefined\n";}
            }
        } elsif ("$object" =~ /SCALAR/) {
	    my $string = $$object;
	    $string =~ s/\n\s*$//gsm;
	    my $pad = ($string =~ /\n/m) ? "> " : "";
	    $string =~ s/\n/\n$offset$pad/g;
	    $output .= sprintf ("%s%s%-12s\n",$offset, $pad, $string) ;
        } else {
            $output .= sprintf ("%s+%-12s\n",$offset, $object) ;
        }
    }
    return $output;
}

sub negotiate_pid ($) {
    my $restart = shift;
    # Negotiate over possible prior instances
    if (-s $pidfile) {
	# PID file exists and is not empty
	open (PID, $pidfile) or die "Cannot read PID file $pidfile\n";
	chomp (my $pid = <PID>);
	close PID;
	die "Corrupt PID file! (read '$pid' from it)\n" unless $pid =~ /^\d+$/;
	if (kill (0, $pid)) {
	    # The process is alive
	    if ($restart) {
		# We are gonna kill the current process
		kill (9, $pid);
		sleep 1;
		if (kill (0, $pid)) { die "Cannot kill predecessor, PID $pid\n";}
		else                { unlink $pidfile; }
	    } else {
		# There's another instance already running, leave it alone.
		exit 1;
	    }
	} else {
	    # PID file exists but the process is dead, proceed
	    unlink $pidfile;
	}
    } elsif (-e $pidfile) {
	# PID file exists but it empty, ignore it.
	unlink $pidfile;
    }

    if (-e $pidfile) { die "PID file $pidfile unepectedly exists!\n"; }
    elsif (open (PID, "> $pidfile")) {
	print PID "$$\n";
	close PID;
    } else { die "Couldn't write my PID ($$) to $pidfile\n"; }

}

sub normalize_duration ($) {
    my $duration = shift()."h";
    $duration =~ s/\s//g;
    my ($num, $unit) = (lc($duration) =~ /^(\d+)(\w)/);
    return undef unless ($num && $unit);
    my $multiplier = 0;
    if    ($unit eq "s") { $multiplier = 1;}
    elsif ($unit eq "m") { $multiplier = 60;}
    elsif ($unit eq "h") { $multiplier = 60*60;}
    elsif ($unit eq "d") { $multiplier = 60*60*24;}
    elsif ($unit eq "w") { $multiplier = 60*60*24*7;}
    elsif ($unit eq "m") { $multiplier = 60*60*24*30;}
    elsif ($unit eq "y") { $multiplier = 60*60*24*365;}
    else                 { return undef;}
    return $num * $multiplier; 
}

sub read_blacklist () {
    open (LIST, "/etc/$name-blacklist") or return {};
    my $blacklist = {};
    while (my $line = <LIST>) {
	chomp $line;
	$line =~ s/\#.*//;
	my ($host, $client, $time) = split (/\s+/, $line);
	next unless ($host && $client && $time && ($time =~ /^\d+$/));
	if ($time > time()) {
	    $blacklist->{$host} ||= {};
	    $blacklist->{$host}->{$client} = $time;
	}
    }
    close LIST;
    return $blacklist;
}

sub write_blacklist ($) {
    my $blacklist = shift;
    open (LIST, "> /etc/$name-blacklist") or return undef;
    for my $host (sort keys %$blacklist) {
	for my $client (sort keys %{$blacklist->{$host}}) {
	    printf LIST ("%-15s %-15s %s   # %s\n",
			 $host, $client, $blacklist->{$host}->{$client},
			 scalar(localtime($blacklist->{$host}->{$client})));
	}
    }
    close LIST;
    return $blacklist;
}

sub read_clientlist () {
    open (LIST, "/etc/$name-clientlist") or return {};
    my $list = {};
    while (my $line = <LIST>) {
	chomp $line;
	$line =~ s/\#.*//;
	my ($client, $port) = split (/[\s:]+/, $line);
	next unless ($client && $port && ($port =~ /^\d+$/));
	$list->{$client} = $port;
    }
    close LIST;
    return $list;
}

sub write_clientlist ($) {
    my $clients = shift;
    open (LIST, "> /etc/$name-clientlist") or return undef;
    for my $client (sort keys %$clients) {
	my $ip = inet_aton($client);
	my $name = gethostbyaddr ($ip, AF_INET);
	printf LIST ("%-30s # %s\n", "$client:$clients->{$client}", $name);
    }
    close LIST;
    return $blacklist;
}

sub read_line_nonblock ($;$) {
    my $handle = shift;
    my $delay = shift || 1;
    # First, make the socket non-blocking.
    my $flags = 0; 
    unless (fcntl ($handle, F_GETFL, $flags)) {close $handle; warn "Couldn't get flags\n"; return undef;}
    my $newflags = $flags | O_NONBLOCK;
    unless (fcntl ($handle, F_SETFL, $newflags)) {close $handle; warn "Couldn't set flags\n"; return undef;}

    my $start = time();
    my $buffer;
    for (;;) {
	# Spin, trying to read the socket. If we spin for longer
	# than two seconds, close the connection.
	my $rv = sysread ($handle, $buffer, 128);
	if (!defined($rv) && $! == EAGAIN) {
	    if (($start + $delay) < time()) { $buffer = ""; last; }
	} else {
	    last;
	}
    }
    unless (fcntl ($handle, F_SETFL, $flags)) { warn "Couldn't reset flags\n";}
    return $buffer;
}

sub cleanup_children ($) {
    my $children = shift;
    for my $child (keys %$children) {
	if (kill (0, $child) && open (CHILD, "/proc/$child/stat")) {
	    # the process is alive, see if it's ours
	    my $CPPID = (split (/\s+/, <CHILD>))[3];
	    if ($CPPID = $$) {
		# this child's PPID is me!
		warn "Killing recalcitrant $child...\n";
		kill (9, $child);
	    }
	}
	delete $children->{$child};
    }
    return $children;
}

sub cleanup_clients ($) {
    # Clean up clients which have not been contactable for three tries
    my $clients = shift;
    for my $client (%$clients) {
	my @bans = glob ("$spooldir/$client*");
	if (@bans > 2) {
	    unlink $_ for @bans;
	    delete $clients->{$client};
	}
    }
    write_clientlist ($clients);
}

sub push_list_to_clients ($$) {
    # This function is ran by each forked child
    my $list = shift;
    my $queue = shift;
    my @strs = split (/\n/, sprintf ("------------ Child $$ reporting:\n%s", ObjectContents ({blacklist => $list, queue => $queue}, 3)));

    for my $client (keys %$queue) {
	next unless my $port = $queue->{$client};
	push (@strs,  "Contacting $client:$port...");
	if (my $socket = IO::Socket::INET->new(PeerAddr => $client,
					       PeerPort => $queue->{$client},
					       Proto => 'tcp',
					       Type => SOCK_STREAM,
					       Timeout => 1)) {
	    $strs[-1] .= " Success!";
	    for my $host (keys %$list) {
		print $socket "$host $list->{$host}\n";
	    }
	    close $socket;
	} else {
	    $strs[-1] .= " Failure!";
	    open (TMP, "> $spooldir/$client-$$-".time()) and close TMP;
	}
    }

    log_strings(@strs);
}

sub distribute_blacklist ($$;$) {
    my $list = shift;
    my $clients = shift;
    my $force = shift;
    my $sizebase = 10;
    my $blist = {};

    my $report = "Blacklist distribution in progress.\n\n";

    # extract actual blacklist from the list of reports
    for my $host (keys %$list) {
	if (keys (%{$list->{$host}}) > 1) {
	    my $maxtime = 0;
	    for my $key (keys %{$list->{$host}}) {
		# Find the highest time reported by any client blacklisting this host
		$maxtime = $list->{$host}->{$key} if ($list->{$host}->{$key} > $maxtime);
	    }
	    if ($maxtime > time()) { $blist->{$host} = $maxtime; } 
	    else                   { delete $list->{$host}; }
	}
    }

    my @new_hosts = ();
    for my $key (keys %$blist) { push (@new_hosts, $key) unless $old_blist{$key}; }

    $report .= sprintf ("The following host has just been added to the blacklist:\n%s\n",
			join ("\n", @new_hosts)."\n\nThe entire blacklist is:")
	if  @new_hosts;

    for my $host (keys %$blist) {
	$report .= sprintf ("    %-20s -- expires %s\n", $host, scalar (localtime($blist->{$host})));
    }

    $report .= "\n";

    # FIXME
    $report .= "the old list was:\n" if %old_blist;
    for my $host (keys %old_blist) {
	$report .= sprintf ("    %-20s -- expires %s\n", $host, scalar (localtime($old_blist{$host})));
    }

    if ((time() - $last_distrib > $interval) && %$clients) {
	# $interval seconds has passed since the last distribution, proceed
	$last_distrib = time();
	cleanup_children ($children);
	cleanup_clients ($clients);
	log_strings ("--- Distribute the data:", ObjectContents ({clients => $clients, blacklist => $blist}, 5));

	my @clients = sort keys %$clients;
	my @cbk = @clients;
	my @clists = ();
	my $numqueues = int(sqrt (@clients)); #int(log(scalar @clients)/log($sizebase) + 1);
	my $queuesize = int ((@clients - 1) / $numqueues + 1);
	log_strings ("$numqueues queues of up to $queuesize clients each");
	$report .= "$numqueues queues of up to $queuesize clients each\n";
	for (my $queue = 0; @clients; ++$queue) {
	    # Split the list of clients into separate queues, each to be handled by a forked child
	    for (my $i = $queuesize; $i && @clients; --$i) {
		my $client =  pop(@clients);
		$clists[$queue]->{$client} = $clients->{$client};
	    }
	}

	$report .= "See /var/log/sshd_sentry_server.log for detailed info\n\n".
	    ObjectContents ({raw_blacklist => $list}, 5);

	if ( @new_hosts) {
	    mail_to_users ($report, "Blacklisting report", @sysmail);
	    %old_blist = %$blist;
	}

	# die ObjectContents ({num => $numqueues, clients => \@cbk, lists => \@clists}, 4);
	for my $queue (@clists) {
	    if (my $child = fork()) {
		# we are the parent
		$children->{$child} = 1;
	    } else {
		# we are the child
		push_list_to_clients ($blist, $queue);
		exit 0;
	    }
	}
    } else {
	# The update came too soon, we are not ready to distribute
	log_strings ("--- Skipping distribution at this time...");
	return 0;
    }
}

#############################
#                           #
#   Execution begins here   #
#                           #
#############################
GetOptions ("help"            => \$help,
	    "restart"         => \$restart,
	    "threshold=i"     => \$threshold,
	    "duration=s"      => \$duration,
	    "server_port=i"   => \$server_port,
	    "interval=s"      => \$interval,
	    "verbose"         => \$verbose
);

mkdir $spooldir unless -d $spooldir;

die "The server can currently only run on Linux.\n" unless grep (/linux/i, `/bin/uname`);

if ($help) { print help(); exit 0;}

negotiate_pid($restart);

# Activate $<option>_default values
eval "no strict 'vars'; \$$_ ||= \$${_}_default"
    for qw(file interval threshold duration penalty server_port interval);
$interval = normalize_duration ($interval);
$duration = normalize_duration ($duration);

die "Bad duration spec ($duration)\n" unless $duration;
die "Bad interval spec ($interval)\n" unless $interval;

$blacklist = read_blacklist();
$clients = read_clientlist();

log_strings ("Listening on port $server_port");
my $server = IO::Socket::INET->new (LocalPort => $server_port,
				    Type => SOCK_STREAM,
				    Reuse => 1,
				    Listen => 10)
				    #Blocking => 0)
    or die "Cannot bind $name to port $server_port: $@\n";
while (my $client = $server->accept()) {
    my $got_update = 0;
    my $peer = inet_ntoa($client->peeraddr());
    log_strings ("Accepted from $peer");

    if ($peer =~ /^(128.119.24[01234567])|(128.119.4[12])/) {
	# one of ours, talk to them

	# Each line consists of a port to upload blacklists on, followed by a new
	# blacklist entry. Note that it's safe to supply just a port number -- this
	# will effectively register the connecting host as a client for future
	# blacklist distributions.
	#    i.e.
	# <callback port>:<blacklisted host>
	my $line = read_line_nonblock ($client);
	chomp $line;
	my ($port, $host) = split (/[:\n\r]+/sm, $line);
	unless ($clients->{$peer}) {
	    # This client just registered, distribute the blacklist to them right away
	    log_strings ("Distributing NOW...");
	    distribute_blacklist ($blacklist, {$peer => $port}, 1);
	}

	$clients->{$peer} = $port;
	write_clientlist ($clients);  $clients = read_clientlist();
	if ($host && $port =~ /^\d+$/ &&
	    $host !~ /(^(128.119.24[01234567])|(128.119.4[12]))|(\.cs\.umass\.edu)/) {
	    $blacklist->{$host} ||= {};

	    # Don't automatically distribute updates if we only got a repeat blacklist
	    # entry from a host which had already reported it
	    $got_update = 1; # unless $blacklist->{$host}->{$peer}; # FIXME

	    $blacklist->{$host}->{$peer} = time() + $duration;
	    print $client "Gotcha.\n";
	} else {
	    # Line is not valid, likely due to a timeout issue.
	    print $client "Bad line (idle timeout?)\n";
	}
    } else {
	# the client is from outside our domain
	print $client "You are not authorized, go away\n";
    }
    close $client;

    if ($got_update) {
	$got_update = 0;
	write_blacklist ($blacklist); $blacklist = read_blacklist();
	distribute_blacklist ($blacklist, $clients);
    }
}
