package org.apache.hadoop.yarn.server.nodemanager.containermanager.resourceplugin.deviceframework;

import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.TreeMap;
import java.util.TreeSet;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.shaded.com.google.common.annotations.VisibleForTesting;
import org.apache.hadoop.shaded.com.google.common.collect.ImmutableMap;
import org.apache.hadoop.shaded.com.google.common.collect.ImmutableSet;
import org.apache.hadoop.shaded.com.google.common.collect.Sets;
import org.apache.hadoop.shaded.com.google.common.collect.UnmodifiableIterator;
import org.apache.hadoop.util.StringUtils;
import org.apache.hadoop.yarn.api.records.ContainerId;
import org.apache.hadoop.yarn.api.records.Resource;
import org.apache.hadoop.yarn.exceptions.ResourceNotFoundException;
import org.apache.hadoop.yarn.server.nodemanager.Context;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.Device;
import org.apache.hadoop.yarn.server.nodemanager.api.deviceplugin.DevicePluginScheduler;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.container.Container;
import org.apache.hadoop.yarn.server.nodemanager.containermanager.linux.resources.ResourceHandlerException;

/* loaded from: input_file:org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/deviceframework/DeviceMappingManager.class */
public class DeviceMappingManager {
    static final Log LOG = LogFactory.getLog(DeviceMappingManager.class);
    private Context nmContext;
    private static final int WAIT_MS_PER_LOOP = 1000;
    private Map<String, DevicePluginScheduler> devicePluginSchedulers = new ConcurrentHashMap();
    private Map<String, Set<Device>> allAllowedDevices = new ConcurrentHashMap();
    private Map<String, Map<Device, ContainerId>> allUsedDevices = new ConcurrentHashMap();

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/apache/hadoop/yarn/server/nodemanager/containermanager/resourceplugin/deviceframework/DeviceMappingManager$DeviceAllocation.class */
    public static class DeviceAllocation {
        private String resourceName;
        private Set<Device> allowed;
        private Set<Device> denied;

        DeviceAllocation(String str, Set<Device> set, Set<Device> set2) {
            this.allowed = Collections.emptySet();
            this.denied = Collections.emptySet();
            this.resourceName = str;
            if (set != null) {
                this.allowed = ImmutableSet.copyOf(set);
            }
            if (set2 != null) {
                this.denied = ImmutableSet.copyOf(set2);
            }
        }

        public Set<Device> getAllowed() {
            return this.allowed;
        }

        public Set<Device> getDenied() {
            return this.denied;
        }

        public String toString() {
            return "ResourceType: " + this.resourceName + ", Allowed Devices: " + this.allowed + ", Denied Devices: " + this.denied;
        }
    }

    public DeviceMappingManager(Context context) {
        this.nmContext = context;
    }

    @VisibleForTesting
    public Map<String, Set<Device>> getAllAllowedDevices() {
        return this.allAllowedDevices;
    }

    @VisibleForTesting
    public Map<String, Map<Device, ContainerId>> getAllUsedDevices() {
        return this.allUsedDevices;
    }

    @VisibleForTesting
    public Map<String, DevicePluginScheduler> getDevicePluginSchedulers() {
        return this.devicePluginSchedulers;
    }

    @VisibleForTesting
    public Set<Device> getAllocatedDevices(String str, ContainerId containerId) {
        TreeSet treeSet = new TreeSet();
        for (Map.Entry<Device, ContainerId> entry : getAllUsedDevices().get(str).entrySet()) {
            if (entry.getValue().equals(containerId)) {
                treeSet.add(entry.getKey());
            }
        }
        return treeSet;
    }

    public synchronized void addDeviceSet(String str, Set<Device> set) {
        LOG.info("Adding new resource: type:" + str + "," + set);
        this.allAllowedDevices.put(str, new TreeSet(set));
        this.allUsedDevices.put(str, new TreeMap());
    }

    public DeviceAllocation assignDevices(String str, Container container) throws ResourceHandlerException {
        DeviceAllocation internalAssignDevices = internalAssignDevices(str, container);
        int i = 0;
        while (internalAssignDevices == null && i < 120000) {
            try {
                LOG.info("Container : " + container.getContainerId() + " is waiting for free " + str + " devices.");
                Thread.sleep(1000L);
                i += 1000;
                internalAssignDevices = internalAssignDevices(str, container);
            } catch (InterruptedException e) {
            }
        }
        if (internalAssignDevices != null) {
            return internalAssignDevices;
        }
        String str2 = "Could not get valid " + str + " device for container '" + container.getContainerId() + "' as some other containers might not releasing them.";
        LOG.warn(str2);
        throw new ResourceHandlerException(str2);
    }

    private synchronized DeviceAllocation internalAssignDevices(String str, Container container) throws ResourceHandlerException {
        Resource resource = container.getResource();
        ContainerId containerId = container.getContainerId();
        int requestedDeviceCount = getRequestedDeviceCount(str, resource);
        if (LOG.isDebugEnabled()) {
            LOG.debug("Try allocating " + requestedDeviceCount + " " + str);
        }
        if (requestedDeviceCount <= 0) {
            return new DeviceAllocation(str, null, this.allAllowedDevices.get(str));
        }
        if (requestedDeviceCount > getAvailableDevices(str) && requestedDeviceCount <= getReleasingDevices(str) + getAvailableDevices(str)) {
            return null;
        }
        int availableDevices = getAvailableDevices(str);
        if (requestedDeviceCount > availableDevices) {
            throw new ResourceHandlerException("Failed to find enough " + str + ", requestor=" + containerId + ", #Requested=" + requestedDeviceCount + ", #available=" + availableDevices);
        }
        TreeSet treeSet = new TreeSet();
        Map<Device, ContainerId> map = this.allUsedDevices.get(str);
        Set<Device> set = this.allAllowedDevices.get(str);
        pickAndDoSchedule(set, map, treeSet, container, requestedDeviceCount, str, this.devicePluginSchedulers.get(str));
        if (!treeSet.isEmpty()) {
            try {
                this.nmContext.getNMStateStore().storeAssignedResources(container, str, new ArrayList(treeSet));
            } catch (IOException e) {
                cleanupAssignedDevices(str, containerId);
                throw new ResourceHandlerException(e);
            }
        }
        return new DeviceAllocation(str, treeSet, Sets.difference(set, treeSet));
    }

    public synchronized void recoverAssignedDevices(String str, ContainerId containerId) throws ResourceHandlerException {
        Container container = this.nmContext.getContainers().get(containerId);
        Map<Device, ContainerId> map = this.allUsedDevices.get(str);
        Set<Device> set = this.allAllowedDevices.get(str);
        if (null == container) {
            throw new ResourceHandlerException("This shouldn't happen, cannot find container with id=" + containerId);
        }
        for (Serializable serializable : container.getResourceMappings().getAssignedResources(str)) {
            if (!(serializable instanceof Device)) {
                throw new ResourceHandlerException("Trying to recover device id, however it is not Device instance, this shouldn't happen");
            }
            Device device = (Device) serializable;
            if (!set.contains(device)) {
                throw new ResourceHandlerException("Try to recover device = " + device + " however it is not in allowed device list:" + StringUtils.join(",", set));
            }
            if (map.containsKey(device)) {
                throw new ResourceHandlerException("Try to recover device id = " + device + " however it is already assigned to container=" + map.get(device) + ", please double check what happened.");
            }
            map.put(device, containerId);
        }
    }

    public synchronized void cleanupAssignedDevices(String str, ContainerId containerId) {
        Iterator<Map.Entry<Device, ContainerId>> it = this.allUsedDevices.get(str).entrySet().iterator();
        while (it.hasNext()) {
            Map.Entry<Device, ContainerId> next = it.next();
            if (next.getValue().equals(containerId)) {
                if (LOG.isDebugEnabled()) {
                    LOG.debug("Recycle devices: " + next.getKey() + ", type: " + str + " from " + containerId);
                }
                it.remove();
            }
        }
    }

    public static int getRequestedDeviceCount(String str, Resource resource) {
        try {
            return Long.valueOf(resource.getResourceValue(str)).intValue();
        } catch (ResourceNotFoundException e) {
            return 0;
        }
    }

    public int getAvailableDevices(String str) {
        return this.allAllowedDevices.get(str).size() - this.allUsedDevices.get(str).size();
    }

    private long getReleasingDevices(String str) {
        long j = 0;
        UnmodifiableIterator it = ImmutableSet.copyOf(this.allUsedDevices.get(str).values()).iterator();
        while (it.hasNext()) {
            Container container = this.nmContext.getContainers().get((ContainerId) it.next());
            if (container != null && container.isContainerInFinalStates()) {
                j += container.getResource().getResourceInformation(str).getValue();
            }
        }
        return j;
    }

    private void pickAndDoSchedule(Set<Device> set, Map<Device, ContainerId> map, Set<Device> set2, Container container, int i, String str, DevicePluginScheduler devicePluginScheduler) throws ResourceHandlerException {
        ContainerId containerId = container.getContainerId();
        Map environment = container.getLaunchContext().getEnvironment();
        if (null == devicePluginScheduler) {
            if (LOG.isDebugEnabled()) {
                LOG.debug("Customized device plugin scheduler is preferred but not implemented, use default logic");
            }
            defaultScheduleAction(set, map, set2, containerId, i);
            return;
        }
        if (LOG.isDebugEnabled()) {
            LOG.debug("Customized device plugin implemented,use customized logic");
            LOG.debug("Try to schedule " + i + "(" + str + ") using " + devicePluginScheduler.getClass());
        }
        Set<Device> allocateDevices = devicePluginScheduler.allocateDevices(Sets.difference(set, map.keySet()), i, ImmutableMap.copyOf(environment));
        if (allocateDevices.size() != i) {
            throw new ResourceHandlerException(devicePluginScheduler.getClass() + " should allocate " + i + " of " + str + ", but actual: " + set2.size());
        }
        set2.addAll(allocateDevices);
        Iterator<Device> it = set2.iterator();
        while (it.hasNext()) {
            map.put(it.next(), containerId);
        }
    }

    private void defaultScheduleAction(Set<Device> set, Map<Device, ContainerId> map, Set<Device> set2, ContainerId containerId, int i) {
        LOG.debug("Using default scheduler. Allowed:" + set + ",Used:" + map + ", containerId:" + containerId);
        for (Device device : set) {
            if (!map.containsKey(device)) {
                map.put(device, containerId);
                set2.add(device);
                if (set2.size() == i) {
                    return;
                }
            }
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    @VisibleForTesting
    public synchronized void addDevicePluginScheduler(String str, DevicePluginScheduler devicePluginScheduler) {
        this.devicePluginSchedulers.put(str, Objects.requireNonNull(devicePluginScheduler));
    }
}
