#!/usr/bin/python -Es
#
# Authors: Dan Walsh <dwalsh@redhat.com>
#
# Copyright (C) 2012 Red Hat, Inc.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
#

PROGNAME="sandboxcontainer"
from gi.repository import LibvirtGObject
from gi.repository import LibvirtSandbox
from gi.repository import GLib
#from gi.repository import Gtk
import gi
import os, sys, shutil, errno, stat
import rpm
from subprocess import Popen, PIPE, STDOUT
import gettext
import selinux

LibvirtGObject.init_object_check(None)
LibvirtSandbox.init_check(None)

gettext.bindtextdomain(PROGNAME, "/usr/share/locale")
gettext.textdomain(PROGNAME)
try:
    gettext.install(PROGNAME,
                    localedir="/usr/share/locale",
                    unicode=False,
                    codeset = 'utf-8')
except IOError:
    import __builtin__
    __builtin__.__dict__['_'] = unicode

class Container:
    DEFAULT_DIRS  = [ "/etc", "/var" ]
    CONFIG_PATH   = "/etc/libvirt-sandbox/%s.sandbox"
    SYSTEM_DIRS   = [ "/tmp", "/etc/pki", "/var/tmp", "/dev/shm", "/var/log", "/var/lib/dhclient", "/run", "/home", "/root", "/var/spool" ]
    SELINUX_FILE_TYPE  = "svirt_lxc_file_t"
    SELINUX_TYPE  = "svirt_lxc_net_t"
    SELINUX_LEVEL = "s0-s0:c0.c1023"

    def __init__(self, name=None, path = "/var/lib/libvirt/filesystems"):
        self.clone = False
        self.image = None
        self.name = name
        self.path = path
        self.service_files = []
        self.conn = None
        self.mcs = "s0"
        self.dynamic = False
        self.file_type = self.SELINUX_FILE_TYPE
        self.security_type = self.SELINUX_TYPE
        self.security_level = self.SELINUX_LEVEL
        self.size = 10 * MB
        if self.name:
            config = self.CONFIG_PATH % self.name
            try:
                self.config = LibvirtSandbox.Config()
                self.config.load_path(config)
            except gi._glib.GError, e:
                raise OSError(config + ": " + str(e))

    def __follow_units(self):
        unitst=""
        for i in self.service_files:
            unitst += "#ReloadPropagatedFrom=%s\n" % i

        return unitst

    def __create_system_unit(self):

        if self.rpm_name:
            name = "%s@%s.service" % (self.rpm_name, self.name)
        else:
            name = self.name

        self.unitfile = "/etc/systemd/system/" + name
            
        unit = """
[Unit]
Description=Secure Container %s for %s Server
Requires=libvirtd.service
After=libvirtd.service
%s 
[Service]
Type=simple
ExecStart=/usr/sbin/virt-sandbox-service start %s
ExecReload=/usr/sbin/virt-sandbox-service reload %s
ExecStop=/usr/sbin/virt-sandbox-service stop %s

[Install]
WantedBy=multi-user.target
""" % ( self.name , self.executable, self.__follow_units(), self.name, self.name, self.name)
        fd = open(self.unitfile, "w")
        fd.write(unit)
        fd.close()
        sys.stdout.write("Created unit file %s\n" % self.unitfile)

    def __add_dir(self, newd):
        self.all_dirs.append(newd)
        tmp = self.dirs
        for d in tmp:
            if newd.startswith(d):
                return
            if d.startswith(newd):
                self.dirs.remove(d)
                self.dirs.append(newd)
                return
        self.dirs.append(newd)
                
    def set_clone(self, clone):
        self.clone = clone

    def get_dynamic(self):
        return self.dynamic 

    def set_dynamic(self, dynamic):
        self.dynamic = dynamic

    def set_security_type(self, security_type):
        self.security_type = security_type

    def set_security_level(self, security_level):
        self.security_level = security_level

    def set_image(self, size):
        self.image = "/var/lib/libvirt/images/%s.image" % self.name
        self.size = size

    def set_path(self, path):
        self.path = path

    def set_mcs(self, mcs):
        self.mcs = mcs

    def __extract_rpm(self):
        self.all_dirs = []
        self.dirs = []

        ts = rpm.ts()
        mi = ts.dbMatch(rpm.RPMTAG_BASENAMES, self.executable)
        h = mi.next()
        self.rpm_name=h['name']
        for rec in h.fiFromHeader():
            f = rec[0]
            for b in self.DEFAULT_DIRS:
                if f.startswith(b):
                    if os.path.isdir(f):
                        self.__add_dir(f)

            if f.endswith(".service"):
                self.service_files.append(f)

    def __add_mount(self, source, dest):
        mount = LibvirtSandbox.ConfigMount.new(dest)
        mount.set_root(source)
        if self.image:
            self.config.add_bind_mount(mount)
        else:
            self.config.add_host_mount(mount)
        
    def __gen_filesystems(self):
        if self.image:
            self.config.add_host_mount(LibvirtSandbox.ConfigMount.new(self.image,self.dest))

        for d in self.SYSTEM_DIRS + self.dirs:
            source = "%s%s" % ( self.dest, d)
            self.__add_mount(source, d)

    def __get_init_path(self):
        return "%s/%s/init" % (self.path, self.name)

    def __gen_dirs(self):
        if self.clone:
            for d in self.dirs:
                shutil.copytree(d, "%s%s" % (self.dest, d), symlinks=True)
        else:
            for d in self.all_dirs:
                os.makedirs("%s%s" % (self.dest, d))

        for d in self.SYSTEM_DIRS:
            try:
                os.makedirs("%s%s" % (self.dest, d))
            except OSError, e:
                if not e.errno == errno.EEXIST:
                    raise e
                
    def __umount(self):
        p = Popen(["/bin/umount", self.dest])
        p.communicate()
        if p.returncode and p.returncode != 0:
            raise OSError("Failed to unmount image %s from %s " %  (self.image, self.dest))

    def __create_image(self):
        fd = open(self.image, "w")
        fd.truncate(self.size)
        fd.close()
        p = Popen(["/sbin/mkfs","-F", "-t", "ext4", self.image],stdout=PIPE, stderr=PIPE)
        p.communicate()
        if p.returncode and p.returncode != 0:
            raise OSError("Failed to build image %s" % self.image )

        p = Popen(["/bin/mount", self.image, self.dest])
        p.communicate()
        if p.returncode and p.returncode != 0:
            raise OSError("Failed to mount image %s on %s " %  (self.image, self.dest))

    def __save(self):
        config = self.CONFIG_PATH % self.name
        if os.path.exists(config):
            os.remove(config)
        self.config.save_path(config)
        sys.stdout.write("Created sandbox config %s\n" % config)

    def delete(self):
        # Stop service if it is running
        try:
            self.stop()
        except:
            pass

        dest = "%s/%s" % (self.path, self.name)
        if os.path.exists(dest):
            shutil.rmtree(dest)
        config = self.CONFIG_PATH % self.name
        if os.path.exists(config):
            os.remove(config)
        unitfile = "/etc/systemd/system/%s.service" % self.name
        if os.path.exists(unitfile):
            os.remove(unitfile)

    def create(self, name, executable):
        self.dest = "%s/%s" % (self.path, name)
        if os.path.exists(self.dest):
            raise OSError("%s already exists" % self.dest)

        self.config = LibvirtSandbox.Config.new(name)
        self.config.set_security_label("system_u:system_r:%s:%s" % (self.security_type, self.security_level))
        self.config.set_security_dynamic(self.dynamic)
        net = LibvirtSandbox.ConfigNetwork().new()
        net.set_dhcp(True)
        self.config.add_network(net)
        self.name = name
        self.executable = executable
        self.config.set_command([ self.executable ])
        self.config.set_tty(True)
        self.config.set_shell(True)

        self.__extract_rpm()
        self.__gen_filesystems()
        os.mkdir(self.dest)
        sys.stdout.write("Created container dir %s\n" % self.dest)
        if self.image:
            self.__create_image()
            self.__gen_dirs()
            self.__umount()
        else:
            self.__gen_dirs()
            selinux.chcon(self.dest, "system_u:object_r:%s:%s" % (self.file_type, self.security_level), True)

#        resolv_conf = self.dest + "/etc/resolv.conf"
#        mount = LibvirtSandbox.ConfigMount.new("/etc/resolv.conf")
#        mount.set_root(resolv_conf)
#        self.config.add_host_mount(mount)
#        fd = open(resolv_conf,"w")
#        fd.close()

        self.__save()
        self.__create_system_unit()

    def reload(self):
        # Crude way of doing this.
        self.stop()
        self.start()

    def __connect(self):
        if not self.conn:
            self.conn=LibvirtGObject.Connection.new("lxc:///")
            self.conn.open(None)

    def start(self):
        def closed(obj, error):
            self.loop.quit()

        self.__connect()
        context = LibvirtSandbox.Context.new(self.conn, self.config)
        context.start()
        console = context.get_console()
        console.connect("closed", closed)
        console.attach_stdio()
        self.loop = GLib.MainLoop()
        self.loop.run()
        try:
            console.detach()
        except:
            pass
        try:
            context.stop()
        except:
            pass

    def stop(self):
        self.__connect()
        context = LibvirtSandbox.Context.new(self.conn, self.config)
        context.attach()
        context.stop()

    def get_process_label(self):
        return "system_u:system_r:%s:%s" % (self.process_type, self.mcs)

    def get_file_label(self):
        return "system_u:object_r:%s:%s" % (self.file_type, self.mcs)

MB = int(1000000)

def delete(args):
    container = Container(args.name)
    container.delete()

def create(args):
    container = Container()
    container.set_clone(args.clone)
    container.set_dynamic(args.dynamic)
    container.set_security_type(args.type)
    container.set_security_level(args.level)

    if args.level:
        container.set_mcs(args.level)

    if args.image:
        container.set_image(args.size)

    container.create(args.name, args.executable)

def usage(parser, msg):
    parser.print_help()

    sys.stderr.write("\n%s\n" % msg)
    sys.stderr.flush()
    sys.exit(1)

def start(args):
    container = Container(args.name)
    container.start()

def reload(args):
    container = Container(args.name)
    container.reload()

def stop(args):
    container = Container(args.name)
    container.stop()

if __name__ == '__main__':
    import argparse
    c = Container()

    parser = argparse.ArgumentParser(description='Secure Container Tool')
    subparsers = parser.add_subparsers(help="commands")
    parser_create = subparsers.add_parser("create", help="create a security container")

    parser_create.add_argument("-e", "--executable", required=True, dest="executable",  default=None,
                       help="Executable to run in the container")
    parser_create.add_argument("-p", "--path", dest="path",  default=None,
                       help="path name to store content")
    parser_create.add_argument("-t", "--type", dest="type", default=c.SELINUX_TYPE,
                       help="SELinux type with which to run the container")
    group = parser_create.add_mutually_exclusive_group()
    group.add_argument("-l", "--level", dest="level", default=c.SELINUX_LEVEL,
                       help="MCS Level with which to run the container")
    group.add_argument("-d", "--dynamic", dest="dynamic", default=False, action="store_true", 
                       help="Use dynamic MCS labeling for the container")
    parser_create.add_argument("-i", "--image", dest="image", default=False, action="store_true", 
                       help="use an image file for the container content")
    parser_create.add_argument("-s", "--size", dest="size", default=(10 * MB),
                       help="size of the image")
    parser_create.add_argument("-c", "--clone", default=False, action="store_true", 
                       help="copy /etc and /var directories that will be mounted in the container")
    
    parser_create.set_defaults(func=create)

    parser_start = subparsers.add_parser("start", help="Start a security container")
    parser_start.set_defaults(func=start)

    parser_delete = subparsers.add_parser("delete", help="Delete a security container")
    parser_delete.set_defaults(func=delete)

    parser_stop = subparsers.add_parser("stop", help="Stop a security container")
    parser_stop.set_defaults(func=stop)

    parser.add_argument("name", help="name of the container")
    
    args = parser.parse_args()
    try:
        args.func(args)
        sys.exit(0)
    except gi._glib.GError, e:
        usage(parser,e)
    except OSError, e:
        usage(parser,e)
    except KeyError, e:
        usage(parser,e)
    except KeyboardInterrupt, e:
        sys.exit(0)
    except ValueError, e:
        for line in e:
            for l in line:
                print l