Protocol overhaul

This commit is contained in:
Kenny Levinsen
2019-09-12 18:39:33 +02:00
parent 7cb451f075
commit dd176ea198
10 changed files with 337 additions and 164 deletions

36
Cargo.lock generated
View File

@@ -52,10 +52,19 @@ dependencies = [
]
[[package]]
name = "greetctl"
name = "greet_proto"
version = "0.1.0"
dependencies = [
"byteorder 1.3.2 (registry+https://github.com/rust-lang/crates.io-index)",
"serde 1.0.100 (registry+https://github.com/rust-lang/crates.io-index)",
"serde_json 1.0.40 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "greetctl"
version = "0.1.0"
dependencies = [
"greet_proto 0.1.0",
"rpassword 4.0.1 (registry+https://github.com/rust-lang/crates.io-index)",
]
@@ -63,8 +72,8 @@ dependencies = [
name = "greetd"
version = "0.1.0"
dependencies = [
"byteorder 1.3.2 (registry+https://github.com/rust-lang/crates.io-index)",
"clap 2.33.0 (registry+https://github.com/rust-lang/crates.io-index)",
"greet_proto 0.1.0",
"nix 0.15.0 (registry+https://github.com/rust-lang/crates.io-index)",
"pam 0.7.0 (git+https://github.com/regiontog/pam.git)",
"serde 1.0.100 (registry+https://github.com/rust-lang/crates.io-index)",
@@ -72,6 +81,11 @@ dependencies = [
"users 0.9.1 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "itoa"
version = "0.4.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "libc"
version = "0.2.62"
@@ -132,6 +146,11 @@ dependencies = [
"winapi 0.3.8 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "ryu"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
[[package]]
name = "serde"
version = "1.0.100"
@@ -150,6 +169,16 @@ dependencies = [
"syn 1.0.5 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "serde_json"
version = "1.0.40"
source = "registry+https://github.com/rust-lang/crates.io-index"
dependencies = [
"itoa 0.4.4 (registry+https://github.com/rust-lang/crates.io-index)",
"ryu 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)",
"serde 1.0.100 (registry+https://github.com/rust-lang/crates.io-index)",
]
[[package]]
name = "strsim"
version = "0.8.0"
@@ -244,6 +273,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
"checksum cc 1.0.45 (registry+https://github.com/rust-lang/crates.io-index)" = "4fc9a35e1f4290eb9e5fc54ba6cf40671ed2a2514c3eeb2b2a908dda2ea5a1be"
"checksum cfg-if 0.1.9 (registry+https://github.com/rust-lang/crates.io-index)" = "b486ce3ccf7ffd79fdeb678eac06a9e6c09fc88d33836340becb8fffe87c5e33"
"checksum clap 2.33.0 (registry+https://github.com/rust-lang/crates.io-index)" = "5067f5bb2d80ef5d68b4c87db81601f0b75bca627bc2ef76b141d7b846a3c6d9"
"checksum itoa 0.4.4 (registry+https://github.com/rust-lang/crates.io-index)" = "501266b7edd0174f8530248f87f99c88fbe60ca4ef3dd486835b8d8d53136f7f"
"checksum libc 0.2.62 (registry+https://github.com/rust-lang/crates.io-index)" = "34fcd2c08d2f832f376f4173a231990fa5aef4e99fb569867318a227ef4c06ba"
"checksum nix 0.15.0 (registry+https://github.com/rust-lang/crates.io-index)" = "3b2e0b4f3320ed72aaedb9a5ac838690a8047c7b275da22711fddff4f8a14229"
"checksum pam 0.7.0 (git+https://github.com/regiontog/pam.git)" = "<none>"
@@ -251,8 +281,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
"checksum proc-macro2 1.0.3 (registry+https://github.com/rust-lang/crates.io-index)" = "e98a83a9f9b331f54b924e68a66acb1bb35cb01fb0a23645139967abefb697e8"
"checksum quote 1.0.2 (registry+https://github.com/rust-lang/crates.io-index)" = "053a8c8bcc71fcce321828dc897a98ab9760bef03a4fc36693c231e5b3216cfe"
"checksum rpassword 4.0.1 (registry+https://github.com/rust-lang/crates.io-index)" = "f072d931f11a96546efd97642e1e75e807345aced86b947f9239102f262d0fcd"
"checksum ryu 1.0.0 (registry+https://github.com/rust-lang/crates.io-index)" = "c92464b447c0ee8c4fb3824ecc8383b81717b9f1e74ba2e72540aef7b9f82997"
"checksum serde 1.0.100 (registry+https://github.com/rust-lang/crates.io-index)" = "f4473e8506b213730ff2061073b48fa51dcc66349219e2e7c5608f0296a1d95a"
"checksum serde_derive 1.0.100 (registry+https://github.com/rust-lang/crates.io-index)" = "11e410fde43e157d789fc290d26bc940778ad0fdd47836426fbac36573710dbb"
"checksum serde_json 1.0.40 (registry+https://github.com/rust-lang/crates.io-index)" = "051c49229f282f7c6f3813f8286cc1e3323e8051823fce42c7ea80fe13521704"
"checksum strsim 0.8.0 (registry+https://github.com/rust-lang/crates.io-index)" = "8ea5119cdb4c55b55d432abb513a0429384878c15dde60cc77b1c99de1a95a6a"
"checksum syn 1.0.5 (registry+https://github.com/rust-lang/crates.io-index)" = "66850e97125af79138385e9b88339cbcd037e3f28ceab8c5ad98e64f0f1f80bf"
"checksum textwrap 0.11.0 (registry+https://github.com/rust-lang/crates.io-index)" = "d326610f408c7a4eb6f51c37c330e496b08506c9457c9d34287ecc38809fb060"

View File

@@ -1,2 +1,2 @@
[workspace]
members = ["greetd", "greetctl"]
members = ["greet_proto", "greetd", "greetctl"]

12
greet_proto/Cargo.toml Normal file
View File

@@ -0,0 +1,12 @@
[package]
name = "greet_proto"
version = "0.1.0"
authors = ["Kenny Levinsen <kl@kl.wtf>"]
edition = "2018"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
byteorder = "1.3"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"

89
greet_proto/src/lib.rs Normal file
View File

@@ -0,0 +1,89 @@
use std::error::Error;
use std::io;
use std::collections::HashMap;
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use serde::{Deserialize, Serialize};
pub struct Header {
pub version: u32,
pub len: u32,
}
impl Header {
pub const fn len() -> usize {
4 /* magic */ + 4 /* version */ + 4 /* payload length */
}
pub fn new(len: u32) -> Header {
Header{
version: 1,
len: len,
}
}
pub fn from_slice(bytes: &[u8]) -> Result<Header, Box<dyn Error>> {
let mut cursor = std::io::Cursor::new(bytes);
let proto_magic = cursor.read_u32::<LittleEndian>()?;
if proto_magic != 0xAFBFCFDF {
return Err(io::Error::new(io::ErrorKind::Other, "invalid message magic").into());
}
let proto_version = cursor.read_u32::<LittleEndian>()?;
let msg_len = cursor.read_u32::<LittleEndian>()?;
Ok(Header{
version: proto_version,
len: msg_len,
})
}
pub fn to_bytes(&self) -> Result<Vec<u8>, Box<dyn Error>> {
let mut buf = Vec::new();
buf.write_u32::<LittleEndian>(0xAFBFCFDF)?;
buf.write_u32::<LittleEndian>(self.version)?;
buf.write_u32::<LittleEndian>(self.len)?;
Ok(buf)
}
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(tag = "type")]
pub enum Request {
Login{
username: String,
password: String,
command: Vec<String>,
env: HashMap<String, String>
}
}
impl Request {
pub fn from_slice(bytes: &[u8]) -> Result<Request, Box<dyn Error>> {
serde_json::from_slice(bytes).map_err(|x| x.into())
}
pub fn to_bytes(&self) -> Result<Vec<u8>, Box<dyn Error>> {
serde_json::to_vec(self).map_err(|x| x.into())
}
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(tag = "type")]
pub enum Response {
LoginSuccess,
LoginFailure,
}
impl Response {
pub fn from_slice(bytes: &[u8]) -> Result<Response, Box<dyn Error>> {
serde_json::from_slice(bytes).map_err(|x| x.into())
}
pub fn to_bytes(&self) -> Result<Vec<u8>, Box<dyn Error>> {
serde_json::to_vec(self).map_err(|x| x.into())
}
}

View File

@@ -7,5 +7,5 @@ edition = "2018"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
byteorder = "1.3"
greet_proto = { path = "../greet_proto" }
rpassword = "4.0"

View File

@@ -1,8 +1,10 @@
use std::env;
use std::io::{self, BufRead, Write};
use std::io::{self, Read, BufRead, Write};
use std::os::unix::net::UnixStream;
use std::collections::HashMap;
use greet_proto::{Request, Response, Header};
use byteorder::{LittleEndian, WriteBytesExt};
use rpassword::prompt_password_stderr;
fn prompt_stderr(prompt: &str) -> Result<String, Box<dyn std::error::Error>> {
@@ -12,32 +14,49 @@ fn prompt_stderr(prompt: &str) -> Result<String, Box<dyn std::error::Error>> {
Ok(stdin_iter.next().unwrap()?)
}
fn login(
username: String,
password: String,
cmd: String,
) -> Result<(), Box<dyn std::error::Error>> {
let msg_len = username.len() + password.len() + cmd.len() + 12;
let mut buf = Vec::with_capacity(msg_len + 16);
buf.write_u32::<LittleEndian>(0xAFBFCFDF)?; // Proto Magic
buf.write_u32::<LittleEndian>(1)?; // Proto version
buf.write_u32::<LittleEndian>(1)?; // Message type
buf.write_u32::<LittleEndian>(msg_len as u32)?; // Payload length
buf.write_u32::<LittleEndian>(username.len() as u32)?;
buf.extend(username.into_bytes());
buf.write_u32::<LittleEndian>(password.len() as u32)?;
buf.extend(password.into_bytes());
buf.write_u32::<LittleEndian>(cmd.len() as u32)?;
buf.extend(cmd.into_bytes());
let mut stream = UnixStream::connect(env::var("GREETD_SOCK")?)?;
stream.write_all(&buf)?;
Ok(())
}
fn main() {
fn login() -> Result<(), Box<dyn std::error::Error>> {
let username = prompt_stderr("Username: ").unwrap();
let password = prompt_password_stderr("Password: ").unwrap();
let command = prompt_stderr("Command: ").unwrap();
login(username, password, command).unwrap();
let request = Request::Login{
username,
password,
command: vec![command],
env: HashMap::new(),
};
let mut stream = UnixStream::connect(env::var("GREETD_SOCK")?)?;
// Write request
let req = request.to_bytes()?;
let header = Header::new(req.len() as u32);
stream.write_all(&header.to_bytes()?)?;
stream.write_all(&req)?;
// Read response
let mut header_buf = vec![0; Header::len()];
stream.read_exact(&mut header_buf)?;
let header = Header::from_slice(&header_buf)?;
let mut resp_buf = vec![0; header.len as usize];
stream.read_exact(&mut resp_buf)?;
let resp = Response::from_slice(&resp_buf)?;
match resp {
Response::LoginSuccess => Ok(()),
Response::LoginFailure => Err(std::io::Error::new(io::ErrorKind::Other, "authentication failed").into())
}
}
fn main() {
loop {
match login() {
Ok(()) => {
eprintln!("authentication successful");
break;
}
Err(err) => eprintln!("error: {:?}", err),
}
}
}

View File

@@ -10,7 +10,7 @@ edition = "2018"
nix = "0.15"
pam = { git = "https://github.com/regiontog/pam.git" }
users = "0.9.1"
byteorder = "1.3"
clap = "2.33"
toml = "0.5"
serde = { version = "1.0", features = ["derive"] }
greet_proto = { path = "../greet_proto" }

View File

@@ -1,12 +1,12 @@
use std::error::Error;
use std::io;
use std::io::{Read, Take};
use std::io::{Read, Write, Take};
use std::os::unix::io::{AsRawFd, RawFd};
use std::os::unix::net::UnixStream;
use nix::poll::PollFlags;
use byteorder::{LittleEndian, ReadBytesExt};
use greet_proto::{Header, Request, Response};
use crate::context::Context;
use crate::pollable::{PollRunResult, Pollable};
@@ -14,11 +14,9 @@ use crate::scrambler::Scrambler;
enum ClientState {
AwaitingHeader,
AwaitingPayload { typ: u32, len: u32 },
AwaitingPayload { len: u32 },
}
const IPC_HEADERLEN: usize = 16; // 4 bytes magic, 4 bytes version, 4 byte type, 4 byte len
pub struct Client {
stream: Take<UnixStream>,
buf: Vec<u8>,
@@ -26,42 +24,10 @@ pub struct Client {
}
impl Client {
fn read_header(cursor: &mut std::io::Cursor<&[u8]>) -> Result<(u32, u32), Box<dyn Error>> {
let proto_magic = cursor.read_u32::<LittleEndian>()?;
if proto_magic != 0xAFBFCFDF {
return Err(io::Error::new(io::ErrorKind::Other, "invalid message magic").into());
}
let proto_version = cursor.read_u32::<LittleEndian>()?;
if proto_version != 1 {
return Err(io::Error::new(io::ErrorKind::Other, "invalid message version").into());
}
let msg_type = cursor.read_u32::<LittleEndian>()?;
let msg_len = cursor.read_u32::<LittleEndian>()?;
Ok((msg_type, msg_len))
}
fn read_string(cursor: &mut std::io::Cursor<&[u8]>) -> Result<String, Box<dyn Error>> {
let len = cursor.read_u32::<LittleEndian>()?;
let mut data: Vec<u8> = vec![0; len as usize];
cursor.read_exact(&mut data)?;
String::from_utf8(data).map_err(|x| x.into())
}
fn read_login(
cursor: &mut std::io::Cursor<&[u8]>,
) -> Result<(String, String, String), Box<dyn Error>> {
let user = Client::read_string(cursor)?;
let pass = Client::read_string(cursor)?;
let cmd = Client::read_string(cursor)?;
Ok((user, pass, cmd))
}
pub fn new(stream: UnixStream) -> Result<Client, Box<dyn Error>> {
stream.set_nonblocking(true)?;
Ok(Client {
stream: stream.take(IPC_HEADERLEN as u64),
stream: stream.take(Header::len() as u64),
buf: Vec::new(),
state: ClientState::AwaitingHeader,
})
@@ -83,25 +49,21 @@ impl Pollable for Client {
ClientState::AwaitingHeader => {
match self.stream.read_to_end(&mut self.buf) {
Ok(_) => {
if self.buf.len() < IPC_HEADERLEN {
if self.buf.len() < Header::len() {
// Got EOF before we got enough data.
self.buf.scramble();
break Ok(PollRunResult::Dead);
}
let mut rdr = std::io::Cursor::new(self.buf.as_slice());
let (msg_type, msg_len) = match Client::read_header(&mut rdr) {
Ok(v) => v,
Err(_) => {
self.buf.scramble();
break Ok(PollRunResult::Dead);
}
};
let header = Header::from_slice(self.buf.as_slice())?;
if header.version != 1 {
return Err(io::Error::new(io::ErrorKind::Other, "invalid message version").into());
}
self.state = ClientState::AwaitingPayload {
typ: msg_type,
len: msg_len,
len: header.len,
};
self.stream.set_limit(msg_len as u64);
self.stream.set_limit(header.len as u64);
self.buf.truncate(0);
}
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
@@ -113,7 +75,7 @@ impl Pollable for Client {
}
}
}
ClientState::AwaitingPayload { typ, len } => {
ClientState::AwaitingPayload { len } => {
match self.stream.read_to_end(&mut self.buf) {
Ok(_) => {
if self.buf.len() < len as usize {
@@ -122,31 +84,29 @@ impl Pollable for Client {
break Ok(PollRunResult::Dead);
}
self.state = ClientState::AwaitingHeader;
self.stream.set_limit(IPC_HEADERLEN as u64);
match typ {
1 => {
// Login
let mut rdr = std::io::Cursor::new(self.buf.as_slice());
let (user, pass, cmd) = match Client::read_login(&mut rdr) {
Ok(v) => v,
Err(_) => {
self.buf.scramble();
break Ok(PollRunResult::Dead);
}
};
self.stream.set_limit(Header::len() as u64);
match Request::from_slice(&self.buf)? {
Request::Login{ username, password, command, env } => {
self.buf.scramble();
ctx.login(user, pass, cmd)?;
}
2 => {
// Screen lock
self.buf.scramble();
unimplemented!("screen lock has not yet been implemented");
}
_ => {
// Unknown message type
self.buf.scramble();
break Ok(PollRunResult::Dead);
let resp = match ctx.login(username, password, command, env) {
Ok(_) => Response::LoginSuccess,
Err(_) => Response::LoginFailure,
};
let resp_bytes = resp.to_bytes().expect("unable to serialize response");
let header = Header::new(resp_bytes.len() as u32);
let header_bytes = header.to_bytes().expect("unable to serialize header");
if self.stream.get_mut().write_all(&header_bytes).is_err() {
eprintln!("unable to write response header");
break Ok(PollRunResult::Dead);
}
if self.stream.get_mut().write_all(&resp_bytes).is_err() {
eprintln!("unable to write response");
break Ok(PollRunResult::Dead);
}
}
}
}

View File

@@ -2,20 +2,40 @@ use std::env;
use std::error::Error;
use std::ffi::CString;
use std::io;
use std::collections::HashMap;
use std::time::{Instant, Duration};
use nix::errno::Errno;
use nix::sys::signal::Signal;
use nix::sys::wait::{waitpid, WaitPidFlag, WaitStatus};
use nix::unistd::{execv, fork, initgroups, setgid, setuid, ForkResult, Gid, Uid};
use nix::unistd::{alarm, execv, fork, initgroups, setgid, setuid, ForkResult, Gid, Uid};
use users::os::unix::UserExt;
use crate::scrambler::Scrambler;
/// Session is an active session.
struct Session<'a> {
pam: pam::Authenticator<'a, pam::PasswordConv>,
task: nix::unistd::Pid,
}
/// PendingSession represents a successful login that is pending session
/// startup. It contains all the data necessary to start the session when the
/// greeter has finally shut down.
struct PendingSession<'a> {
waited_since: Instant,
pam: pam::Authenticator<'a, pam::PasswordConv>,
uid: Uid,
gid: Gid,
home: String,
shell: String,
username: String,
env: Vec<(String, String)>,
cmd: Vec<String>,
}
// Greeter is an active greeter.
struct Greeter {
task: nix::unistd::Pid,
}
@@ -23,6 +43,7 @@ struct Greeter {
pub struct Context<'a> {
session: Option<Session<'a>>,
greeter: Option<Greeter>,
pending_session: Option<PendingSession<'a>>,
greeter_bin: String,
greeter_user: String,
@@ -30,9 +51,7 @@ pub struct Context<'a> {
}
fn shoo(task: nix::unistd::Pid) {
eprintln!("sending SIGTERM");
let _ = nix::sys::signal::kill(task, Signal::SIGTERM);
eprintln!("waitpid with exponential backoff to 1 second");
let mut dead = false;
let mut sleep = 1;
while !dead && sleep < 1000 {
@@ -48,10 +67,8 @@ fn shoo(task: nix::unistd::Pid) {
std::thread::sleep(std::time::Duration::from_millis(sleep));
}
if !dead {
eprintln!("sending SIGKILL");
sleep = 1;
let _ = nix::sys::signal::kill(task, Signal::SIGKILL);
eprintln!("waitpid with exponential backoff to 1 second");
while !dead && sleep < 1000 {
match waitpid(task, Some(WaitPidFlag::WNOHANG)) {
Ok(WaitStatus::Exited(..)) | Ok(WaitStatus::Signaled(..)) => {
@@ -65,7 +82,6 @@ fn shoo(task: nix::unistd::Pid) {
std::thread::sleep(std::time::Duration::from_millis(sleep));
}
}
eprintln!("done waiting");
}
impl<'a> Context<'a> {
@@ -73,6 +89,7 @@ impl<'a> Context<'a> {
Context {
session: None,
greeter: None,
pending_session: None,
greeter_bin: greeter_bin,
greeter_user: greeter_user,
tty: tty,
@@ -143,9 +160,9 @@ impl<'a> Context<'a> {
&mut self,
username: String,
mut password: String,
cmd: String,
cmd: Vec<String>,
provided_env: HashMap<String, String>,
) -> Result<(), Box<dyn Error>> {
eprintln!("initiating login");
if self.session.is_some() {
return Err(io::Error::new(io::ErrorKind::Other, "session already active").into());
}
@@ -159,61 +176,80 @@ impl<'a> Context<'a> {
return Err(io::Error::new(io::ErrorKind::Other, "authentication failed").into());
}
eprintln!("login successful");
// TODO: Fetch the username from the PAM session.
let u = users::get_user_by_name(&username).expect("unable to get user struct");
let uid = Uid::from_raw(u.uid());
let gid = Gid::from_raw(u.primary_group_id());
let cusername = CString::new(u.name().to_str().expect("unable to get username"))
.expect("unable to create username CString");
let cpath = CString::new("/bin/sh").unwrap();
let cargs = [
cpath.clone(),
CString::new("-c").unwrap(),
CString::new(format!("[ -f /etc/profile ] && source /etc/profile; [ -f $HOME/.profile ] && source $HOME/.profile; exec {}", cmd)).unwrap()
];
auth.env("XDG_SESSION_TYPE", "wayland")?;
auth.env("XDG_SESSION_CLASS", "user")?;
auth.env("XDG_VTNR", self.tty.to_string())?;
auth.env("XDG_SEAT", "seat0")?;
eprintln!("opening session");
for (key, value) in provided_env.iter() {
auth.env(key, value)?;
}
auth.open_session().expect("unable to open session");
password.scramble();
let myenv: Vec<String> = if let Some(pamenv) = auth.environment() {
let myenv: Vec<(String, String)> = if let Some(pamenv) = auth.environment() {
pamenv
.iter()
.map(|x| x.to_string_lossy().into_owned())
.map(|x| {
let x = x.to_string_lossy().into_owned();
let mut parts = x.splitn(2, '=');
match (parts.next(), parts.next()) {
(Some(key), Some(value)) => Some((key.to_string(), value.to_string())),
_ => None,
}
})
.filter(|x| x.is_some())
.map(|x| x.unwrap())
.collect()
} else {
env::vars().map(|(x, y)| format!("{}={}", x, y)).collect()
// TODO: Handle this better. Can it happen at all?
env::vars().chain(provided_env.iter().map(|(x, y)| (x.to_string(), y.to_string()))).collect()
};
eprintln!("terminating greeter");
self.pending_session = Some(PendingSession{
waited_since: Instant::now(),
pam: auth,
env: myenv,
uid,
gid,
home: u.home_dir().to_str().unwrap().to_string(),
shell: u.shell().to_str().unwrap().to_string(),
cmd: cmd,
username: u.name().to_str().unwrap().to_string(),
});
match self.greeter.take() {
Some(greeter) => shoo(greeter.task),
None => (),
};
alarm::set(10);
eprintln!("forking session task");
Ok(())
}
fn start_session(&mut self, p: PendingSession<'a>) -> Result<(), Box<dyn Error>> {
let cusername = CString::new(p.username.to_string())
.expect("unable to create username CString");
let cpath = CString::new("/bin/sh").unwrap();
let cargs = [
cpath.clone(),
CString::new("-c").unwrap(),
CString::new(format!("[ -f /etc/profile ] && source /etc/profile; [ -f $HOME/.profile ] && source $HOME/.profile; exec {}", p.cmd.join(" "))).unwrap()
];
let child = match fork()? {
ForkResult::Parent { child, .. } => child,
ForkResult::Child => {
// Drop privileges to target user
initgroups(&cusername, gid).expect("unable to init groups");
setgid(gid).expect("unable to set GID");
setuid(uid).expect("unable to set UID");
initgroups(&cusername, p.gid).expect("unable to init groups");
setgid(p.gid).expect("unable to set GID");
setuid(p.uid).expect("unable to set UID");
// Change working directory
let pwd = match env::set_current_dir(&u.home_dir()) {
Ok(_) => u.home_dir().to_str().unwrap().to_string(),
let pwd = match env::set_current_dir(&p.home) {
Ok(_) => p.home.to_string(),
Err(_) => {
env::set_current_dir("/").expect("unable to set current dir");
"/".to_string()
@@ -221,26 +257,20 @@ impl<'a> Context<'a> {
};;
// Set environment
for e in myenv {
let mut parts = e.splitn(2, '=');
match (parts.next(), parts.next()) {
(Some(key), Some(value)) => env::set_var(key, value),
_ => (),
};
for (key, value) in p.env {
env::set_var(key, value);
}
env::set_var("LOGNAME", &u.name());
env::set_var("HOME", &u.home_dir());
env::set_var("LOGNAME", &p.username);
env::set_var("HOME", &p.home);
env::set_var("PWD", &pwd);
env::set_var("SHELL", &u.shell());
env::set_var("SHELL", &p.shell);
if env::var("TERM").is_err() {
env::set_var("TERM", "linux");
}
if env::var("XDG_RUNTIME_DIR").is_err() {
env::set_var("XDG_RUNTIME_DIR", format!("/run/user/{}", uid));
env::set_var("XDG_RUNTIME_DIR", format!("/run/user/{}", p.uid));
}
eprintln!("execing session task");
// Run
execv(&cpath, &cargs).expect("unable to exec");
unreachable!("after exec");
@@ -249,24 +279,44 @@ impl<'a> Context<'a> {
self.session = Some(Session {
task: child,
pam: auth,
pam: p.pam,
});
Ok(())
}
pub fn alarm(&mut self) {
if let Some(Greeter{ task }) = self.greeter.take() {
if let Some(p) = self.pending_session.take() {
if p.waited_since.elapsed() > Duration::from_secs(5) {
shoo(task);
if let Err(e) = self.start_session(p) {
eprintln!("session start failed: {:?}", e);
}
} else {
self.pending_session = Some(p);
self.greeter = Some(Greeter{ task });
}
} else {
self.greeter = Some(Greeter{ task });
}
}
}
pub fn check_children(&mut self) {
match self.session.take() {
Some(session) => {
match waitpid(session.task, Some(WaitPidFlag::WNOHANG)) {
Ok(WaitStatus::Exited(..)) | Ok(WaitStatus::Signaled(..)) => {
Ok(WaitStatus::Exited(..)) | Ok(WaitStatus::Signaled(..)) | Err(nix::Error::Sys(Errno::ECHILD)) => {
// Session task is dead, so kill the session and
// restart the greeter.
eprintln!("session exited");
drop(session.pam);
self.greet().expect("unable to start greeter");
}
Ok(WaitStatus::StillAlive) => self.session = Some(session),
_ => panic!("waitpid on session returned unexpected status"),
v => panic!("waitpid on session returned unexpected status: {:?}", v),
}
}
None => (),
@@ -274,15 +324,24 @@ impl<'a> Context<'a> {
match self.greeter.take() {
Some(greeter) => {
match waitpid(greeter.task, Some(WaitPidFlag::WNOHANG)) {
Ok(WaitStatus::Exited(..)) | Ok(WaitStatus::Signaled(..)) => {
if self.session.is_none() {
// Greeter died on us, let's just die with it.
eprintln!("greeter exited");
std::process::exit(1);
Ok(WaitStatus::Exited(..)) | Ok(WaitStatus::Signaled(..)) | Err(nix::Error::Sys(Errno::ECHILD)) => {
match self.pending_session.take() {
Some(pending_session) => {
// Our greeter finally bit the dust so we can
// start our pending session.
if let Err(e) = self.start_session(pending_session) {
eprintln!("session start failed: {:?}", e);
}
}
None => if self.session.is_none() {
// Greeter died on us, let's just die with it.
eprintln!("greeter exited");
std::process::exit(1);
}
}
}
Ok(WaitStatus::StillAlive) => self.greeter = Some(greeter),
_ => panic!("waitpid on greeter returned unexpected status"),
v => panic!("waitpid on greeter returned unexpected status: {:?}", v),
}
}
None => (),

View File

@@ -16,8 +16,9 @@ pub struct Signals {
impl Signals {
pub fn new() -> Result<Signals, Box<dyn Error>> {
let mut mask = SigSet::empty();
mask.add(Signal::SIGCHLD);
mask.add(Signal::SIGALRM);
mask.add(Signal::SIGTERM);
mask.add(Signal::SIGCHLD);
mask.thread_block()?;
let listener = SignalFd::with_flags(&mask, SfdFlags::SFD_NONBLOCK | SfdFlags::SFD_CLOEXEC)?;
@@ -39,6 +40,7 @@ impl Pollable for Signals {
loop {
match self.listener.read_signal() {
Ok(Some(sig)) => match Signal::from_c_int(sig.ssi_signo as i32)? {
Signal::SIGALRM => ctx.alarm(),
Signal::SIGCHLD => ctx.check_children(),
Signal::SIGTERM => ctx.terminate(),
_ => (),