diff --git a/lib/compiler/include/compiler/task_graph_simulator/pcg_task_graph.dtg.toml b/lib/compiler/include/compiler/task_graph_simulator/pcg_task_graph.dtg.toml index 2c5b5f56fc..31b87feb31 100644 --- a/lib/compiler/include/compiler/task_graph_simulator/pcg_task_graph.dtg.toml +++ b/lib/compiler/include/compiler/task_graph_simulator/pcg_task_graph.dtg.toml @@ -7,19 +7,19 @@ features = [ includes = [ "utils/graph/digraph/digraph_view.h", - "utils/bidict/bidict.h", "compiler/task_graph_simulator/pcg_task.dtg.h", "pcg/device_id_t.dtg.h", + "utils/many_to_one/many_to_one.h", "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", - "", - "" + "", + "" ] src_includes = [ - "utils/fmt/unordered_set.h", - "utils/hash/unordered_set.h", - "utils/fmt/unordered_map.h", - "utils/hash/unordered_map.h" + "utils/fmt/set.h", + "utils/hash/set.h", + "utils/fmt/map.h", + "utils/hash/map.h" ] [[fields]] @@ -28,8 +28,8 @@ type = "::FlexFlow::DiGraphView" [[fields]] name = "node_to_task" -type = "::FlexFlow::bidict<::FlexFlow::Node, ::FlexFlow::PCGTask>" +type = "::FlexFlow::ManyToOne<::FlexFlow::Node, ::FlexFlow::PCGTask>" [[fields]] name = "node_to_devices" -type = "std::unordered_map<::FlexFlow::Node, std::unordered_set<::FlexFlow::device_id_t>>" +type = "std::map<::FlexFlow::Node, std::set<::FlexFlow::device_id_t>>" diff --git a/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.cc b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.cc index 1aeb83d202..e8bd602289 100644 --- a/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.cc +++ b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.cc @@ -1,13 +1,13 @@ #include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.h" #include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_communication_edge.h" #include "utils/containers/filtermap_keys.h" -#include "utils/containers/map_from_pairs.h" #include "utils/containers/map_keys_with_value_merging.h" -#include "utils/containers/merge_maps_with.h" #include "utils/containers/require_all_same1.h" #include "utils/containers/require_same.h" #include "utils/containers/transform.h" #include "utils/containers/values.h" +#include "utils/containers/merge_unordered_maps_with.h" +#include "utils/containers/unordered_map_from_pairs.h" namespace FlexFlow { @@ -15,7 +15,7 @@ std::unordered_set abstracted_single_tensor_movement_get_dst_layers( AbstractedSingleTensorMovement const &m) { return transform( - keys(m.edge_to_size), + unordered_keys(m.edge_to_size), [](AbstractedSingleTensorCommunicationEdge const &e) -> BinaryTreePath { return e.dst.operator_tree_path; }); @@ -34,7 +34,7 @@ AbstractedSingleTensorMovement merge_abstracted_single_tensor_movements( return AbstractedSingleTensorMovement{ /*src_op_tree_path=*/require_all_same1(src_paths), /*edge_to_size=*/ - merge_maps_with(transform(vector_of(movements), + merge_unordered_maps_with(transform(vector_of(movements), [](AbstractedSingleTensorMovement const &m) { return m.edge_to_size; }), @@ -51,7 +51,7 @@ AbstractedSingleTensorMovement return AbstractedSingleTensorMovement{ /*src_op_tree_path=*/src_op_tree_path, /*edge_to_size=*/ - map_from_pairs( + unordered_map_from_pairs( transform(communications, [](AbstractedSingleTensorCommunication const &c) { return std::pair{c.edge, c.size}; diff --git a/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.cc b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.cc index 98a7d9b0b2..37bf62029f 100644 --- a/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.cc +++ b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_tensor_set_movement.cc @@ -3,13 +3,13 @@ #include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.dtg.h" #include "compiler/machine_mapping/abstracted_tensor_set_movement/abstracted_single_tensor_movement.h" #include "compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.h" -#include "utils/containers/binary_merge_maps_with.h" #include "utils/containers/flatmap.h" #include "utils/containers/map_keys_with_value_merging.h" #include "utils/containers/merge_maps_with.h" #include "utils/containers/transform.h" #include "utils/containers/unordered_set_of.h" #include "utils/hash/unordered_map.h" +#include "utils/containers/binary_merge_unordered_maps_with.h" namespace FlexFlow { @@ -63,7 +63,7 @@ TensorSetMovement concretize_abstracted_tensor_set_movement( [](TensorSetMovement const &lhs, TensorSetMovement const &rhs) -> TensorSetMovement { return TensorSetMovement{ - binary_merge_maps_with( + binary_merge_unordered_maps_with( lhs.edge_to_size, rhs.edge_to_size, [](num_bytes_t l, num_bytes_t r) { return l + r; }), diff --git a/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc index 151008f65f..192ada3fb6 100644 --- a/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc +++ b/lib/compiler/src/compiler/machine_mapping/abstracted_tensor_set_movement/get_abstracted_tensor_set_movement_across_split.cc @@ -10,10 +10,9 @@ #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_edge.dtg.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_edge.h" -#include "utils/bidict/algorithms/unordered_set_of.h" +#include "utils/bidict/algorithms/bidict_unordered_set_of.h" #include "utils/containers/binary_cartesian_product.h" #include "utils/containers/flatmap.h" -#include "utils/containers/generate_map.h" #include "utils/containers/get_only.h" #include "utils/containers/group_by.h" #include "utils/containers/map_from_pairs.h" @@ -45,8 +44,8 @@ AbstractedSingleTensorMovement get_abstracted_single_tensor_movement_along_edge( op_to_op_get_coord_mapping(mapping); std::unordered_map - single_comms = map_from_pairs(transform( - unordered_set_of(coord_mapping), + single_comms = unordered_map_from_pairs(transform( + bidict_unordered_set_of(coord_mapping), [&](std::pair const & src_dst) -> std::pair { @@ -101,9 +100,9 @@ AbstractedTensorSetMovement get_abstracted_tensor_set_movement_across_split( }; return AbstractedTensorSetMovement{ - transform(edges_by_tensor.right_groups(), - [&](nonempty_unordered_set const - &edges) { + transform(unordered_set_of(edges_by_tensor.right_groups()), + [&](nonempty_set const &edges) + { return merge_abstracted_single_tensor_movements(transform( unordered_multiset_of(edges.unwrap_as_unordered_set()), to_abstracted_single_tensor_movement)); diff --git a/lib/compiler/src/compiler/machine_mapping/apply_substitution_and_update_machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/apply_substitution_and_update_machine_mapping.cc index 7ccab2fac9..4e38750de3 100644 --- a/lib/compiler/src/compiler/machine_mapping/apply_substitution_and_update_machine_mapping.cc +++ b/lib/compiler/src/compiler/machine_mapping/apply_substitution_and_update_machine_mapping.cc @@ -59,7 +59,7 @@ SearchResult apply_substitution_and_update_machine_mapping( select_random(substituted_machine_views)); } - ASSERT(is_subseteq_of(keys(post_node_data), keys(machine_views))); + ASSERT(is_subseteq_of(unordered_keys(post_node_data), unordered_keys(machine_views))); std::unordered_map post_node_machine_views = diff --git a/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc index 77e50740aa..48f3bd9eed 100644 --- a/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc +++ b/lib/compiler/src/compiler/machine_mapping/get_optimal_machine_mapping.cc @@ -23,7 +23,7 @@ #include "utils/containers/contains.h" #include "utils/containers/contains_key.h" #include "utils/containers/flatmap.h" -#include "utils/containers/generate_map.h" +#include "utils/containers/generate_unordered_map.h" #include "utils/containers/get_all_assignments.h" #include "utils/containers/keys.h" #include "utils/containers/set_minus.h" @@ -103,7 +103,7 @@ MachineMappingResult set_minus(boundary_layers, get_constrained_layers(sub_constraints)); std::unordered_map> - allowed = generate_map( + allowed = generate_unordered_map( unconstrained_boundary_layers, [&](BinaryTreePath const &l) -> std::unordered_set { UnmappedRuntimeOnlyOpCostEstimateKey leaf = diff --git a/lib/compiler/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc b/lib/compiler/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc index 7d1f28337c..f7dbdd1d05 100644 --- a/lib/compiler/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc +++ b/lib/compiler/src/compiler/machine_mapping/get_tensor_set_movement_across_split.cc @@ -7,7 +7,6 @@ #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_edge.dtg.h" #include "pcg/parallel_computation_graph/parallel_computation_graph_edge.h" -#include "utils/containers/generate_map.h" #include "utils/containers/keys.h" #include "utils/containers/map_values.h" #include "utils/containers/sum.h" diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc index a2307716ba..c7b068d121 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping.cc @@ -7,8 +7,8 @@ #include "pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h" #include "utils/bidict/algorithms/bidict_from_map.h" #include "utils/containers/are_disjoint.h" -#include "utils/containers/binary_merge_disjoint_maps.h" -#include "utils/containers/keys.h" +#include "utils/containers/unordered_keys.h" +#include "utils/containers/binary_merge_disjoint_unordered_maps.h" namespace FlexFlow { @@ -20,7 +20,7 @@ MappedParallelComputationGraph get_parallel_layers(pcg); std::unordered_set mapped_layers = - keys(mapping.machine_views); + unordered_keys(mapping.machine_views); ASSERT(mapped_layers == pcg_layers); @@ -40,7 +40,7 @@ MappedParallelComputationGraph }; std::unordered_map - mapped_op_task_groups = generate_map(mapped_layers, mapping_for_layer); + mapped_op_task_groups = generate_unordered_map(mapped_layers, mapping_for_layer); return mapped_pcg_from_pcg_and_mapped_op_task_groups(pcg, mapped_op_task_groups); @@ -49,12 +49,12 @@ MappedParallelComputationGraph MachineMapping combine_disjoint_mappings(MachineMapping const &m1, MachineMapping const &m2) { return MachineMapping{ - binary_merge_disjoint_maps(m1.machine_views, m2.machine_views), + binary_merge_disjoint_unordered_maps(m1.machine_views, m2.machine_views), }; } bool nodes_are_disjoint(MachineMapping const &m1, MachineMapping const &m2) { - return are_disjoint(keys(m1.machine_views), keys(m2.machine_views)); + return are_disjoint(unordered_keys(m1.machine_views), unordered_keys(m2.machine_views)); } std::optional get_machine_mapping_from_machine_mapping_result( diff --git a/lib/compiler/src/compiler/machine_mapping/machine_mapping_constraints.cc b/lib/compiler/src/compiler/machine_mapping/machine_mapping_constraints.cc index 8278d5511c..f77d424795 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_mapping_constraints.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_mapping_constraints.cc @@ -3,8 +3,8 @@ #include "utils/containers/filter_values.h" #include "utils/containers/filtermap_keys.h" #include "utils/containers/flatmap.h" -#include "utils/containers/generate_map.h" -#include "utils/containers/keys.h" +#include "utils/containers/generate_unordered_map.h" +#include "utils/containers/unordered_keys.h" #include "utils/containers/map_values.h" #include "utils/containers/restrict_keys.h" #include "utils/full_binary_tree/binary_tree_path.h" @@ -14,7 +14,7 @@ namespace FlexFlow { MachineMappingConstraints get_unconstrained_solution_for_layers( std::unordered_set const &layers) { return MachineMappingConstraints{ - generate_map(layers, + generate_unordered_map(layers, [](BinaryTreePath const &) -> std::optional { return std::nullopt; }), @@ -24,7 +24,7 @@ MachineMappingConstraints get_unconstrained_solution_for_layers( std::unordered_set get_unconstrained_layers(MachineMappingConstraints const &constraints) { - return keys(filter_values( + return unordered_keys(filter_values( constraints.machine_views, [](std::optional const &mv) { return !mv.has_value(); })); } @@ -32,14 +32,14 @@ std::unordered_set std::unordered_set get_constrained_layers(MachineMappingConstraints const &constraints) { - return keys(filter_values( + return unordered_keys(filter_values( constraints.machine_views, [](std::optional const &mv) { return mv.has_value(); })); } std::unordered_set get_all_layers(MachineMappingConstraints const &partial_solution) { - return keys(partial_solution.machine_views); + return unordered_keys(partial_solution.machine_views); } std::optional get_machine_view_for_layer( @@ -103,7 +103,7 @@ MachineMappingConstraints with_additional_constraints( std::optional require_only_root(MachineMappingConstraints const &constraints) { - ASSERT(keys(constraints.machine_views) == + ASSERT(unordered_keys(constraints.machine_views) == std::unordered_set{binary_tree_root_path()}, fmt::format("require_only_root expected constraints to have only a " "single key (the root path), but received {}", diff --git a/lib/compiler/src/compiler/machine_mapping/machine_view.cc b/lib/compiler/src/compiler/machine_mapping/machine_view.cc index 090dec5845..7d00707b0d 100644 --- a/lib/compiler/src/compiler/machine_mapping/machine_view.cc +++ b/lib/compiler/src/compiler/machine_mapping/machine_view.cc @@ -221,8 +221,8 @@ static OperatorAtomicTaskShardBinding mappings = get_operator_to_ptensor_mappings(op_attrs, inputs_dim_degrees); std::unordered_map - ptensor_coords = generate_map( - keys(inputs_dim_degrees), + ptensor_coords = generate_unordered_map( + unordered_keys(inputs_dim_degrees), [&](TensorSlotName const &slot_name) -> ParallelTensorSpaceCoordinate { num_ptensor_shard_dims_t num_shard_dims = diff --git a/lib/compiler/src/compiler/machine_mapping/memory_optimization/get_optimal_machine_mapping_with_memory.cc b/lib/compiler/src/compiler/machine_mapping/memory_optimization/get_optimal_machine_mapping_with_memory.cc index a3f2009a60..dad91dd317 100644 --- a/lib/compiler/src/compiler/machine_mapping/memory_optimization/get_optimal_machine_mapping_with_memory.cc +++ b/lib/compiler/src/compiler/machine_mapping/memory_optimization/get_optimal_machine_mapping_with_memory.cc @@ -17,7 +17,7 @@ #include "pcg/parallel_computation_graph/parallel_computation_graph.h" #include "utils/containers/contains.h" #include "utils/containers/flatmap.h" -#include "utils/containers/generate_map.h" +#include "utils/containers/generate_unordered_map.h" #include "utils/containers/get_all_assignments.h" #include "utils/containers/unordered_set_of.h" #include "utils/exception.h" @@ -84,7 +84,7 @@ MachineMappingWithMemoryResult get_optimal_machine_mapping_with_memory( std::unordered_set const &boundary_layers) -> std::unordered_set { std::unordered_map> - allowed = generate_map( + allowed = generate_unordered_map( boundary_layers, [&](BinaryTreePath const &l) -> std::unordered_set { UnmappedRuntimeOnlyOpCostEstimateKey leaf = diff --git a/lib/compiler/src/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.cc b/lib/compiler/src/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.cc index 6e2096afcc..97814288ab 100644 --- a/lib/compiler/src/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.cc +++ b/lib/compiler/src/compiler/machine_mapping/parallel_layer_guid_oblivious_machine_mapping.cc @@ -5,11 +5,11 @@ #include "op-attrs/get_operator_task_space.h" #include "op-attrs/parallel_tensor_shape.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" -#include "utils/containers/binary_merge_disjoint_maps.h" #include "utils/containers/map_keys.h" #include "utils/containers/require_same.h" #include "utils/containers/try_at.h" #include "utils/full_binary_tree/binary_tree_path.h" +#include "utils/containers/binary_merge_disjoint_unordered_maps.h" namespace FlexFlow { @@ -17,7 +17,7 @@ ParallelLayerGuidObliviousMachineMapping binary_combine_mappings( ParallelLayerGuidObliviousMachineMapping const &lhs, ParallelLayerGuidObliviousMachineMapping const &rhs) { return ParallelLayerGuidObliviousMachineMapping{ - binary_merge_disjoint_maps( + binary_merge_disjoint_unordered_maps( map_keys(lhs.raw_mapping, nest_inside_left_child), map_keys(rhs.raw_mapping, nest_inside_right_child)), }; @@ -45,7 +45,7 @@ std::unordered_map PCGBinarySPDecomposition const &decomposition, ParallelLayerGuidObliviousMachineMapping const &mapping) { std::unordered_set leaf_paths = require_same( - pcg_sp_tree_get_all_leaf_paths(decomposition), keys(mapping.raw_mapping)); + pcg_sp_tree_get_all_leaf_paths(decomposition), unordered_keys(mapping.raw_mapping)); std::unordered_map path_to_op_task_space_map = @@ -54,7 +54,7 @@ std::unordered_map return get_operator_task_space(pcg, l); }); - return generate_map( + return generate_unordered_map( leaf_paths, [&](BinaryTreePath const &p) -> MachineSpaceStencil { return MachineSpaceStencil{ /*operator_task_space=*/path_to_op_task_space_map.at(p), @@ -71,12 +71,12 @@ std::unordered_map> std::unordered_map tree_leaf_map = mm_problem_tree_get_path_to_leaf_map(tree); - std::unordered_set mapping_paths = keys(mapping.raw_mapping); - std::unordered_set tree_paths = keys(tree_leaf_map); + std::unordered_set mapping_paths = unordered_keys(mapping.raw_mapping); + std::unordered_set tree_paths = unordered_keys(tree_leaf_map); ASSERT(is_subseteq_of(mapping_paths, tree_paths)); - return generate_map( + return generate_unordered_map( tree_paths, [&](BinaryTreePath const &p) -> std::optional { if (!contains_key(mapping.raw_mapping, p)) { diff --git a/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc b/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc index cd8e634f2c..4d1c88d9eb 100644 --- a/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc +++ b/lib/compiler/src/compiler/series_parallel/pcg/pcg_binary_sp_decomposition.cc @@ -152,7 +152,7 @@ SPDecompositionTreeNodeType std::unordered_set pcg_sp_tree_get_all_leaf_paths(PCGBinarySPDecomposition const &tree) { - return keys(pcg_sp_tree_get_path_to_leaf_map(tree)); + return unordered_keys(pcg_sp_tree_get_path_to_leaf_map(tree)); } std::unordered_set diff --git a/lib/compiler/src/compiler/task_graph_simulator/pcg_task_graph.cc b/lib/compiler/src/compiler/task_graph_simulator/pcg_task_graph.cc index d4d5a78d6a..b016b106e9 100644 --- a/lib/compiler/src/compiler/task_graph_simulator/pcg_task_graph.cc +++ b/lib/compiler/src/compiler/task_graph_simulator/pcg_task_graph.cc @@ -15,6 +15,7 @@ #include "utils/graph/instances/adjacency_digraph.h" #include #include +#include "utils/containers/set_of.h" namespace FlexFlow { @@ -23,21 +24,21 @@ PCGTaskGraph MachineMapping const &machine_mapping, MachineComputeSpecification const &machine_spec) { DiGraph digraph = DiGraph::create(); - bidict node_to_task; + ManyToOne node_to_task; bidict node_to_layer; - std::unordered_map> node_to_devices; + std::map> node_to_devices; for (parallel_layer_guid_t const &layer : get_parallel_layers(pcg)) { MachineView mv = machine_mapping.machine_views.at(layer); RuntimeOnlyOpCostEstimateKey op_key = get_mapped_runtime_only_op_cost_estimate_key_for_layer(pcg, layer, mv); Node node = digraph.add_node(); - node_to_task.equate(node, PCGTask{op_key}); - node_to_layer.equate(node, layer); + node_to_task.insert({node, PCGTask{op_key}}); + node_to_layer.equate_strict(node, layer); node_to_devices[node] = - get_device_ids(get_operator_task_space(pcg, layer), - machine_mapping.machine_views.at(layer), - machine_spec); + set_of(get_device_ids(get_operator_task_space(pcg, layer), + machine_mapping.machine_views.at(layer), + machine_spec)); } for (ParallelComputationGraphEdge const &edge : get_edges(pcg)) { @@ -46,7 +47,7 @@ PCGTaskGraph TensorSetMovement movement = get_tensor_set_movement_from_pcg_edge(edge, pcg, src_mv, dst_mv); Node node = digraph.add_node(); - node_to_task.equate(node, PCGTask{movement}); + node_to_task.insert({node, PCGTask{movement}}); node_to_devices[node] = {}; Node src_node = node_to_layer.at_r(get_src_layer(edge)); Node dst_node = node_to_layer.at_r(get_dst_layer(edge)); diff --git a/lib/compiler/src/compiler/task_graph_simulator/task_simulator.cc b/lib/compiler/src/compiler/task_graph_simulator/task_simulator.cc index bc528493a8..3fabfc3966 100644 --- a/lib/compiler/src/compiler/task_graph_simulator/task_simulator.cc +++ b/lib/compiler/src/compiler/task_graph_simulator/task_simulator.cc @@ -62,7 +62,7 @@ milliseconds_t task_simulator_estimate_forward_pass_time( std::unordered_set devices_occupied = set_union(transform(in_progress_tasks, get_devices)); - std::unordered_set required_devices = get_devices(task); + std::unordered_set required_devices = unordered_set_of(get_devices(task)); return intersection(devices_occupied, required_devices).empty(); }; diff --git a/lib/compiler/src/compiler/unity_algorithm/unity_algorithm.cc b/lib/compiler/src/compiler/unity_algorithm/unity_algorithm.cc index be8c7c4f98..05a049e98f 100644 --- a/lib/compiler/src/compiler/unity_algorithm/unity_algorithm.cc +++ b/lib/compiler/src/compiler/unity_algorithm/unity_algorithm.cc @@ -20,7 +20,7 @@ #include "substitutions/sub_parallel_computation_graph.h" #include "substitutions/substitution.h" #include "substitutions/unity_substitution_set.h" -#include "utils/containers/generate_map.h" +#include "utils/containers/generate_unordered_map.h" #include "utils/deduplicated_priority_queue.h" #include "utils/graph/node/algorithms.h" #include "utils/optional.h" diff --git a/lib/kernels/include/kernels/accessor.h b/lib/kernels/include/kernels/accessor.h index 27d3693e7e..cacc56046f 100644 --- a/lib/kernels/include/kernels/accessor.h +++ b/lib/kernels/include/kernels/accessor.h @@ -43,6 +43,11 @@ class GenericTensorAccessorR { bool operator==(GenericTensorAccessorR const &) const; bool operator!=(GenericTensorAccessorR const &) const; + bool operator<(GenericTensorAccessorR const &) const; + bool operator<=(GenericTensorAccessorR const &) const; + bool operator>(GenericTensorAccessorR const &) const; + bool operator>=(GenericTensorAccessorR const &) const; + template real_type_t
const &at(TensorDimsCoord const &indices) const { ASSERT(this->device_type == DeviceType::CPU, @@ -97,6 +102,11 @@ class GenericTensorAccessorW { bool operator==(GenericTensorAccessorW const &) const; bool operator!=(GenericTensorAccessorW const &) const; + bool operator<(GenericTensorAccessorW const &) const; + bool operator<=(GenericTensorAccessorW const &) const; + bool operator>(GenericTensorAccessorW const &) const; + bool operator>=(GenericTensorAccessorW const &) const; + operator GenericTensorAccessorR() const; template diff --git a/lib/kernels/src/kernels/accessor.cc b/lib/kernels/src/kernels/accessor.cc index a3f8ead17f..75f144a57a 100644 --- a/lib/kernels/src/kernels/accessor.cc +++ b/lib/kernels/src/kernels/accessor.cc @@ -98,6 +98,22 @@ bool GenericTensorAccessorW::operator!=( return this->tie() != other.tie(); } +bool GenericTensorAccessorW::operator<(GenericTensorAccessorW const &other) const { + return this->tie() < other.tie(); +} + +bool GenericTensorAccessorW::operator<=(GenericTensorAccessorW const &other) const { + return this->tie() <= other.tie(); +} + +bool GenericTensorAccessorW::operator>(GenericTensorAccessorW const &other) const { + return this->tie() > other.tie(); +} + +bool GenericTensorAccessorW::operator>=(GenericTensorAccessorW const &other) const { + return this->tie() >= other.tie(); +} + int32_t *GenericTensorAccessorW::get_int32_ptr() const { return this->get(); } @@ -150,6 +166,22 @@ bool GenericTensorAccessorR::operator!=( return this->tie() != other.tie(); } +bool GenericTensorAccessorR::operator<(GenericTensorAccessorR const &other) const { + return this->tie() < other.tie(); +} + +bool GenericTensorAccessorR::operator<=(GenericTensorAccessorR const &other) const { + return this->tie() <= other.tie(); +} + +bool GenericTensorAccessorR::operator>(GenericTensorAccessorR const &other) const { + return this->tie() > other.tie(); +} + +bool GenericTensorAccessorR::operator>=(GenericTensorAccessorR const &other) const { + return this->tie() >= other.tie(); +} + int32_t const *GenericTensorAccessorR::get_int32_ptr() const { return this->get(); } diff --git a/lib/local-execution/src/local-execution/task_execution.cc b/lib/local-execution/src/local-execution/task_execution.cc index c96c834d4a..68a4c579f8 100644 --- a/lib/local-execution/src/local-execution/task_execution.cc +++ b/lib/local-execution/src/local-execution/task_execution.cc @@ -14,6 +14,7 @@ #include "utils/optional.h" #include "utils/overload.h" #include +#include "utils/containers/unordered_map_from_map.h" namespace FlexFlow { @@ -58,9 +59,9 @@ TaskArgumentAccessor make_task_argument_accessor_for_invocation( return assert_unwrap(value.accessor); }; std::unordered_map - tensor_slots_backing = binary_merge_disjoint_maps( + tensor_slots_backing = unordered_map_from_map(binary_merge_disjoint_maps( map_keys_and_values(invocation.inputs, make_param, get_accessor), - map_keys_and_values(invocation.outputs, make_param, get_accessor)); + map_keys_and_values(invocation.outputs, make_param, get_accessor))); return TaskArgumentAccessor::create( /*allocator=*/allocator, diff --git a/lib/local-execution/src/local-execution/tensor_allocation.cc b/lib/local-execution/src/local-execution/tensor_allocation.cc index bb2a1ba2a4..203345a9af 100644 --- a/lib/local-execution/src/local-execution/tensor_allocation.cc +++ b/lib/local-execution/src/local-execution/tensor_allocation.cc @@ -55,7 +55,7 @@ DynamicOpenDataflowGraph perform_tensor_allocation( Allocator &allocator) { ASSERT(no_tensors_are_allocated(g)); ASSERT(tensors_are_ready_for_allocation(g)); - for (DynamicValueAttrs const &v : keys(preallocated)) { + for (DynamicValueAttrs const &v : unordered_keys(preallocated)) { ASSERT(v.accessor == std::nullopt); } diff --git a/lib/op-attrs/include/op-attrs/pcg_operator_attrs.dtg.toml b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.dtg.toml index 88a65f75c5..f2dd7c9350 100644 --- a/lib/op-attrs/include/op-attrs/pcg_operator_attrs.dtg.toml +++ b/lib/op-attrs/include/op-attrs/pcg_operator_attrs.dtg.toml @@ -11,13 +11,13 @@ features = [ ] includes = [ - "op-attrs/ops/attention_attrs.dtg.h", - "op-attrs/ops/batch_matmul_attrs.dtg.h", - "op-attrs/ops/batch_norm_attrs.dtg.h", - "op-attrs/ops/broadcast_attrs.dtg.h", - "op-attrs/ops/cast_attrs.dtg.h", - "op-attrs/ops/combine_attrs.dtg.h", - "op-attrs/ops/concat_attrs.dtg.h", + "op-attrs/ops/attention_attrs.dtg.h", + "op-attrs/ops/batch_matmul_attrs.dtg.h", + "op-attrs/ops/batch_norm_attrs.dtg.h", + "op-attrs/ops/broadcast_attrs.dtg.h", + "op-attrs/ops/cast_attrs.dtg.h", + "op-attrs/ops/combine_attrs.dtg.h", + "op-attrs/ops/concat_attrs.dtg.h", "op-attrs/ops/conv_2d_attrs.dtg.h", "op-attrs/ops/dropout_attrs.dtg.h", "op-attrs/ops/element_binary_attrs.dtg.h", @@ -61,7 +61,7 @@ key = "cast" [[values]] type = "::FlexFlow::CombineAttrs" -key = "combine_distributed" +key = "parallel_combine" [[values]] type = "::FlexFlow::ConcatAttrs" @@ -125,15 +125,15 @@ key = "reduce" [[values]] type = "::FlexFlow::ReductionAttrs" -key = "reduce_distributed" +key = "parallel_reduce" [[values]] type = "::FlexFlow::RepartitionAttrs" -key = "partition_distributed" +key = "parallel_partition" [[values]] type = "::FlexFlow::ReplicateAttrs" -key = "replicate_distributed" +key = "parallel_replicate" [[values]] type = "::FlexFlow::ReverseAttrs" diff --git a/lib/op-attrs/src/op-attrs/get_incoming_tensor_roles.cc b/lib/op-attrs/src/op-attrs/get_incoming_tensor_roles.cc index eec9ae869c..89936d9b00 100644 --- a/lib/op-attrs/src/op-attrs/get_incoming_tensor_roles.cc +++ b/lib/op-attrs/src/op-attrs/get_incoming_tensor_roles.cc @@ -46,7 +46,7 @@ std::unordered_map }; }, [&](ConcatAttrs const &) { - return generate_map(get_variadic_inputs_slot_name_sequence(), + return generate_unordered_map(get_variadic_inputs_slot_name_sequence(), [](TensorSlotName) -> IncomingTensorRole { return IncomingTensorRole::INPUT; }); diff --git a/lib/op-attrs/src/op-attrs/get_operator_space_to_parallel_tensor_space_mappings.cc b/lib/op-attrs/src/op-attrs/get_operator_space_to_parallel_tensor_space_mappings.cc index 618eb533ff..1a97f8b38b 100644 --- a/lib/op-attrs/src/op-attrs/get_operator_space_to_parallel_tensor_space_mappings.cc +++ b/lib/op-attrs/src/op-attrs/get_operator_space_to_parallel_tensor_space_mappings.cc @@ -8,11 +8,11 @@ #include "op-attrs/ops/weight.h" #include "utils/containers/filtrans.h" #include "utils/containers/get_only.h" -#include "utils/containers/merge_disjoint_maps.h" #include "utils/containers/require_only_key.h" #include "utils/containers/require_two_keys.h" #include "utils/containers/zip_values_strict.h" #include "utils/overload.h" +#include "utils/containers/merge_disjoint_unordered_maps.h" namespace FlexFlow { @@ -281,7 +281,7 @@ std::unordered_map ComputationGraphOpAttrs const &attrs, std::unordered_map const &inputs_degrees) { - return merge_disjoint_maps(std::vector{ + return merge_disjoint_unordered_maps(std::vector{ get_operator_to_input_mappings(attrs, inputs_degrees), get_operator_to_weight_mappings(attrs, inputs_degrees), get_operator_to_output_mappings(attrs, inputs_degrees), diff --git a/lib/op-attrs/src/op-attrs/ops/element_unary.cc b/lib/op-attrs/src/op-attrs/ops/element_unary.cc index 9d02923689..ca7e417814 100644 --- a/lib/op-attrs/src/op-attrs/ops/element_unary.cc +++ b/lib/op-attrs/src/op-attrs/ops/element_unary.cc @@ -35,7 +35,6 @@ ParallelTensorDimDegrees get_output_parallel_dim_degrees( ElementUnaryAttrs const &attrs, ParallelTensorDimDegrees const &input_degrees) { ASSERT(input_degrees.sum_degree.value == 1); - ASSERT(input_degrees.discard_copy_degree.value == 1); return input_degrees; } diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_dim_degrees.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_dim_degrees.cc index 51d7968033..5b5d0b514f 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_dim_degrees.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_dim_degrees.cc @@ -5,10 +5,9 @@ #include "op-attrs/parallel_tensor_dim_idx_t.dtg.h" #include "op-attrs/parallel_tensor_dim_idx_t.h" #include "op-attrs/parallel_tensor_space_coordinate.h" -#include "utils/containers/binary_merge_disjoint_maps.h" #include "utils/containers/filtermap_keys.h" #include "utils/containers/filtrans.h" -#include "utils/containers/generate_map.h" +#include "utils/containers/generate_unordered_map.h" #include "utils/containers/get_all_assignments.h" #include "utils/containers/map_keys.h" #include "utils/containers/map_values.h" @@ -19,6 +18,7 @@ #include "utils/nonnegative_int/nonnegative_range.h" #include "utils/nonnegative_int/num_elements.h" #include "utils/orthotope/minimal_dim_domain.h" +#include "utils/containers/binary_merge_disjoint_unordered_maps.h" namespace FlexFlow { @@ -96,11 +96,11 @@ std::unordered_map }; std::unordered_map shard_dim_degrees = - generate_map(get_idxs(degrees.shard_degrees), [&](ff_dim_t const &dim) { + generate_unordered_map(get_idxs(degrees.shard_degrees), [&](ff_dim_t const &dim) { return degrees.shard_degrees.at(dim); }); - return binary_merge_disjoint_maps( + return binary_merge_disjoint_unordered_maps( /*lhs=*/replica_dim_degrees, /*rhs=*/map_keys(shard_dim_degrees, [](ff_dim_t const &dim) { return parallel_tensor_dim_idx_t{dim}; @@ -131,7 +131,7 @@ DimDomain ParallelTensorDimDegrees const &dim_degrees) { return DimDomain{ - generate_map(get_parallel_tensor_dim_indices(dim_degrees), + generate_unordered_map(get_parallel_tensor_dim_indices(dim_degrees), [&](parallel_tensor_dim_idx_t idx) { return get_degree_for_parallel_tensor_dim_idx(dim_degrees, idx); diff --git a/lib/op-attrs/src/op-attrs/parallel_tensor_space_coordinate.cc b/lib/op-attrs/src/op-attrs/parallel_tensor_space_coordinate.cc index 0c6e157697..79d765b02c 100644 --- a/lib/op-attrs/src/op-attrs/parallel_tensor_space_coordinate.cc +++ b/lib/op-attrs/src/op-attrs/parallel_tensor_space_coordinate.cc @@ -3,7 +3,7 @@ #include "op-attrs/parallel_tensor_dim_idx_t.h" #include "utils/containers/contains_key.h" #include "utils/containers/filtermap_keys.h" -#include "utils/containers/generate_map.h" +#include "utils/containers/generate_unordered_map.h" #include "utils/containers/unordered_set_of.h" #include "utils/nonnegative_int/num_elements.h" @@ -73,7 +73,7 @@ DimCoord dim_coord_from_parallel_tensor_space_coord( ParallelTensorSpaceCoordinate const &coord) { return DimCoord{ - generate_map(get_dim_idxs_in_ptensor_space_coord(coord), + generate_unordered_map(get_dim_idxs_in_ptensor_space_coord(coord), [&](parallel_tensor_dim_idx_t idx) { return ptensor_coord_component_for_ptensor_dim_idx(coord, idx); diff --git a/lib/op-attrs/src/op-attrs/shape_inference.cc b/lib/op-attrs/src/op-attrs/shape_inference.cc index a3f8066dee..38e116bfd8 100644 --- a/lib/op-attrs/src/op-attrs/shape_inference.cc +++ b/lib/op-attrs/src/op-attrs/shape_inference.cc @@ -52,7 +52,7 @@ static std::vector std::vector expected_slots = slice(slots, 0, v_num_slots.unwrap_nonnegative()); - ASSERT(unordered_set_of(expected_slots) == keys(v)); + ASSERT(unordered_set_of(expected_slots) == unordered_keys(v)); return transform(expected_slots, [&](TensorSlotName const &slot_name) { return v.at(slot_name); diff --git a/lib/op-attrs/test/src/op-attrs/ops/element_unary.cc b/lib/op-attrs/test/src/op-attrs/ops/element_unary.cc index 672b160cbd..09e49a123c 100644 --- a/lib/op-attrs/test/src/op-attrs/ops/element_unary.cc +++ b/lib/op-attrs/test/src/op-attrs/ops/element_unary.cc @@ -53,22 +53,26 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(result == correct); } - SUBCASE("sum degree > 1") { + SUBCASE("discard copy degree > 1") { positive_int degree = 2_p; - CHECK_THROWS(get_output_shape( - attrs, - make_input( - SumDegree{degree}, DiscardCopyDegree{1_p}, 1_p, 1_p, 1_p))); + ParallelTensorShape par_input = + make_input(SumDegree{1_p}, DiscardCopyDegree{degree}, 1_p, 1_p, 1_p); + + tl::expected result = + get_output_shape(attrs, par_input); + tl::expected correct = par_input; + + CHECK(result == correct); } - SUBCASE("discard copy degree > 1") { + SUBCASE("sum degree > 1") { positive_int degree = 2_p; CHECK_THROWS(get_output_shape( attrs, make_input( - SumDegree{1_p}, DiscardCopyDegree{degree}, 1_p, 1_p, 1_p))); + SumDegree{degree}, DiscardCopyDegree{1_p}, 1_p, 1_p, 1_p))); } } } diff --git a/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_operator_task_group.h b/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_operator_task_group.h index aded1eb657..41aca802e7 100644 --- a/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_operator_task_group.h +++ b/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_operator_task_group.h @@ -7,6 +7,7 @@ #include "pcg/mapped_parallel_computation_graph/operator_atomic_task_shard_binding.dtg.h" #include "utils/bidict/bidict.h" #include +#include "utils/one_to_many/one_to_many.h" namespace FlexFlow { @@ -38,10 +39,13 @@ struct MappedOperatorTaskGroup { friend struct ::std::hash; }; -bidict +OneToMany get_tensor_bindings_for_slot_name(MappedOperatorTaskGroup const &, TensorSlotName const &); +std::set get_slot_names_for_task_group(MappedOperatorTaskGroup const &); + + nlohmann::json mapped_operator_task_group_as_dot_json(MappedOperatorTaskGroup const &); diff --git a/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h b/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h index 12c7921282..2e789e14c9 100644 --- a/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h +++ b/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h @@ -3,20 +3,56 @@ #include "pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.dtg.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "pcg/mapped_parallel_computation_graph/mapped_parallel_layer_invocation_info.dtg.h" namespace FlexFlow { std::unordered_set mpcg_get_parallel_layers(MappedParallelComputationGraph const &); + +std::set + mpcg_get_invocation_set(MappedParallelComputationGraph const &); + MappedOperatorTaskGroup mpcg_get_mapping_for_layer(MappedParallelComputationGraph const &, parallel_layer_guid_t); ParallelComputationGraph pcg_from_mpcg(MappedParallelComputationGraph const &); +parallel_layer_guid_t + mpcg_get_source_layer(MappedParallelComputationGraph const &, + parallel_tensor_guid_t const &); + +PCGOperatorAttrs mpcg_get_pcg_op_attrs(MappedParallelComputationGraph const &, + parallel_layer_guid_t const &); + +ParallelTensorAttrs + mpcg_get_parallel_tensor_attrs(MappedParallelComputationGraph const &, + parallel_tensor_guid_t const &); + +std::unordered_map + mpcg_get_incoming_edges(MappedParallelComputationGraph const &, + parallel_layer_guid_t const &); + +std::unordered_set + mpcg_get_outgoing_edges(MappedParallelComputationGraph const &, + parallel_layer_guid_t const &); + +ManyToOne + mpcg_get_incoming_tensors(MappedParallelComputationGraph const &, + parallel_layer_guid_t const &); + +bidict + mpcg_get_outgoing_tensors(MappedParallelComputationGraph const &, + parallel_layer_guid_t const &); + std::unordered_set mpcg_get_edges(MappedParallelComputationGraph const &); +std::unordered_set + mpcg_get_parallel_tensor_uses(MappedParallelComputationGraph const &, + parallel_tensor_guid_t const &); + MappedParallelComputationGraph mapped_pcg_from_pcg_and_mapped_op_task_groups( ParallelComputationGraph const &pcg, std::unordered_map const diff --git a/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_layer_info.dtg.toml b/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_layer_info.dtg.toml new file mode 100644 index 0000000000..056cbecdde --- /dev/null +++ b/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_layer_info.dtg.toml @@ -0,0 +1,31 @@ +namespace = "FlexFlow" +name = "MappedParallelLayerInfo" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +includes = [ + "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", + "pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h", + "pcg/mapped_parallel_computation_graph/mapped_operator_task_group.h", +] + +src_includes = [ +] + +[[fields]] +name = "guid" +type = "::FlexFlow::parallel_layer_guid_t" + +[[fields]] +name = "attrs" +type = "::FlexFlow::ParallelLayerAttrs" + +[[fields]] +name = "mapping" +type = "::FlexFlow::MappedOperatorTaskGroup" diff --git a/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_layer_invocation_info.dtg.toml b/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_layer_invocation_info.dtg.toml new file mode 100644 index 0000000000..b28884d551 --- /dev/null +++ b/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_layer_invocation_info.dtg.toml @@ -0,0 +1,33 @@ +namespace = "FlexFlow" +name = "MappedParallelLayerInvocationInfo" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +includes = [ + "pcg/mapped_parallel_computation_graph/mapped_parallel_layer_info.dtg.h", + "pcg/parallel_computation_graph/parallel_tensor_info.dtg.h", + "op-attrs/tensor_slot_name.dtg.h", +] + +src_includes = [ + "utils/fmt/map.h", + "utils/hash/map.h", +] + +[[fields]] +name = "incoming" +type = "std::map<::FlexFlow::TensorSlotName, ::FlexFlow::ParallelTensorInfo>" + +[[fields]] +name = "layer_info" +type = "::FlexFlow::MappedParallelLayerInfo" + +[[fields]] +name = "outgoing" +type = "std::map<::FlexFlow::TensorSlotName, ::FlexFlow::ParallelTensorInfo>" diff --git a/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_layer_invocation_info.h b/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_layer_invocation_info.h new file mode 100644 index 0000000000..dcda3a977b --- /dev/null +++ b/lib/pcg/include/pcg/mapped_parallel_computation_graph/mapped_parallel_layer_invocation_info.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_MAPPED_PARALLEL_COMPUTATION_GRAPH_MAPPED_PARALLEL_LAYER_INVOCATION_INFO_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_MAPPED_PARALLEL_COMPUTATION_GRAPH_MAPPED_PARALLEL_LAYER_INVOCATION_INFO_H + +#include "pcg/mapped_parallel_computation_graph/mapped_parallel_layer_invocation_info.dtg.h" +#include "pcg/parallel_computation_graph/parallel_layer_invocation_info.dtg.h" +#include "pcg/mapped_parallel_computation_graph/mapped_operator_task_group.h" + +namespace FlexFlow { + +MappedParallelLayerInvocationInfo + mapped_parallel_layer_invocation_info_from_pcg_invocation_and_mapping( + ParallelLayerInvocationInfo const &, + MappedOperatorTaskGroup const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h index 0368be62bc..7c5a825420 100644 --- a/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_computation_graph.h @@ -10,7 +10,9 @@ #include "pcg/parallel_computation_graph/parallel_layer_added_result.dtg.h" #include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" #include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_use_t.dtg.h" #include +#include "pcg/parallel_computation_graph/parallel_layer_invocation_info.dtg.h" namespace FlexFlow { @@ -37,6 +39,13 @@ ParallelLayerAddedResult OperatorTaskSpace get_operator_task_space(ParallelComputationGraph const &pcg, parallel_layer_guid_t const &layer); +std::set + pcg_get_invocation_info_set(ParallelComputationGraph const &); + +ParallelLayerInvocationInfo + pcg_get_invocation_info_for_layer(ParallelComputationGraph const &, + parallel_layer_guid_t); + std::unordered_set get_pcg_edges_from_layer_to_layer(ParallelComputationGraph const &pcg, parallel_layer_guid_t const &src, @@ -53,6 +62,10 @@ std::unordered_map get_incoming_edges(ParallelComputationGraph const &, parallel_layer_guid_t const &); +std::unordered_set + pcg_get_parallel_tensor_uses(ParallelComputationGraph const &, + parallel_tensor_guid_t const &); + std::unordered_set get_initial_layers(ParallelComputationGraph const &); diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_info.dtg.toml b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_info.dtg.toml new file mode 100644 index 0000000000..67107cdbb1 --- /dev/null +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_info.dtg.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "ParallelLayerInfo" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +includes = [ + "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h", + "pcg/parallel_computation_graph/parallel_layer_attrs.dtg.h", +] + +src_includes = [ +] + +[[fields]] +name = "guid" +type = "::FlexFlow::parallel_layer_guid_t" + +[[fields]] +name = "attrs" +type = "::FlexFlow::ParallelLayerAttrs" diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_invocation_info.dtg.toml b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_invocation_info.dtg.toml new file mode 100644 index 0000000000..6555752bcf --- /dev/null +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_layer_invocation_info.dtg.toml @@ -0,0 +1,33 @@ +namespace = "FlexFlow" +name = "ParallelLayerInvocationInfo" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +includes = [ + "pcg/parallel_computation_graph/parallel_layer_info.dtg.h", + "pcg/parallel_computation_graph/parallel_tensor_info.dtg.h", + "op-attrs/tensor_slot_name.dtg.h", +] + +src_includes = [ + "utils/fmt/map.h", + "utils/hash/map.h", +] + +[[fields]] +name = "incoming" +type = "std::map<::FlexFlow::TensorSlotName, ::FlexFlow::ParallelTensorInfo>" + +[[fields]] +name = "layer_info" +type = "::FlexFlow::ParallelLayerInfo" + +[[fields]] +name = "outgoing" +type = "std::map<::FlexFlow::TensorSlotName, ::FlexFlow::ParallelTensorInfo>" diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_info.dtg.toml b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_info.dtg.toml new file mode 100644 index 0000000000..09f53d6954 --- /dev/null +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_info.dtg.toml @@ -0,0 +1,26 @@ +namespace = "FlexFlow" +name = "ParallelTensorInfo" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "json", + "fmt", +] + +includes = [ + "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h", + "pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h", +] + +src_includes = [ +] + +[[fields]] +name = "guid" +type = "::FlexFlow::parallel_tensor_guid_t" + +[[fields]] +name = "attrs" +type = "::FlexFlow::ParallelTensorAttrs" diff --git a/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_use_t.h b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_use_t.h new file mode 100644 index 0000000000..f5e5575632 --- /dev/null +++ b/lib/pcg/include/pcg/parallel_computation_graph/parallel_tensor_use_t.h @@ -0,0 +1,15 @@ +#ifndef _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_PARALLEL_TENSOR_USE_T_H +#define _FLEXFLOW_LIB_PCG_INCLUDE_PCG_PARALLEL_COMPUTATION_GRAPH_PARALLEL_TENSOR_USE_T_H + +#include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_use_t.dtg.h" + +namespace FlexFlow { + +parallel_layer_guid_t + parallel_tensor_use_get_layer(parallel_tensor_use_t const &); +TensorSlotName parallel_tensor_use_get_slot(parallel_tensor_use_t const &); + +} // namespace FlexFlow + +#endif diff --git a/lib/pcg/src/pcg/computation_graph.cc b/lib/pcg/src/pcg/computation_graph.cc index 56bfb98856..1b0b3b3204 100644 --- a/lib/pcg/src/pcg/computation_graph.cc +++ b/lib/pcg/src/pcg/computation_graph.cc @@ -2,7 +2,6 @@ #include "op-attrs/computation_graph_op_attrs.h" #include "op-attrs/get_incoming_tensor_roles.h" #include "op-attrs/shape_inference.h" -#include "utils/containers/binary_merge_disjoint_maps.h" #include "utils/containers/concat_vectors.h" #include "utils/containers/filter_values.h" #include "utils/containers/filtrans.h" @@ -35,6 +34,7 @@ #include "utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/labelled_open_kwarg_dataflow_graph_view_as_dot.h" #include "utils/graph/node/algorithms.h" #include "utils/record_formatter.h" +#include "utils/containers/binary_merge_disjoint_unordered_maps.h" namespace FlexFlow { @@ -101,7 +101,7 @@ LayerAddedResult add_layer( KwargNodeAddedResult added = computation_graph.raw_graph.add_node( layer_attrs, - binary_merge_disjoint_maps(raw_inputs, raw_weights), + binary_merge_disjoint_unordered_maps(raw_inputs, raw_weights), output_attrs); return LayerAddedResult{ @@ -196,7 +196,7 @@ static std::unordered_map ASSERT(incoming_tensors.size() == incoming_slot_roles.size()); std::unordered_set slots_with_desired_role = - keys(filter_values(incoming_slot_roles, [&](IncomingTensorRole role) { + unordered_keys(filter_values(incoming_slot_roles, [&](IncomingTensorRole role) { return role == desired_role; })); diff --git a/lib/pcg/src/pcg/computation_graph_builder.cc b/lib/pcg/src/pcg/computation_graph_builder.cc index b687aa11b6..2eb140fa58 100644 --- a/lib/pcg/src/pcg/computation_graph_builder.cc +++ b/lib/pcg/src/pcg/computation_graph_builder.cc @@ -48,6 +48,7 @@ #include "utils/fmt/set.h" #include "utils/stack_vector/stack_vector_of.h" #include +#include "utils/containers/binary_merge_disjoint_unordered_maps.h" namespace FlexFlow { @@ -114,11 +115,11 @@ static void check_incoming_tensor_roles( restrict_keys(get_incoming_tensor_roles(layer.op_attrs), set_union(input_slots, weight_slots)); std::unordered_map current = - binary_merge_disjoint_maps( - generate_map( + binary_merge_disjoint_unordered_maps( + generate_unordered_map( input_slots, [](TensorSlotName) { return IncomingTensorRole::INPUT; }), - generate_map(weight_slots, [](TensorSlotName) { + generate_unordered_map(weight_slots, [](TensorSlotName) { return IncomingTensorRole::WEIGHT; })); @@ -134,8 +135,8 @@ std::unordered_map &weight_initializers, std::optional> const &outputs) { - ASSERT(are_disjoint(keys(inputs), keys(weight_initializers))); - check_incoming_tensor_roles(layer, keys(inputs), keys(weight_initializers)); + ASSERT(are_disjoint(unordered_keys(inputs), unordered_keys(weight_initializers))); + check_incoming_tensor_roles(layer, unordered_keys(inputs), unordered_keys(weight_initializers)); std::unordered_map input_shapes = map_values( inputs, [&](tensor_guid_t const &t) { return this->get_shape(t); }); diff --git a/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_operator_task_group.cc b/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_operator_task_group.cc index d0fd3300f5..3bb508681d 100644 --- a/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_operator_task_group.cc +++ b/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_operator_task_group.cc @@ -12,6 +12,14 @@ #include "utils/containers/vector_of.h" #include "utils/hash/tuple.h" #include "utils/nonnegative_int/num_elements.h" +#include "utils/containers/require_all_same1.h" +#include "utils/containers/set_of.h" +#include "utils/containers/keys.h" +#include "utils/containers/contains.h" +#include "utils/bidict/algorithms/right_entries.h" +#include "utils/containers/map_values.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/many_to_one/invert_many_to_one.h" namespace FlexFlow { @@ -23,7 +31,7 @@ MappedOperatorTaskGroup::MappedOperatorTaskGroup( transform(vector_of(shard_bindings.right_values()), [&](OperatorAtomicTaskShardBinding const &s) -> std::unordered_set { - return keys(s.tensor_coords); + return unordered_keys(s.tensor_coords); }); std::unordered_set slot_names = @@ -39,8 +47,6 @@ MappedOperatorTaskGroup::MappedOperatorTaskGroup( return ptensor_space_coord_for_slot_name(signature, slot_name); }); - ASSERT(are_all_distinct(coords_for_key)); - std::vector coord_dims_for_key = transform(coords_for_key, [](ParallelTensorSpaceCoordinate const &c) { return ptensor_coord_num_dims(c); @@ -92,15 +98,27 @@ bidict const & return this->shard_bindings; } -bidict +OneToMany get_tensor_bindings_for_slot_name(MappedOperatorTaskGroup const &task_group, TensorSlotName const &slot_name) { - return transform_values(task_group.get_shard_bindings(), - [&](OperatorAtomicTaskShardBinding const &b) { - return ptensor_space_coord_for_slot_name(b, - slot_name); - }) - .reversed(); + std::set slot_names = get_slot_names_for_task_group(task_group); + ASSERT(contains(slot_names, slot_name)); + + std::unordered_map m = + map_values(task_group.get_shard_bindings().as_unordered_map(), + [&](OperatorAtomicTaskShardBinding const &b) -> ParallelTensorSpaceCoordinate { + return ptensor_space_coord_for_slot_name(b, slot_name); + }); + + return invert_many_to_one(many_to_one_from_unstructured_relation(unordered_set_of(m))); +} + +std::set get_slot_names_for_task_group(MappedOperatorTaskGroup const &g) { + return require_all_same1( + transform(vector_of(right_entries(g.get_shard_bindings())), + [&](OperatorAtomicTaskShardBinding const &shard_bindings) -> std::set { + return keys(shard_bindings.tensor_coords); + })); } nlohmann::json diff --git a/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.cc b/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.cc index f4fa946a66..1bece17c9a 100644 --- a/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.cc +++ b/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.cc @@ -2,12 +2,20 @@ #include "op-attrs/pcg_operator_attrs.h" #include "pcg/mapped_parallel_computation_graph/mapped_parallel_layer_attrs.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "utils/bidict/algorithms/bidict_from_map.h" #include "utils/bidict/algorithms/transform_keys.h" #include "utils/containers/transform.h" #include "utils/graph/kwarg_dataflow_graph/algorithms/find_isomorphism_between_kwarg_dataflow_graphs.h" #include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/labelled_kwarg_dataflow_graph_view_as_dot.h" #include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/materialize_labelled_kwarg_dataflow_graph_view.h" #include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/rewrite_labelled_kwarg_dataflow_graph_node_labels.h" +#include "utils/many_to_one/many_to_one_from_map.h" +#include "pcg/mapped_parallel_computation_graph/mapped_parallel_layer_invocation_info.h" +#include "pcg/mapped_parallel_computation_graph/mapped_operator_task_group.h" +#include "utils/containers/set_of.h" +#include "utils/containers/set_union.h" +#include "utils/containers/keys.h" +#include "utils/containers/require_all_of.h" namespace FlexFlow { @@ -16,6 +24,22 @@ std::unordered_set return get_parallel_layers(pcg_from_mpcg(mpcg)); } +std::set + mpcg_get_invocation_set(MappedParallelComputationGraph const &mpcg) +{ + auto mk_mapped_invocation = [&](ParallelLayerInvocationInfo const &invocation) + -> MappedParallelLayerInvocationInfo + { + MappedOperatorTaskGroup mapping = mpcg_get_mapping_for_layer(mpcg, invocation.layer_info.guid); + + return mapped_parallel_layer_invocation_info_from_pcg_invocation_and_mapping(invocation, mapping); + }; + + ParallelComputationGraph pcg = pcg_from_mpcg(mpcg); + + return transform(pcg_get_invocation_info_set(pcg), mk_mapped_invocation); +} + MappedOperatorTaskGroup mpcg_get_mapping_for_layer(MappedParallelComputationGraph const &mpcg, parallel_layer_guid_t l) { @@ -46,6 +70,59 @@ ParallelComputationGraph }; } +parallel_layer_guid_t + mpcg_get_source_layer(MappedParallelComputationGraph const &mpcg, + parallel_tensor_guid_t const &t) { + return get_source_layer(pcg_from_mpcg(mpcg), t); +} + +PCGOperatorAttrs + mpcg_get_pcg_op_attrs(MappedParallelComputationGraph const &mpcg, + parallel_layer_guid_t const &l) { + return pcg_get_op_attrs(pcg_from_mpcg(mpcg), l); +} + +ParallelTensorAttrs + mpcg_get_parallel_tensor_attrs(MappedParallelComputationGraph const &mpcg, + parallel_tensor_guid_t const &t) { + return get_parallel_tensor_attrs(pcg_from_mpcg(mpcg), t); +} + +std::unordered_map + mpcg_get_incoming_edges(MappedParallelComputationGraph const &mpcg, + parallel_layer_guid_t const &l) { + return get_incoming_edges(pcg_from_mpcg(mpcg), l); +} + +std::unordered_set + mpcg_get_outgoing_edges(MappedParallelComputationGraph const &mpcg, + parallel_layer_guid_t const &l) { + return get_outgoing_edges(pcg_from_mpcg(mpcg), l); +} + +ManyToOne + mpcg_get_incoming_tensors(MappedParallelComputationGraph const &mpcg, + parallel_layer_guid_t const &l) { + return many_to_one_from_map(get_incoming_tensors(pcg_from_mpcg(mpcg), l)); +} + +bidict + mpcg_get_outgoing_tensors(MappedParallelComputationGraph const &mpcg, + parallel_layer_guid_t const &l) { + return bidict_from_map(get_outgoing_tensors(pcg_from_mpcg(mpcg), l)); +} + +std::unordered_set + mpcg_get_edges(MappedParallelComputationGraph const &mpcg) { + return get_edges(pcg_from_mpcg(mpcg)); +} + +std::unordered_set + mpcg_get_parallel_tensor_uses(MappedParallelComputationGraph const &mpcg, + parallel_tensor_guid_t const &t) { + return pcg_get_parallel_tensor_uses(pcg_from_mpcg(mpcg), t); +} + MappedParallelComputationGraph mapped_pcg_from_pcg_and_mapped_op_task_groups( ParallelComputationGraph const &pcg, std::unordered_map const @@ -57,6 +134,22 @@ MappedParallelComputationGraph mapped_pcg_from_pcg_and_mapped_op_task_groups( return mapped_op_task_groups.at(l); }; + auto slot_names_for_layer = [&](parallel_layer_guid_t l) -> std::set { + return set_union(keys(get_incoming_tensors(pcg, l)), keys(get_outgoing_tensors(pcg, l))); + }; + + auto slot_names_for_layer_mapping = [&](parallel_layer_guid_t l) -> std::set { + return get_slot_names_for_task_group(mapping_for_layer(l)); + }; + + require_all_of( + get_parallel_layers(pcg), + [&](parallel_layer_guid_t l) -> void { + std::set for_layer = slot_names_for_layer(l); + std::set for_layer_mapping = slot_names_for_layer_mapping(l); + ASSERT(for_layer == for_layer_mapping); + }); + auto mpcg_layer_attrs_from_pcg_layer_attrs = [&](Node const &node, ParallelLayerAttrs const &pcg_layer_attrs) -> MappedParallelLayerAttrs { diff --git a/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_parallel_layer_invocation_info.cc b/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_parallel_layer_invocation_info.cc new file mode 100644 index 0000000000..bda6afb60c --- /dev/null +++ b/lib/pcg/src/pcg/mapped_parallel_computation_graph/mapped_parallel_layer_invocation_info.cc @@ -0,0 +1,22 @@ +#include "pcg/mapped_parallel_computation_graph/mapped_parallel_layer_invocation_info.h" + +namespace FlexFlow { + +MappedParallelLayerInvocationInfo + mapped_parallel_layer_invocation_info_from_pcg_invocation_and_mapping( + ParallelLayerInvocationInfo const &invocation_info, + MappedOperatorTaskGroup const &mapping) +{ + return MappedParallelLayerInvocationInfo{ + /*incoming=*/invocation_info.incoming, + /*layer_info=*/MappedParallelLayerInfo{ + /*guid=*/invocation_info.layer_info.guid, + /*attrs=*/invocation_info.layer_info.attrs, + /*mapping=*/mapping, + }, + /*outgoing=*/invocation_info.outgoing, + }; +} + + +} // namespace FlexFlow diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc index a548ceb65a..40e8cb9e5e 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph.cc @@ -28,6 +28,7 @@ #include "utils/graph/kwarg_dataflow_graph/algorithms/find_isomorphism_between_kwarg_dataflow_graphs.h" #include "utils/graph/kwarg_dataflow_graph/algorithms/get_incoming_kwarg_dataflow_outputs_for_node.h" #include "utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_edges_from_node_to_node.h" +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_value_uses.h" #include "utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_kwarg_dataflow_edges_for_node.h" #include "utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_kwarg_dataflow_outputs_for_node.h" #include "utils/graph/labelled_kwarg_dataflow_graph/algorithms/labelled_kwarg_dataflow_graph_view_as_dot.h" @@ -36,6 +37,8 @@ #include "utils/graph/node/node.dtg.h" #include "utils/record_formatter.h" #include +#include "utils/containers/map_from_unordered.h" +#include "utils/containers/binary_merge_disjoint_unordered_maps.h" namespace FlexFlow { @@ -97,7 +100,7 @@ ParallelLayerAddedResult add_parallel_layer( std::unordered_map output_flags = maybe_output_flags.value_or( - generate_map(keys(output_shapes), + generate_unordered_map(unordered_keys(output_shapes), [](TensorSlotName const &) { return CreateGrad::YES; })); std::unordered_map output_attrs = @@ -110,7 +113,7 @@ ParallelLayerAddedResult add_parallel_layer( KwargNodeAddedResult op_added = pcg.raw_graph.add_node( layer_attrs, - binary_merge_disjoint_maps(unwrapped_inputs, unwrapped_weights), + binary_merge_disjoint_unordered_maps(unwrapped_inputs, unwrapped_weights), output_attrs); return ParallelLayerAddedResult{ @@ -163,6 +166,46 @@ OperatorTaskSpace get_operator_task_space(ParallelComputationGraph const &pcg, compgraph_op_attrs_from_pcg_op_attrs(op_attrs).value(), input_degrees); } +std::set + pcg_get_invocation_info_set(ParallelComputationGraph const &pcg) +{ + return transform(set_of(get_parallel_layers(pcg)), + [&](parallel_layer_guid_t l) -> ParallelLayerInvocationInfo { + return pcg_get_invocation_info_for_layer(pcg, l); + }); +} + +ParallelLayerInvocationInfo + pcg_get_invocation_info_for_layer(ParallelComputationGraph const &pcg, + parallel_layer_guid_t l) +{ + ParallelLayerAttrs l_attrs = get_parallel_layer_attrs(pcg, l); + + std::map incoming = + map_from_unordered(get_incoming_tensors(pcg, l)); + + std::map outgoing = + map_from_unordered(get_outgoing_tensors(pcg, l)); + + auto get_parallel_tensor_info = [&](parallel_tensor_guid_t t) -> ParallelTensorInfo { + ParallelTensorAttrs t_attrs = get_parallel_tensor_attrs(pcg, t); + + return ParallelTensorInfo{ + /*guid=*/t, + /*attrs=*/t_attrs, + }; + }; + + return ParallelLayerInvocationInfo{ + /*incoming=*/map_values(incoming, get_parallel_tensor_info), + /*layer_info=*/ParallelLayerInfo{ + /*guid=*/l, + /*attrs=*/l_attrs, + }, + /*outgoing=*/map_values(outgoing, get_parallel_tensor_info), + }; +} + std::unordered_set get_edges(ParallelComputationGraph const &pcg) { return transform(get_all_kwarg_dataflow_edges(pcg.raw_graph), @@ -187,9 +230,10 @@ std::unordered_set get_outgoing_edges(ParallelComputationGraph const &pcg, parallel_layer_guid_t const &l) { std::unordered_set> raw_edges = + unordered_set_of( get_outgoing_kwarg_dataflow_edges_for_node(pcg.raw_graph, l.raw_graph_node) - .right_values(); + .right_values()); return transform(raw_edges, [](KwargDataflowEdge const &e) { return ParallelComputationGraphEdge{e}; }); @@ -206,6 +250,17 @@ std::unordered_map }); } +std::unordered_set + pcg_get_parallel_tensor_uses(ParallelComputationGraph const &pcg, + parallel_tensor_guid_t const &t) { + std::unordered_set> raw_uses = + get_kwarg_dataflow_value_uses(pcg.raw_graph, t.raw_graph_output); + + return transform(raw_uses, [](KwargDataflowInput const &i) { + return parallel_tensor_use_t{i}; + }); +} + std::unordered_set get_initial_layers(ParallelComputationGraph const &pcg) { std::unordered_set raw_sources = get_initial_nodes(pcg.raw_graph); @@ -296,7 +351,7 @@ static std::unordered_map ASSERT(incoming_tensors.size() == incoming_slot_roles.size()); std::unordered_set slots_with_desired_role = - keys(filter_values(incoming_slot_roles, [&](IncomingTensorRole role) { + unordered_keys(filter_values(incoming_slot_roles, [&](IncomingTensorRole role) { return role == desired_role; })); diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc index 1d6713dcdb..c01b17c285 100644 --- a/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc @@ -34,6 +34,7 @@ #include "utils/containers/transform.h" #include "utils/containers/zip_values_strict_with.h" #include "utils/containers/zip_with.h" +#include "utils/containers/binary_merge_disjoint_unordered_maps.h" namespace FlexFlow { @@ -678,11 +679,11 @@ static void check_incoming_tensor_roles( std::unordered_map correct = get_incoming_tensor_roles(layer.op_attrs); std::unordered_map current = - binary_merge_disjoint_maps( - generate_map( + binary_merge_disjoint_unordered_maps( + generate_unordered_map( input_slots, [](TensorSlotName) { return IncomingTensorRole::INPUT; }), - generate_map(weight_slots, [](TensorSlotName) { + generate_unordered_map(weight_slots, [](TensorSlotName) { return IncomingTensorRole::WEIGHT; })); @@ -698,8 +699,8 @@ std::unordered_map std::unordered_map const &weight_initializers) { - ASSERT(are_disjoint(keys(inputs), keys(weight_initializers))); - check_incoming_tensor_roles(layer, keys(inputs), keys(weight_initializers)); + ASSERT(are_disjoint(unordered_keys(inputs), unordered_keys(weight_initializers))); + check_incoming_tensor_roles(layer, unordered_keys(inputs), unordered_keys(weight_initializers)); std::unordered_map input_shapes = map_values(inputs, [&](parallel_tensor_guid_t const &i) { diff --git a/lib/pcg/src/pcg/parallel_computation_graph/parallel_tensor_use_t.cc b/lib/pcg/src/pcg/parallel_computation_graph/parallel_tensor_use_t.cc new file mode 100644 index 0000000000..71a9cadf1c --- /dev/null +++ b/lib/pcg/src/pcg/parallel_computation_graph/parallel_tensor_use_t.cc @@ -0,0 +1,14 @@ +#include "pcg/parallel_computation_graph/parallel_tensor_use_t.h" + +namespace FlexFlow { + +parallel_layer_guid_t + parallel_tensor_use_get_layer(parallel_tensor_use_t const &u) { + return parallel_layer_guid_t{u.raw_dataflow_input.node}; +} + +TensorSlotName parallel_tensor_use_get_slot(parallel_tensor_use_t const &u) { + return u.raw_dataflow_input.slot_name; +} + +} // namespace FlexFlow diff --git a/lib/pcg/test/src/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.cc b/lib/pcg/test/src/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.cc index 7856d89f27..a53c67c336 100644 --- a/lib/pcg/test/src/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.cc +++ b/lib/pcg/test/src/pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.cc @@ -79,18 +79,24 @@ TEST_SUITE(FF_TEST_SUITE) { MappedOperatorTaskGroup partition_mapping = MappedOperatorTaskGroup{ bidict{ - {machine_coord(0_n), - OperatorAtomicTaskShardBinding{ - { - {TensorSlotName::OUTPUT, ptensor_coord(0_n)}, - }, - }}, - {machine_coord(1_n), - OperatorAtomicTaskShardBinding{ - { - {TensorSlotName::OUTPUT, ptensor_coord(1_n)}, - }, - }}, + { + machine_coord(0_n), + OperatorAtomicTaskShardBinding{ + { + {TensorSlotName::INPUT, ptensor_coord(0_n)}, + {TensorSlotName::OUTPUT, ptensor_coord(0_n)}, + }, + }, + }, + { + machine_coord(1_n), + OperatorAtomicTaskShardBinding{ + { + {TensorSlotName::INPUT, ptensor_coord(0_n)}, + {TensorSlotName::OUTPUT, ptensor_coord(1_n)}, + }, + }, + }, }, }; @@ -116,20 +122,26 @@ TEST_SUITE(FF_TEST_SUITE) { MappedOperatorTaskGroup{ bidict{ - {machine_coord(0_n), - OperatorAtomicTaskShardBinding{ - { - {TensorSlotName::LHS_INPUT, ptensor_coord(0_n)}, - {TensorSlotName::RHS_INPUT, ptensor_coord(0_n)}, - }, - }}, - {machine_coord(1_n), - OperatorAtomicTaskShardBinding{ - { - {TensorSlotName::LHS_INPUT, ptensor_coord(1_n)}, - {TensorSlotName::RHS_INPUT, ptensor_coord(1_n)}, - }, - }}, + { + machine_coord(0_n), + OperatorAtomicTaskShardBinding{ + { + {TensorSlotName::LHS_INPUT, ptensor_coord(0_n)}, + {TensorSlotName::RHS_INPUT, ptensor_coord(0_n)}, + {TensorSlotName::OUTPUT, ptensor_coord(0_n)}, + }, + }, + }, + { + machine_coord(1_n), + OperatorAtomicTaskShardBinding{ + { + {TensorSlotName::LHS_INPUT, ptensor_coord(1_n)}, + {TensorSlotName::RHS_INPUT, ptensor_coord(1_n)}, + {TensorSlotName::OUTPUT, ptensor_coord(1_n)}, + }, + }, + }, }, }}, }; diff --git a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc index fd314ebaea..44488e70ea 100644 --- a/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc +++ b/lib/pcg/test/src/pcg/parallel_computation_graph/parallel_computation_graph_builder.cc @@ -5,7 +5,7 @@ #include "pcg/parallel_computation_graph/parallel_layer_attrs.h" #include "pcg/parallel_computation_graph/parallel_tensor_guid_t.h" #include "utils/containers/count.h" -#include "utils/containers/generate_map.h" +#include "utils/containers/generate_unordered_map.h" #include "utils/containers/get_only.h" #include "utils/containers/items.h" #include "utils/containers/require_only_key.h" @@ -248,7 +248,7 @@ TEST_SUITE(FF_TEST_SUITE) { /*paddingW=*/paddingW); std::unordered_map layers = - generate_map(get_parallel_layers(b.pcg), + generate_unordered_map(get_parallel_layers(b.pcg), [&](parallel_layer_guid_t const &l) { return get_parallel_layer_attrs(b.pcg, l); }); diff --git a/lib/realm-execution/include/realm-execution/realm_context.h b/lib/realm-execution/include/realm-execution/realm_context.h index ab89e916c0..5b76d52e2c 100644 --- a/lib/realm-execution/include/realm-execution/realm_context.h +++ b/lib/realm-execution/include/realm-execution/realm_context.h @@ -9,6 +9,7 @@ #include "pcg/device_id_t.dtg.h" #include "pcg/machine_space_coordinate.dtg.h" #include "realm-execution/realm.h" +#include "realm-execution/redops/redop_id_t.dtg.h" #include "realm-execution/tasks/task_id_t.dtg.h" #include #include @@ -63,7 +64,7 @@ struct RealmContext { int priority = 0); ///\} - /** \name Data movement */ + /** \name Data movement and reduction */ ///\{ Realm::Event issue_copy(ParallelTensorShape const &src_shape, Realm::RegionInstance src_inst, @@ -72,6 +73,17 @@ struct RealmContext { Realm::ProfilingRequestSet const &requests, Realm::Event wait_on = Realm::Event::NO_EVENT, int priority = 0); + + Realm::Event issue_reduction(ParallelTensorShape const &src_shape, + Realm::RegionInstance src_inst, + ParallelTensorShape const &dst_shape, + Realm::RegionInstance dst_inst, + redop_id_t redop_id, + bool is_fold, + bool exclusive, + Realm::ProfilingRequestSet const &requests, + Realm::Event wait_on = Realm::Event::NO_EVENT, + int priority = 0); ///\} /** \name Instance management */ diff --git a/lib/realm-execution/include/realm-execution/redops/realm_redop_registry.h b/lib/realm-execution/include/realm-execution/redops/realm_redop_registry.h new file mode 100644 index 0000000000..e7e51326e1 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/redops/realm_redop_registry.h @@ -0,0 +1,16 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REDOPS_REALM_REDOP_REGISTRY_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REDOPS_REALM_REDOP_REGISTRY_H + +#include "realm-execution/realm.h" +#include "realm-execution/redops/redop_id_t.dtg.h" + +namespace FlexFlow { + +/** + * \brief Registers all known reduction operators (redops). + */ +void register_all_redops(Realm::Runtime); + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/include/realm-execution/redops/redop_id_t.dtg.toml b/lib/realm-execution/include/realm-execution/redops/redop_id_t.dtg.toml new file mode 100644 index 0000000000..44e1f32c59 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/redops/redop_id_t.dtg.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "redop_id_t" +type = "enum" +features = [ + "hash", + "fmt", + "rapidcheck", + "json", +] +docstring = ''' +\brief An enum for identifying reduction operators (redops) for use in the Realm runtime. +''' + +[[values]] +name = "SUM_BOOL_REDOP_ID" + +[[values]] +name = "SUM_INT32_REDOP_ID" + +[[values]] +name = "SUM_INT64_REDOP_ID" + +[[values]] +name = "SUM_FLOAT_REDOP_ID" + +[[values]] +name = "SUM_DOUBLE_REDOP_ID" diff --git a/lib/realm-execution/include/realm-execution/redops/redop_id_t.h b/lib/realm-execution/include/realm-execution/redops/redop_id_t.h new file mode 100644 index 0000000000..8565b20b17 --- /dev/null +++ b/lib/realm-execution/include/realm-execution/redops/redop_id_t.h @@ -0,0 +1,22 @@ +#ifndef _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REDOPS_REALM_REDOP_ID_T_H +#define _FLEXFLOW_LIB_REALM_EXECUTION_INCLUDE_REALM_EXECUTION_REDOPS_REALM_REDOP_ID_T_H + +#include "op-attrs/datatype.dtg.h" +#include "realm-execution/realm.h" +#include "realm-execution/redops/redop_id_t.dtg.h" + +namespace FlexFlow { + +/** + * \brief Return the sum reduction operator (redop) ID for a given data type. + */ +redop_id_t get_sum_redop_id_for_data_type(DataType); + +/** + * \brief Convert a \ref FlexFlow::redop_id_t into a Realm reduction op ID. + */ +Realm::ReductionOpID get_realm_reduction_op_id_for_redop_id(redop_id_t); + +} // namespace FlexFlow + +#endif diff --git a/lib/realm-execution/include/realm-execution/tasks/task_id_t.dtg.toml b/lib/realm-execution/include/realm-execution/tasks/task_id_t.dtg.toml index b1e5e07e28..b0bcc23b4d 100644 --- a/lib/realm-execution/include/realm-execution/tasks/task_id_t.dtg.toml +++ b/lib/realm-execution/include/realm-execution/tasks/task_id_t.dtg.toml @@ -327,15 +327,6 @@ name = "COMBINE_FWD_TASK_ID" [[values]] name = "COMBINE_BWD_TASK_ID" -[[values]] -name = "REPLICATE_INIT_TASK_ID" - -[[values]] -name = "REPLICATE_FWD_TASK_ID" - -[[values]] -name = "REPLICATE_BWD_TASK_ID" - [[values]] name = "REDUCTION_INIT_TASK_ID" diff --git a/lib/realm-execution/src/realm-execution/instance_allocation.cc b/lib/realm-execution/src/realm-execution/instance_allocation.cc index 4ef2919b10..79eef476df 100644 --- a/lib/realm-execution/src/realm-execution/instance_allocation.cc +++ b/lib/realm-execution/src/realm-execution/instance_allocation.cc @@ -42,7 +42,7 @@ TensorInstanceBacking perform_instance_allocation( RealmContext &ctx) { ASSERT(no_tensors_are_allocated(g)); ASSERT(tensors_are_ready_for_allocation(g)); - for (DynamicValueAttrs const &v : keys(preallocated)) { + for (DynamicValueAttrs const &v : unordered_keys(preallocated)) { ASSERT(v.accessor == std::nullopt); } diff --git a/lib/realm-execution/src/realm-execution/pcg_instance.cc b/lib/realm-execution/src/realm-execution/pcg_instance.cc index 0ecd02143e..4b068d70be 100644 --- a/lib/realm-execution/src/realm-execution/pcg_instance.cc +++ b/lib/realm-execution/src/realm-execution/pcg_instance.cc @@ -5,6 +5,7 @@ #include "realm-execution/distributed_per_device_op_state_initialization.h" #include "realm-execution/instance_allocation.h" #include "realm-execution/realm_context.h" +#include "realm-execution/redops/redop_id_t.h" #include "realm-execution/tasks/impl/op_task.h" #include "realm-execution/tensor_instance_backing.h" #include "task-spec/dynamic_graph/copy_insertion.h" @@ -215,6 +216,46 @@ static Realm::Event spawn_dynamic_node_invocation( precondition); }; + auto issue_replicate_bwd = [&]() { + DynamicValueAttrs output_grad = get_only(values( + filter_keys(invocation.inputs, [](DynamicTensorSlot const &s) -> bool { + return s.slot_tensor_role == + DynamicTensorRole{FwbTensorType::GRADIENT}; + }))); + + DynamicValueAttrs input_grad = get_only(values(invocation.outputs)); + + Realm::RegionInstance dst_inst = + tensor_instance_backing.backing.at(input_grad).first; + + redop_id_t redop_id = get_sum_redop_id_for_data_type( + assert_unwrap(output_grad.parallel_tensor_shape).data_type); + + // chain reductions sequentially to avoid write races on dst + Realm::Event result = precondition; + for (auto const &[p, m] : unstructured_relation_from_one_to_many(assert_unwrap(output_grad.mapping))) { + DynamicValueAttrs replica_key = output_grad; + replica_key.mapping = + OneToMany{{p, {m}}}; + replica_key.shard_coord = p; + + Realm::RegionInstance src_inst = + tensor_instance_backing.backing.at(replica_key).first; + + result = ctx.issue_reduction( + /*src_shape=*/assert_unwrap(output_grad.parallel_tensor_shape), + /*src_inst=*/src_inst, + /*dst_shape=*/assert_unwrap(input_grad.parallel_tensor_shape), + /*dst_inst=*/dst_inst, + /*redop_id=*/redop_id, + /*is_fold=*/false, + /*exlusive=*/false, + /*requests=*/Realm::ProfilingRequestSet{}, + /*wait_on=*/result); + } + return result; + }; + TrainingOperationAttrs op_attrs = assert_unwrap(invocation.node_attrs.op_attrs); return op_attrs.visit(overload{ @@ -222,6 +263,14 @@ static Realm::Event spawn_dynamic_node_invocation( return pcg_op_attrs.visit(overload{ [&](InputAttrs const &) { return Realm::Event::NO_EVENT; }, [&](WeightAttrs const &) { return Realm::Event::NO_EVENT; }, + [&](ReplicateAttrs const &) { + if (invocation.node_attrs.task_type.has_value() && + invocation.node_attrs.task_type.value() == + DynamicTaskType::BWD) { + return issue_replicate_bwd(); + } + return issue_copy(); // forward + }, [&](auto const &) { return spawn_task(); }, }); }, diff --git a/lib/realm-execution/src/realm-execution/realm_context.cc b/lib/realm-execution/src/realm-execution/realm_context.cc index 790c1bd613..36dd7c71cc 100644 --- a/lib/realm-execution/src/realm-execution/realm_context.cc +++ b/lib/realm-execution/src/realm-execution/realm_context.cc @@ -7,6 +7,7 @@ #include "pcg/device_id_t.h" #include "pcg/device_type.dtg.h" #include "realm-execution/realm_allocator.h" +#include "realm-execution/redops/redop_id_t.h" #include "realm-execution/tasks/task_id_t.dtg.h" #include "realm-execution/tasks/task_id_t.h" #include "utils/containers/contains_key.h" @@ -154,6 +155,46 @@ static Realm::IndexSpace ispace_from_dims(TensorDims const &dims) { return Realm::IndexSpace{rect}; } +[[nodiscard]] static Realm::Event + issue_copy_for_field(TensorDims const &dims, + Realm::CopySrcDstField const &src_field, + Realm::CopySrcDstField const &dst_field, + Realm::ProfilingRequestSet const &requests, + Realm::Event wait_on, + int priority) { + switch (dims.ff_ordered.num_dims()) { +#if REALM_MAX_DIM >= 1 + case 1: + return ispace_from_dims<1>(dims).copy( + {src_field}, {dst_field}, requests, wait_on, priority); +#endif +#if REALM_MAX_DIM >= 2 + case 2: + return ispace_from_dims<2>(dims).copy( + {src_field}, {dst_field}, requests, wait_on, priority); +#endif +#if REALM_MAX_DIM >= 3 + case 3: + return ispace_from_dims<3>(dims).copy( + {src_field}, {dst_field}, requests, wait_on, priority); +#endif +#if REALM_MAX_DIM >= 4 + case 4: + return ispace_from_dims<4>(dims).copy( + {src_field}, {dst_field}, requests, wait_on, priority); +#endif +#if REALM_MAX_DIM >= 5 + case 5: + return ispace_from_dims<5>(dims).copy( + {src_field}, {dst_field}, requests, wait_on, priority); +#endif + default: + PANIC("TensorShape dims greater than REALM_MAX_DIM: {}", + dims.ff_ordered.num_dims()); + break; + } +} + Realm::Event RealmContext::issue_copy(ParallelTensorShape const &src_shape, Realm::RegionInstance src_inst, @@ -183,43 +224,48 @@ Realm::Event size_of_datatype(src_piece_shape.data_type).int_from_positive_int()), /*subfield_offset=*/0); - Realm::Event result; - switch (src_piece_shape.dims.ff_ordered.num_dims()) { -#if REALM_MAX_DIM >= 1 - case 1: - result = ispace_from_dims<1>(src_piece_shape.dims) - .copy({src_field}, {dst_field}, requests, wait_on, priority); - break; -#endif -#if REALM_MAX_DIM >= 2 - case 2: - result = ispace_from_dims<2>(src_piece_shape.dims) - .copy({src_field}, {dst_field}, requests, wait_on, priority); - break; -#endif -#if REALM_MAX_DIM >= 3 - case 3: - result = ispace_from_dims<3>(src_piece_shape.dims) - .copy({src_field}, {dst_field}, requests, wait_on, priority); - break; -#endif -#if REALM_MAX_DIM >= 4 - case 4: - result = ispace_from_dims<4>(src_piece_shape.dims) - .copy({src_field}, {dst_field}, requests, wait_on, priority); - break; -#endif -#if REALM_MAX_DIM >= 5 - case 5: - result = ispace_from_dims<5>(src_piece_shape.dims) - .copy({src_field}, {dst_field}, requests, wait_on, priority); - break; -#endif - default: - PANIC("TensorShape dims greater than REALM_MAX_DIM: {}", - src_piece_shape.dims.ff_ordered.num_dims()); - break; - } + Realm::Event result = issue_copy_for_field( + src_piece_shape.dims, src_field, dst_field, requests, wait_on, priority); + this->outstanding_events.push_back(result); + return result; +} + +Realm::Event + RealmContext::issue_reduction(ParallelTensorShape const &src_shape, + Realm::RegionInstance src_inst, + ParallelTensorShape const &dst_shape, + Realm::RegionInstance dst_inst, + redop_id_t redop_id, + bool is_fold, + bool exclusive, + Realm::ProfilingRequestSet const &requests, + Realm::Event wait_on, + int priority) { + TensorShape src_piece_shape = get_piece_shape(src_shape); + TensorShape dst_piece_shape = get_piece_shape(dst_shape); + ASSERT(src_piece_shape == dst_piece_shape); // For now, assume they match + + Realm::CopySrcDstField src_field; + src_field.set_field( + /*inst=*/src_inst, + /*field_id=*/0, + /*size=*/ + static_cast( + size_of_datatype(src_piece_shape.data_type).int_from_positive_int()), + /*subfield_offset=*/0); + Realm::CopySrcDstField dst_field; + dst_field.set_field( + /*inst=*/dst_inst, + /*field_id=*/0, + /*size=*/ + static_cast( + size_of_datatype(src_piece_shape.data_type).int_from_positive_int()), + /*subfield_offset=*/0); + dst_field.set_redop( + get_realm_reduction_op_id_for_redop_id(redop_id), is_fold, exclusive); + + Realm::Event result = issue_copy_for_field( + src_piece_shape.dims, src_field, dst_field, requests, wait_on, priority); this->outstanding_events.push_back(result); return result; } diff --git a/lib/realm-execution/src/realm-execution/realm_manager.cc b/lib/realm-execution/src/realm-execution/realm_manager.cc index e76be7054b..5a8f9cbbbb 100644 --- a/lib/realm-execution/src/realm-execution/realm_manager.cc +++ b/lib/realm-execution/src/realm-execution/realm_manager.cc @@ -1,5 +1,6 @@ #include "realm-execution/realm_manager.h" #include "realm-execution/realm_context.h" +#include "realm-execution/redops/realm_redop_registry.h" #include "realm-execution/tasks/realm_task_registry.h" namespace FlexFlow { @@ -9,8 +10,9 @@ RealmManager::RealmManager(int *argc, char ***argv) bool ok = this->get_runtime().init(argc, argv); ASSERT(ok); - // Register all tasks at initialization time so we don't need to later + // Register all tasks and redops at initialization time so we don't need to later register_all_tasks().wait(); + register_all_redops(this->get_runtime()); } RealmManager::~RealmManager() { diff --git a/lib/realm-execution/src/realm-execution/redops/realm_redop_registry.cc b/lib/realm-execution/src/realm-execution/redops/realm_redop_registry.cc new file mode 100644 index 0000000000..ab3304836a --- /dev/null +++ b/lib/realm-execution/src/realm-execution/redops/realm_redop_registry.cc @@ -0,0 +1,540 @@ +#include "realm-execution/redops/realm_redop_registry.h" +#include "realm-execution/redops/redop_id_t.h" + +namespace FlexFlow { + +// Reduction operators and related infrastructure borrowed from Legion. We +// maintain the Legion naming scheme to maximizing compatibility with the +// existing code, despite not otherwise relying or using Legion in any way. +// https://gitlab.com/StanfordLegion/legion/-/blob/5263aeff477fb94239c50d9306d58c4244e9fc38/runtime/legion/api/redop.inl#L31 +#if !defined(__cpp_lib_atomic_ref) || (__cpp_lib_atomic_ref < 201806L) +// We only need this crap if we're using a version of c++ < 20 +// Starting with c++20 we can do all this the right way with atomic_ref +namespace TypePunning { +// The tenth circle of hell is reserved for members of the C++ committee +// that decided to deviate from C's support for type punning unions. +// Add on to it the fact that it took them 9 fucking years to realize +// that they needed std::atomic_ref and it's plain to see they are all +// just a bunch of idiots that should never be allowed near a programming +// language standard ever again. They've clearly never written lock-free +// code in their lives. +template +class Pointer { +public: + Pointer(void *p) : pointer(convert(p)) {} + static inline T *convert(void *p) { + T *ptr = nullptr; + static_assert(sizeof(ptr) == sizeof(p)); + memcpy(&ptr, &p, sizeof(p)); + return ptr; + } + inline operator T *(void) const { + return (T *)pointer; + } + inline T operator*(void) const { + return *pointer; + } + inline T operator[](size_t off) const { + return pointer[off]; + } + +private: + T volatile *const pointer; +}; +template +class AlignedPointer { +public: + AlignedPointer(void *p) : off(align(p)), pointer(convert(p, off)) {} + static inline T *convert(void *p, size_t off) { + uint8_t *p1 = nullptr; + static_assert(sizeof(p1) == sizeof(p)); + memcpy(&p1, &p, sizeof(p)); + p1 = p1 - off; + T *p2 = nullptr; + static_assert(sizeof(p1) == sizeof(p2)); + memcpy(&p2, &p1, sizeof(p1)); + return p2; + } + static inline size_t align(void *p) { + uintptr_t ptr; + static_assert(sizeof(ptr) == sizeof(p)); + memcpy(&ptr, &p, sizeof(ptr)); + return ptr % ALIGNMENT; + } + inline operator T *(void) const { + return (T *)pointer; + } + inline T operator*(void) const { + return *pointer; + } + inline size_t offset(void) const { + return off; + } + +private: + size_t off; + T volatile *const pointer; +}; +template +class Alias { +public: + inline void load(Pointer const &pointer, size_t off = 0) { + T1 value = pointer[off]; + memcpy(buffer, (void *)&value, sizeof(T1)); + } + template + inline void load(AlignedPointer const &pointer) { + T1 value = *pointer; + memcpy(buffer, (void *)&value, sizeof(T1)); + } + inline T1 as_one(void) const { + T1 result; + memcpy((void *)&result, buffer, sizeof(result)); + return result; + } + inline T2 as_two(void) const { + T2 result; + memcpy((void *)&result, buffer, sizeof(result)); + return result; + } + inline Alias &operator=(T2 rhs) { + memcpy(buffer, (void *)&rhs, sizeof(rhs)); + return *this; + } + +private: + // Make this one private so it is can never be called + inline Alias &operator=(T1 rhs) { + memcpy(buffer, (void *)&rhs, sizeof(rhs)); + return *this; + } + static_assert(sizeof(T1) == sizeof(T2)); + uint8_t buffer[sizeof(T1)]; +}; +}; // namespace TypePunning +#endif + +// Define a prefix for annotating functions for CUDA compilation +#if defined(__CUDACC__) || defined(__HIPCC__) +#define __LEGION_CUDA_HD__ __host__ __device__ +#else +#define __LEGION_CUDA_HD__ +#endif + +template +class SumReduction { + // Empty definition + // Specializations provided for each type +}; + +template <> +class SumReduction { +public: + typedef bool LHS; + typedef bool RHS; + + static constexpr bool identity = false; + + template + __LEGION_CUDA_HD__ static void apply(LHS &lhs, RHS rhs); + template + __LEGION_CUDA_HD__ static void fold(RHS &rhs1, RHS rhs2); +}; + +template <> +class SumReduction { +public: + typedef int32_t LHS; + typedef int32_t RHS; + + static constexpr int32_t identity = 0; + + template + __LEGION_CUDA_HD__ static void apply(LHS &lhs, RHS rhs); + template + __LEGION_CUDA_HD__ static void fold(RHS &rhs1, RHS rhs2); +}; + +template <> +class SumReduction { +public: + typedef int64_t LHS; + typedef int64_t RHS; + + static constexpr int64_t identity = 0; + + template + __LEGION_CUDA_HD__ static void apply(LHS &lhs, RHS rhs); + template + __LEGION_CUDA_HD__ static void fold(RHS &rhs1, RHS rhs2); +}; + +template <> +class SumReduction { +public: + typedef float LHS; + typedef float RHS; + + static constexpr float identity = 0.f; + + template + __LEGION_CUDA_HD__ static void apply(LHS &lhs, RHS rhs); + template + __LEGION_CUDA_HD__ static void fold(RHS &rhs1, RHS rhs2); +}; + +template <> +class SumReduction { +public: + typedef double LHS; + typedef double RHS; + + static constexpr double identity = 0.0; + + template + __LEGION_CUDA_HD__ static void apply(LHS &lhs, RHS rhs); + template + __LEGION_CUDA_HD__ static void fold(RHS &rhs1, RHS rhs2); +}; + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::apply(LHS &lhs, + RHS rhs) { + lhs = lhs || rhs; +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::apply(LHS &lhs, + RHS rhs) { +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + // GPU atomics need 4 byte alignment + const uintptr_t unaligned = reinterpret_cast(&lhs); + unsigned const offset = unaligned % sizeof(unsigned int); + const uintptr_t aligned = unaligned - offset; + unsigned int *ptr = reinterpret_cast(aligned); + unsigned int newval = *ptr, oldval; + do { + RHS previous = __uint2bool(newval, offset); + RHS next = previous || rhs; + oldval = newval; + newval = __bool2uint(newval, next, offset); + newval = atomicCAS(ptr, oldval, newval); + } while (oldval != newval); +#else +#if defined(__cpp_lib_atomic_ref) && (__cpp_lib_atomic_ref >= 201806L) + std::atomic_ref atomic(lhs); + RHS oldval = atomic.load(); + RHS newval; + do { + newval = oldval || rhs; + } while (!atomic.compare_exchange_weak(oldval, newval)); +#else + // No atomic logical operations so use compare and swap + TypePunning::Alias oldval, newval; + TypePunning::Pointer pointer((void *)&lhs); + do { + oldval.load(pointer); + newval = oldval.as_two() || rhs; + } while (!__sync_bool_compare_and_swap( + (int8_t *)pointer, oldval.as_one(), newval.as_one())); +#endif +#endif +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::fold(RHS &rhs1, + RHS rhs2) { + rhs1 = rhs1 || rhs2; +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::fold(RHS &rhs1, + RHS rhs2) { +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + // GPU atomics need 4 byte alignment + const uintptr_t unaligned = reinterpret_cast(&rhs1); + unsigned const offset = unaligned % sizeof(unsigned int); + const uintptr_t aligned = unaligned - offset; + unsigned int *ptr = reinterpret_cast(aligned); + unsigned int newval = *ptr, oldval; + do { + RHS previous = __uint2bool(newval, offset); + RHS next = previous || rhs2; + oldval = newval; + newval = __bool2uint(newval, next, offset); + newval = atomicCAS(ptr, oldval, newval); + } while (oldval != newval); +#else +#if defined(__cpp_lib_atomic_ref) && (__cpp_lib_atomic_ref >= 201806L) + std::atomic_ref atomic(rhs1); + RHS oldval = atomic.load(); + RHS newval; + do { + newval = oldval || rhs2; + } while (!atomic.compare_exchange_weak(oldval, newval)); +#else + // No atomic logical operations so use compare and swap + TypePunning::Alias oldval, newval; + TypePunning::Pointer pointer((void *)&rhs1); + do { + oldval.load(pointer); + newval = oldval.as_two() || rhs2; + } while (!__sync_bool_compare_and_swap( + (int8_t *)pointer, oldval.as_one(), newval.as_one())); +#endif +#endif +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::apply(LHS &lhs, + RHS rhs) { + lhs += rhs; +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::apply(LHS &lhs, + RHS rhs) { +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + atomicAdd(&lhs, rhs); +#else + __sync_fetch_and_add(&lhs, rhs); +#endif +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::fold(RHS &rhs1, + RHS rhs2) { + rhs1 += rhs2; +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::fold(RHS &rhs1, + RHS rhs2) { +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + atomicAdd(&rhs1, rhs2); +#else + __sync_fetch_and_add(&rhs1, rhs2); +#endif +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::apply(LHS &lhs, + RHS rhs) { + lhs += rhs; +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::apply(LHS &lhs, + RHS rhs) { +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + // Apparently there is no signed 64bit int atomic yet + RHS newval = lhs, oldval; + // Type punning like this is illegal in C++ but the + // CUDA manual has an example just like it so fuck it + unsigned long long int *ptr = (unsigned long long int *)&lhs; + do { + oldval = newval; + newval += rhs; + newval = __ulonglong_as_longlong(atomicCAS( + ptr, __longlong_as_ulonglong(oldval), __longlong_as_ulonglong(newval))); + } while (oldval != newval); +#else + __sync_fetch_and_add(&lhs, rhs); +#endif +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::fold(RHS &rhs1, + RHS rhs2) { + rhs1 += rhs2; +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::fold(RHS &rhs1, + RHS rhs2) { +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + // Apparently there is no signed 64bit int atomic yet + RHS newval = rhs1, oldval; + // Type punning like this is illegal in C++ but the + // CUDA manual has an example just like it so fuck it + unsigned long long int *ptr = (unsigned long long int *)&rhs1; + do { + oldval = newval; + newval += rhs2; + newval = __ulonglong_as_longlong(atomicCAS( + ptr, __longlong_as_ulonglong(oldval), __longlong_as_ulonglong(newval))); + } while (oldval != newval); +#else + __sync_fetch_and_add(&rhs1, rhs2); +#endif +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::apply(LHS &lhs, + RHS rhs) { + lhs += rhs; +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::apply(LHS &lhs, + RHS rhs) { +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + atomicAdd(&lhs, rhs); +#else +#if defined(__cpp_lib_atomic_ref) && (__cpp_lib_atomic_ref >= 201806L) + std::atomic_ref atomic(lhs); + RHS oldval = atomic.load(); + RHS newval; + do { + newval = oldval + rhs; + } while (!atomic.compare_exchange_weak(oldval, newval)); +#else + // No atomic floating point operations so use compare and swap + TypePunning::Alias oldval, newval; + TypePunning::Pointer pointer((void *)&lhs); + do { + oldval.load(pointer); + newval = oldval.as_two() + rhs; + } while (!__sync_bool_compare_and_swap( + (int32_t *)pointer, oldval.as_one(), newval.as_one())); +#endif +#endif +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::fold(RHS &rhs1, + RHS rhs2) { + rhs1 += rhs2; +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::fold(RHS &rhs1, + RHS rhs2) { +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + atomicAdd(&rhs1, rhs2); +#else +#if defined(__cpp_lib_atomic_ref) && (__cpp_lib_atomic_ref >= 201806L) + std::atomic_ref atomic(rhs1); + RHS oldval = atomic.load(); + RHS newval; + do { + newval = oldval + rhs2; + } while (!atomic.compare_exchange_weak(oldval, newval)); +#else + // No atomic floating point operations so use compare and swap + TypePunning::Alias oldval, newval; + TypePunning::Pointer pointer((void *)&rhs1); + do { + oldval.load(pointer); + newval = oldval.as_two() + rhs2; + } while (!__sync_bool_compare_and_swap( + (int32_t *)pointer, oldval.as_one(), newval.as_one())); +#endif +#endif +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::apply(LHS &lhs, + RHS rhs) { + lhs += rhs; +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::apply(LHS &lhs, + RHS rhs) { +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) +#if (__CUDA_ARCH__ >= 600) || defined(__HIP_DEVICE_COMPILE__) + atomicAdd(&lhs, rhs); +#else + RHS newval = lhs, oldval; + // Type punning like this is illegal in C++ but the + // CUDA manual has an example just like it so fuck it + unsigned long long int *ptr = (unsigned long long int *)&lhs; + do { + oldval = newval; + newval += rhs; + newval = __ulonglong_as_double(atomicCAS( + ptr, __double_as_ulonglong(oldval), __double_as_ulonglong(newval))); + } while (oldval != newval); +#endif +#else +#if defined(__cpp_lib_atomic_ref) && (__cpp_lib_atomic_ref >= 201806L) + std::atomic_ref atomic(lhs); + RHS oldval = atomic.load(); + RHS newval; + do { + newval = oldval + rhs; + } while (!atomic.compare_exchange_weak(oldval, newval)); +#else + // No atomic floating point operations so use compare and swap + TypePunning::Alias oldval, newval; + TypePunning::Pointer pointer((void *)&lhs); + do { + oldval.load(pointer); + newval = oldval.as_two() + rhs; + } while (!__sync_bool_compare_and_swap( + (int64_t *)pointer, oldval.as_one(), newval.as_one())); +#endif +#endif +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::fold(RHS &rhs1, + RHS rhs2) { + rhs1 += rhs2; +} + +template <> +__LEGION_CUDA_HD__ inline void SumReduction::fold(RHS &rhs1, + RHS rhs2) { +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) +#if (__CUDA_ARCH__ >= 600) || defined(__HIP_DEVICE_COMPILE__) + atomicAdd(&rhs1, rhs2); +#else + RHS newval = rhs1, oldval; + // Type punning like this is illegal in C++ but the + // CUDA manual has an example just like it so fuck it + unsigned long long int *ptr = (unsigned long long int *)&rhs1; + do { + oldval = newval; + newval += rhs2; + newval = __ulonglong_as_double(atomicCAS( + ptr, __double_as_ulonglong(oldval), __double_as_ulonglong(newval))); + } while (oldval != newval); +#endif +#else +#if defined(__cpp_lib_atomic_ref) && (__cpp_lib_atomic_ref >= 201806L) + std::atomic_ref atomic(rhs1); + RHS oldval = atomic.load(); + RHS newval; + do { + newval = oldval + rhs2; + } while (!atomic.compare_exchange_weak(oldval, newval)); +#else + // No atomic floating point operations so use compare and swap + TypePunning::Alias oldval, newval; + TypePunning::Pointer pointer((void *)&rhs1); + do { + oldval.load(pointer); + newval = oldval.as_two() + rhs2; + } while (!__sync_bool_compare_and_swap( + (int64_t *)pointer, oldval.as_one(), newval.as_one())); +#endif +#endif +} + +void register_all_redops(Realm::Runtime rt) { + // Registration is synchronous, so no need to capture events here + rt.register_reduction>( + get_realm_reduction_op_id_for_redop_id(redop_id_t::SUM_BOOL_REDOP_ID)); + rt.register_reduction>( + get_realm_reduction_op_id_for_redop_id(redop_id_t::SUM_INT32_REDOP_ID)); + rt.register_reduction>( + get_realm_reduction_op_id_for_redop_id(redop_id_t::SUM_INT64_REDOP_ID)); + rt.register_reduction>( + get_realm_reduction_op_id_for_redop_id(redop_id_t::SUM_FLOAT_REDOP_ID)); + rt.register_reduction>( + get_realm_reduction_op_id_for_redop_id(redop_id_t::SUM_DOUBLE_REDOP_ID)); +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/redops/redop_id_t.cc b/lib/realm-execution/src/realm-execution/redops/redop_id_t.cc new file mode 100644 index 0000000000..f31769419f --- /dev/null +++ b/lib/realm-execution/src/realm-execution/redops/redop_id_t.cc @@ -0,0 +1,28 @@ +#include "realm-execution/redops/redop_id_t.h" +#include "utils/exception.h" + +namespace FlexFlow { + +redop_id_t get_sum_redop_id_for_data_type(DataType dtype) { + switch (dtype) { + case DataType::BOOL: + return redop_id_t::SUM_BOOL_REDOP_ID; + case DataType::INT32: + return redop_id_t::SUM_INT32_REDOP_ID; + case DataType::INT64: + return redop_id_t::SUM_INT64_REDOP_ID; + case DataType::FLOAT: + return redop_id_t::SUM_FLOAT_REDOP_ID; + case DataType::DOUBLE: + return redop_id_t::SUM_DOUBLE_REDOP_ID; + default: + PANIC("No known sum reduction for data type {}", dtype); + } +} + +Realm::ReductionOpID + get_realm_reduction_op_id_for_redop_id(redop_id_t redop_id) { + return static_cast(redop_id); +} + +} // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc b/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc index e7a8948f8d..dfdfe72ce0 100644 --- a/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc +++ b/lib/realm-execution/src/realm-execution/tasks/realm_task_registry.cc @@ -49,7 +49,6 @@ Realm::Event register_all_tasks() { task_id_t::REDUCE_INIT_TASK_ID, task_id_t::REDUCTION_INIT_TASK_ID, task_id_t::REPARTITION_INIT_TASK_ID, - task_id_t::REPLICATE_INIT_TASK_ID, task_id_t::SOFTMAX_INIT_TASK_ID, }; @@ -86,7 +85,6 @@ Realm::Event register_all_tasks() { task_id_t::REDUCE_FWD_TASK_ID, task_id_t::REDUCTION_FWD_TASK_ID, task_id_t::REPARTITION_FWD_TASK_ID, - task_id_t::REPLICATE_FWD_TASK_ID, task_id_t::RESHAPE_FWD_TASK_ID, task_id_t::REVERSE_FWD_TASK_ID, task_id_t::SOFTMAX_FWD_TASK_ID, @@ -115,7 +113,6 @@ Realm::Event register_all_tasks() { task_id_t::REDUCE_BWD_TASK_ID, task_id_t::REDUCTION_BWD_TASK_ID, task_id_t::REPARTITION_BWD_TASK_ID, - task_id_t::REPLICATE_BWD_TASK_ID, task_id_t::RESHAPE_BWD_TASK_ID, task_id_t::REVERSE_BWD_TASK_ID, task_id_t::SOFTMAX_BWD_TASK_ID, diff --git a/lib/realm-execution/src/realm-execution/tasks/serializer/serializable_tensor_instance_backing.cc b/lib/realm-execution/src/realm-execution/tasks/serializer/serializable_tensor_instance_backing.cc index 79a5176c4f..1d53824c29 100644 --- a/lib/realm-execution/src/realm-execution/tasks/serializer/serializable_tensor_instance_backing.cc +++ b/lib/realm-execution/src/realm-execution/tasks/serializer/serializable_tensor_instance_backing.cc @@ -8,25 +8,29 @@ namespace FlexFlow { SerializableTensorInstanceBacking tensor_instance_backing_to_serializable( TensorInstanceBacking const &backing) { - return SerializableTensorInstanceBacking{/*backing=*/map_keys_and_values( + return SerializableTensorInstanceBacking{ + /*backing=*/map_keys_and_values( backing.backing, dynamic_value_attrs_to_serializable, [](std::pair const &p) { return std::pair{realm_instance_to_serializable(p.first), realm_event_to_serializable(p.second)}; - })}; + }), + }; } TensorInstanceBacking tensor_instance_backing_from_serializable( SerializableTensorInstanceBacking const &backing) { - return TensorInstanceBacking{/*backing=*/map_keys_and_values( + return TensorInstanceBacking{ + /*backing=*/map_keys_and_values( backing.backing, dynamic_value_attrs_from_serializable, [](std::pair const &p) { return std::pair{realm_instance_from_serializable(p.first), realm_event_from_serializable(p.second)}; - })}; + }), + }; } } // namespace FlexFlow diff --git a/lib/realm-execution/src/realm-execution/tasks/task_id_t.cc b/lib/realm-execution/src/realm-execution/tasks/task_id_t.cc index dd4b0a66ca..0bdc2ca6b5 100644 --- a/lib/realm-execution/src/realm-execution/tasks/task_id_t.cc +++ b/lib/realm-execution/src/realm-execution/tasks/task_id_t.cc @@ -64,9 +64,7 @@ std::optional [](RepartitionAttrs const &attrs) { return task_id_t::REPARTITION_INIT_TASK_ID; }, - [](ReplicateAttrs const &attrs) { - return task_id_t::REPLICATE_INIT_TASK_ID; - }, + [](ReplicateAttrs const &attrs) { return std::nullopt; }, [](ReshapeAttrs const &) { return std::nullopt; }, [](ReverseAttrs const &) { return std::nullopt; }, [](SoftmaxAttrs const &) { return task_id_t::SOFTMAX_INIT_TASK_ID; }, @@ -115,9 +113,7 @@ std::optional [](RepartitionAttrs const &attrs) { return task_id_t::REPARTITION_FWD_TASK_ID; }, - [](ReplicateAttrs const &attrs) { - return task_id_t::REPLICATE_FWD_TASK_ID; - }, + [](ReplicateAttrs const &attrs) { return std::nullopt; }, [](ReshapeAttrs const &) { return task_id_t::RESHAPE_FWD_TASK_ID; }, [](ReverseAttrs const &) { return task_id_t::REVERSE_FWD_TASK_ID; }, [](SoftmaxAttrs const &) { return task_id_t::SOFTMAX_FWD_TASK_ID; }, @@ -166,9 +162,7 @@ std::optional [](RepartitionAttrs const &attrs) { return task_id_t::REPARTITION_BWD_TASK_ID; }, - [](ReplicateAttrs const &attrs) { - return task_id_t::REPLICATE_BWD_TASK_ID; - }, + [](ReplicateAttrs const &attrs) { return std::nullopt; }, [](ReshapeAttrs const &) { return task_id_t::RESHAPE_BWD_TASK_ID; }, [](ReverseAttrs const &) { return task_id_t::REVERSE_BWD_TASK_ID; }, [](SoftmaxAttrs const &) { return task_id_t::SOFTMAX_BWD_TASK_ID; }, diff --git a/lib/realm-execution/test/src/realm-execution/test_op_replicate.cc b/lib/realm-execution/test/src/realm-execution/test_op_replicate.cc new file mode 100644 index 0000000000..6efbb17eb3 --- /dev/null +++ b/lib/realm-execution/test/src/realm-execution/test_op_replicate.cc @@ -0,0 +1,352 @@ +#include "internal/realm_test_utils.h" +#include "kernels/allocation.h" +#include "kernels/compare_tensor_accessors.h" +#include "kernels/copy_tensor_accessor.h" +#include "kernels/format_accessor_contents.h" +#include "kernels/tensor_accessor_reductions.h" +#include "op-attrs/operator_task_space_to_operator_task_space_mapping.h" +#include "op-attrs/ops/element_unary.h" +#include "op-attrs/ops/linear.h" +#include "op-attrs/ops/replicate.h" +#include "op-attrs/parallel_tensor_shape.h" +#include "op-attrs/tensor_shape.dtg.h" +#include "op-attrs/tensor_slot_name.dtg.h" +#include "pcg/device_type.dtg.h" +#include "pcg/machine_space_coordinate.dtg.h" +#include "pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h" +#include "pcg/mapped_parallel_computation_graph/operator_atomic_task_shard_binding.dtg.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_builder.h" +#include "pcg/parallel_computation_graph/parallel_layer_guid_t.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_guid_t.dtg.h" +#include "realm-execution/distributed_ff_handle.h" +#include "realm-execution/dynamic_tensor_accessor_from_instance.h" +#include "realm-execution/pcg_instance.h" +#include "realm-execution/realm_context.h" +#include "realm-execution/realm_manager.h" +#include "task-spec/permissions.h" +#include "test/utils/doctest/check_kv.h" +#include "utils/containers/require_only_key.h" +#include + +namespace test { + +using namespace ::FlexFlow; +namespace Realm = ::FlexFlow::Realm; + +template +static ParallelLayerAttrs make_layer_attrs(T const &op_attrs) { + return ParallelLayerAttrs{ + /*op_attrs=*/PCGOperatorAttrs{op_attrs}, + /*name=*/std::nullopt, + }; +}; + +static bool did_loss_decrease(GenericTensorAccessorR const &first_epoch, + GenericTensorAccessorR const &last_epoch, + Allocator &allocator) { + return tensor_accessor_all( + compare_tensor_accessors_le(last_epoch, first_epoch, allocator)); +} + +MappedParallelComputationGraph + make_test_mpcg_for_device_type(DeviceType device_type) { + positive_int batch_size = 10_p; + positive_int data_dim = 16_p; + positive_int hidden_dim = 32_p; + positive_int output_dim = 1_p; + + TensorShape output_tensor_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, output_dim}}, DataType::FLOAT}; + + TensorShape label_tensor_shape = TensorShape{ + TensorDims{FFOrdered{batch_size, output_dim}}, DataType::FLOAT}; + + ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + TensorShape input_tensor_shape = + TensorShape{TensorDims{FFOrdered{batch_size, data_dim}}, DataType::FLOAT}; + + ParallelLayerAddedResult inputs_layer = + pcg_add_input_layer(pcg, input_tensor_shape); + parallel_tensor_guid_t t_input = + require_only_key(inputs_layer.outputs, TensorSlotName::OUTPUT); + + ParallelLayerAddedResult inputs_layer_2 = + pcg_add_input_layer(pcg, input_tensor_shape); + parallel_tensor_guid_t t_input_2 = + require_only_key(inputs_layer_2.outputs, TensorSlotName::OUTPUT); + + ElementBinaryAttrs add_attrs = ElementBinaryAttrs{ + OperatorType::EW_ADD, + DataType::FLOAT, + false, + false, + }; + + ParallelLayerAddedResult add_operator_1 = + add_parallel_layer(pcg, + make_layer_attrs(add_attrs), + { + { + TensorSlotName::LHS_INPUT, + t_input, + }, + { + TensorSlotName::RHS_INPUT, + t_input_2, + }, + }, + /*weights=*/{}); + + parallel_tensor_guid_t t_add_1 = + require_only_key(add_operator_1.outputs, TensorSlotName::OUTPUT); + + positive_int replicate_degree = 2_p; + ReplicateAttrs repl_attrs = ReplicateAttrs{replicate_degree}; + ParallelLayerAddedResult repl_operator_1 = + add_parallel_layer(pcg, + make_layer_attrs(repl_attrs), + { + { + TensorSlotName::INPUT, + t_add_1, + }, + }, + /*weight=*/{}); + + parallel_tensor_guid_t t_repl_1 = + require_only_key(repl_operator_1.outputs, TensorSlotName::OUTPUT); + + ParallelLayerAddedResult relu_operator_1 = + add_parallel_layer(pcg, + make_layer_attrs(make_relu_attrs()), + /*inputs=*/ + { + { + TensorSlotName::INPUT, + t_repl_1, + }, + }, + /*weights=*/{}); + + parallel_tensor_guid_t t_relu_1 = + require_only_key(relu_operator_1.outputs, TensorSlotName::OUTPUT); + + MachineSpaceCoordinate cpu0{0_n, 0_n, device_type}; + MachineSpaceCoordinate cpu1{0_n, 1_n, device_type}; + + ParallelTensorSpaceCoordinate tensor_coord0{ + /*sum_component=*/0_n, + /*discard_copy_component=*/0_n, + /*shard_component=*/FFOrdered{0_n}}; + ParallelTensorSpaceCoordinate tensor_coord1{ + /*sum_component=*/0_n, + /*discard_copy_component=*/1_n, + /*shard_component=*/FFOrdered{0_n}}; + + MappedParallelComputationGraph mpcg = + mapped_pcg_from_pcg_and_mapped_op_task_groups( + /*pcg=*/pcg, + /*mapped_op_task_groups=*/{ + { + inputs_layer.parallel_layer, + MappedOperatorTaskGroup{ + { + { + cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord0}, + }}, + }, + }, + }, + }, + { + inputs_layer_2.parallel_layer, + MappedOperatorTaskGroup{ + { + { + cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord0}, + }}, + }, + }, + }, + }, + { + add_operator_1.parallel_layer, + MappedOperatorTaskGroup{ + { + { + cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::LHS_INPUT, tensor_coord0}, + {TensorSlotName::RHS_INPUT, tensor_coord0}, + {TensorSlotName::OUTPUT, tensor_coord0}, + }}, + }, + }, + }, + }, + { + repl_operator_1.parallel_layer, + MappedOperatorTaskGroup{ + { + { + cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord0}, + {TensorSlotName::OUTPUT, tensor_coord0}, + }}, + }, + { + cpu1, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord0}, + {TensorSlotName::OUTPUT, tensor_coord1}, + }}, + }, + }, + }, + }, + { + relu_operator_1.parallel_layer, + MappedOperatorTaskGroup{ + { + { + cpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord0}, + {TensorSlotName::OUTPUT, tensor_coord0}, + }}, + }, + { + cpu1, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::INPUT, tensor_coord1}, + {TensorSlotName::OUTPUT, tensor_coord1}, + }}, + }, + }, + }, + }, + }); + + return mpcg; +} + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("RealmBackend e2e Training Replicate Op (CPU Model Parallelism)") { + std::vector fake_args = + make_fake_realm_args(/*num_cpus=*/2_p, /*num_gpus=*/0_n); + int fake_argc = fake_args.size(); + char **fake_argv = fake_args.data(); + + RealmManager manager = RealmManager{&fake_argc, &fake_argv}; + ControllerTaskResult result = + manager.start_controller([](RealmContext &ctx) { + Allocator allocator = ctx.get_current_device_allocator(); + + MappedParallelComputationGraph mpcg = + make_test_mpcg_for_device_type(DeviceType::CPU); + + std::unordered_map + input_tensors; + + OptimizerAttrs optimizer_attrs = OptimizerAttrs{ + SGDOptimizerAttrs{ + /*lr=*/0.001, + /*momentum=*/0.9, + /*nesterov=*/false, + /*weight_decay=*/0.001, + }, + }; + + DistributedFfHandle device_handle = create_distributed_ff_handle( + ctx, + /*workSpaceSize=*/1024 * 1024, + /*allowTensorOpMathConversion=*/true); + + PCGInstance pcg_instance = create_pcg_instance( + /*ctx=*/ctx, + /*mpcg=*/mpcg, + /*optimizer=*/optimizer_attrs, + /*loss=*/std::nullopt, + /*input_tensors=*/input_tensors, + /*profiling_settings=*/ProfilingSettings{0, 0}, + /*device_handle=*/device_handle, + /*iteration_config=*/FFIterationConfig{1_p}); + + // begin training loop + int num_epochs = 1; + for (int i = 0; i < num_epochs; i++) { + perform_all_passes_for_pcg_instance( + /*instance=*/pcg_instance, + /*profiling_settings=*/ProfilingSettings{0, 0}, + /*device_handle=*/device_handle, + /*iteration_config=*/FFIterationConfig{1_p}); + } + }); + result.wait(); + } +} + +TEST_SUITE(FF_CUDA_TEST_SUITE) { + TEST_CASE("RealmBackend e2e Training Replicate Op (GPU Model Parallelism)") { + std::vector fake_args = + make_fake_realm_args(/*num_cpus=*/1_p, /*num_gpus=*/2_n); + int fake_argc = fake_args.size(); + char **fake_argv = fake_args.data(); + + RealmManager manager = RealmManager{&fake_argc, &fake_argv}; + + ControllerTaskResult result = + manager.start_controller([](RealmContext &ctx) { + Allocator allocator = ctx.get_current_device_allocator(); + + MappedParallelComputationGraph mpcg = + make_test_mpcg_for_device_type(DeviceType::GPU); + + OptimizerAttrs optimizer_attrs = OptimizerAttrs{ + SGDOptimizerAttrs{ + /*lr=*/0.001, + /*momentum=*/0.9, + /*nesterov=*/false, + /*weight_decay=*/0.001, + }, + }; + + std::unordered_map + input_tensors; + + DistributedFfHandle device_handle = create_distributed_ff_handle( + ctx, + /*workSpaceSize=*/1024 * 1024, + /*allowTensorOpMathConversion=*/true); + + PCGInstance pcg_instance = create_pcg_instance( + /*ctx=*/ctx, + /*mpcg=*/mpcg, + /*optimizer=*/optimizer_attrs, + /*loss=*/std::nullopt, + /*input_tensors=*/input_tensors, + /*profiling_settings=*/ProfilingSettings{0, 0}, + /*device_handle=*/device_handle, + /*iteration_config=*/FFIterationConfig{1_p}); + + // begin training loop + int num_epochs = 1; + for (int i = 0; i < num_epochs; i++) { + perform_all_passes_for_pcg_instance( + /*instance=*/pcg_instance, + /*profiling_settings=*/ProfilingSettings{0, 0}, + /*device_handle=*/device_handle, + /*iteration_config=*/FFIterationConfig{1_p}); + } + }); + result.wait(); + } +} +} // namespace test diff --git a/lib/runtime/src/parallel_tensor_uses.cc b/lib/runtime/src/parallel_tensor_uses.cc index 444d3a061a..970a9109dc 100644 --- a/lib/runtime/src/parallel_tensor_uses.cc +++ b/lib/runtime/src/parallel_tensor_uses.cc @@ -31,7 +31,7 @@ Op const *ParallelTensorUses::get_owner(ParallelTensor const &tensor) const { } void ParallelTensorUses::remove(Op const &op) { - for (auto const &k : keys(this->uses)) { + for (auto const &k : unordered_keys(this->uses)) { inplace_filter(this->uses.at(k), [&](ParallelTensorUseDescription const &d) { return d.op->op_guid == op.op_guid; diff --git a/lib/runtime/src/tensor_uses.cc b/lib/runtime/src/tensor_uses.cc index ce4672342d..db2cc942d4 100644 --- a/lib/runtime/src/tensor_uses.cc +++ b/lib/runtime/src/tensor_uses.cc @@ -20,7 +20,7 @@ std::vector TensorUses::at(size_t tensor_guid) const { } void TensorUses::remove(Layer const &layer) { - for (auto const &k : keys(this->uses)) { + for (auto const &k : unordered_keys(this->uses)) { inplace_filter(this->uses.at(k), [&](TensorUseDescription const &d) { return d.layer->layer_guid == layer.layer_guid; }); diff --git a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h index cbfe3ab264..2a3dc8bbb8 100644 --- a/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h +++ b/lib/substitutions/include/substitutions/sub_parallel_computation_graph.h @@ -48,8 +48,8 @@ std::unordered_set get_subgraph_outgoing_edges( std::unordered_set const &); std::unordered_set - get_parallel_tensor_uses(SubParallelComputationGraph const &, - open_parallel_tensor_guid_t const &); + get_open_parallel_tensor_uses(SubParallelComputationGraph const &, + open_parallel_tensor_guid_t const &); SubParallelComputationGraphData get_sub_pcg_data(SubParallelComputationGraph const &); diff --git a/lib/substitutions/src/substitutions/apply_substitution/apply_substitution.cc b/lib/substitutions/src/substitutions/apply_substitution/apply_substitution.cc index 6ed2ef563e..b8140440b7 100644 --- a/lib/substitutions/src/substitutions/apply_substitution/apply_substitution.cc +++ b/lib/substitutions/src/substitutions/apply_substitution/apply_substitution.cc @@ -9,11 +9,11 @@ #include "substitutions/sub_parallel_computation_graph_data.dtg.h" #include "substitutions/sub_parallel_computation_graph_data.h" #include "substitutions/sub_parallel_computation_graph_edge.h" -#include "utils/containers/binary_merge_disjoint_maps.h" -#include "utils/containers/keys.h" +#include "utils/containers/unordered_keys.h" #include "utils/containers/restrict_keys.h" #include "utils/containers/set_minus.h" #include "utils/containers/values.h" +#include "utils/containers/binary_merge_disjoint_unordered_maps.h" namespace FlexFlow { @@ -50,7 +50,7 @@ SubParallelComputationGraph apply_substitution_from_output_result( require_sub_parallel_computation_graph_data_is_valid(pre_data); std::unordered_set pre_nodes = - keys(pre_data.node_data); + unordered_keys(pre_data.node_data); std::unordered_set matched_nodes = unordered_set_of(values(match.node_assignment)); std::unordered_set post_nodes_from_original_graph = @@ -64,7 +64,7 @@ SubParallelComputationGraph apply_substitution_from_output_result( std::unordered_map post_node_data_from_sub = output_graph_data.node_data; - return binary_merge_disjoint_maps(post_node_data_from_orig, + return binary_merge_disjoint_unordered_maps(post_node_data_from_orig, post_node_data_from_sub); }(); @@ -109,9 +109,10 @@ SubParallelComputationGraph apply_substitution_from_output_result( input_parallel_tensor_guid_t output_graph_input = output_expr_to_result_sub_pcg_mapping.input_mapping.at_r( output_expr_input); - std::unordered_set uses = get_parallel_tensor_uses( - substitution_output_graph, - open_parallel_tensor_guid_from_input(output_graph_input)); + std::unordered_set uses = + get_open_parallel_tensor_uses( + substitution_output_graph, + open_parallel_tensor_guid_from_input(output_graph_input)); for (parallel_tensor_use_t const &use : uses) { SubParallelComputationGraphEdge new_edge = subpcg_edge_from_tensor_and_use(base_graph_tensor, use); @@ -167,8 +168,8 @@ SubParallelComputationGraph apply_substitution_from_output_result( std::unordered_map post_value_data_from_sub = output_graph_data.value_data; - return binary_merge_disjoint_maps(post_value_data_from_orig, - post_value_data_from_sub); + return binary_merge_disjoint_unordered_maps(post_value_data_from_orig, + post_value_data_from_sub); }(); SubParallelComputationGraphData post_data = SubParallelComputationGraphData{ diff --git a/lib/substitutions/src/substitutions/apply_substitution/output_expr_to_result_sub_pcg_mapping.cc b/lib/substitutions/src/substitutions/apply_substitution/output_expr_to_result_sub_pcg_mapping.cc index 2ad5b54a17..4374a951f8 100644 --- a/lib/substitutions/src/substitutions/apply_substitution/output_expr_to_result_sub_pcg_mapping.cc +++ b/lib/substitutions/src/substitutions/apply_substitution/output_expr_to_result_sub_pcg_mapping.cc @@ -2,7 +2,7 @@ #include "substitutions/output_graph/output_graph_expr.h" #include "substitutions/sub_parallel_computation_graph.h" #include "utils/bidict/algorithms/bidict_from_pairs.h" -#include "utils/bidict/algorithms/merge_disjoint_bidicts.h" +#include "utils/bidict/algorithms/binary_merge_disjoint_bidicts.h" #include "utils/containers/values.h" #include "utils/containers/zip_values_strict.h" @@ -26,7 +26,7 @@ bidict mapping_for_layer = bidict_from_pairs(values( zip_values_strict(layer_outputs, output_graph_expr_outputs))); - result = merge_disjoint_bidicts(result, mapping_for_layer); + result = binary_merge_disjoint_bidicts(result, mapping_for_layer); } return result; diff --git a/lib/substitutions/src/substitutions/apply_substitution/perform_shape_inference.cc b/lib/substitutions/src/substitutions/apply_substitution/perform_shape_inference.cc index 8e1c06b9b5..d3ad4ca246 100644 --- a/lib/substitutions/src/substitutions/apply_substitution/perform_shape_inference.cc +++ b/lib/substitutions/src/substitutions/apply_substitution/perform_shape_inference.cc @@ -1,7 +1,6 @@ #include "substitutions/apply_substitution/perform_shape_inference.h" #include "op-attrs/get_incoming_tensor_roles.h" #include "op-attrs/shape_inference.h" -#include "utils/containers/binary_merge_disjoint_maps.h" #include "utils/containers/filter_values.h" #include "utils/containers/filtrans.h" #include "utils/containers/is_subseteq_of.h" @@ -20,6 +19,7 @@ #include "utils/graph/open_dataflow_graph/algorithms/get_inputs.h" #include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_incoming_open_kwarg_dataflow_values_for_node.h" #include "utils/nonnegative_int/num_elements.h" +#include "utils/containers/binary_merge_disjoint_unordered_maps.h" namespace FlexFlow { @@ -56,12 +56,12 @@ LabelledOpenKwargDataflowGraphView incoming_tensor_roles = get_incoming_tensor_roles(n_attrs.op_attrs); - ASSERT(is_subseteq_of(keys(incoming_shapes), keys(incoming_tensor_roles))); + ASSERT(is_subseteq_of(unordered_keys(incoming_shapes), unordered_keys(incoming_tensor_roles))); auto incoming_shapes_with_role = [&](IncomingTensorRole role) -> std::unordered_map { std::unordered_set slots_with_desired_role = - keys(filter_values(incoming_tensor_roles, + unordered_keys(filter_values(incoming_tensor_roles, [&](IncomingTensorRole r) { return r == role; })); return restrict_keys(incoming_shapes, slots_with_desired_role); @@ -72,7 +72,7 @@ LabelledOpenKwargDataflowGraphView weight_shapes = incoming_shapes_with_role(IncomingTensorRole::WEIGHT); - ASSERT(binary_merge_disjoint_maps(input_shapes, weight_shapes) == + ASSERT(binary_merge_disjoint_unordered_maps(input_shapes, weight_shapes) == incoming_shapes); std::unordered_map diff --git a/lib/substitutions/src/substitutions/output_graph/output_operator_attrs_assignment.cc b/lib/substitutions/src/substitutions/output_graph/output_operator_attrs_assignment.cc index 647362ee4d..2755716e44 100644 --- a/lib/substitutions/src/substitutions/output_graph/output_operator_attrs_assignment.cc +++ b/lib/substitutions/src/substitutions/output_graph/output_operator_attrs_assignment.cc @@ -2,9 +2,9 @@ #include "substitutions/operator_pattern/get_attribute_map.h" #include "substitutions/output_graph/materialize_operator_from_attrs_map.h" #include "substitutions/output_graph/output_operator_attribute_expr.h" -#include "utils/containers/binary_merge_maps_with_right_dominating.h" #include "utils/containers/map_values.h" #include "utils/exception.h" +#include "utils/containers/binary_merge_unordered_maps_with_right_dominating.h" namespace FlexFlow { @@ -36,7 +36,7 @@ PCGOperatorAttrs materialize_output_operator_from_attrs_assignment( }); std::unordered_map - joined_attrs_map = binary_merge_maps_with_right_dominating( + joined_attrs_map = binary_merge_unordered_maps_with_right_dominating( template_attrs_map, assignments_attrs_map); return materialize_operator_from_attrs_map(joined_attrs_map); diff --git a/lib/substitutions/src/substitutions/pcg_pattern_match.cc b/lib/substitutions/src/substitutions/pcg_pattern_match.cc index 498fd6c1bf..8a71fe2ad5 100644 --- a/lib/substitutions/src/substitutions/pcg_pattern_match.cc +++ b/lib/substitutions/src/substitutions/pcg_pattern_match.cc @@ -4,8 +4,8 @@ #include "substitutions/unlabelled/unlabelled_graph_pattern.h" #include "utils/bidict/algorithms/bidict_from_keys_and_values.h" #include "utils/bidict/algorithms/bidict_from_map.h" +#include "utils/bidict/algorithms/binary_merge_disjoint_bidicts.h" #include "utils/bidict/algorithms/exhaustive_relational_join.h" -#include "utils/bidict/algorithms/merge_disjoint_bidicts.h" #include "utils/bidict/algorithms/transform_values.h" #include "utils/containers/is_subseteq_of.h" #include "utils/containers/map_values.h" @@ -34,7 +34,7 @@ bidict exhaustive_relational_join(pattern_node_outputs.reversed(), matched_layer_output_tensors); - result = merge_disjoint_bidicts(result, mapping); + result = binary_merge_disjoint_bidicts(result, mapping); } return result; @@ -76,7 +76,7 @@ void assert_pcg_pattern_match_is_valid_for_pattern_and_subpcg( std::unordered_set pattern_inputs = get_inputs(pattern); std::unordered_set match_pattern_inputs = - keys(match.input_assignment); + unordered_keys(match.input_assignment); ASSERT(pattern_inputs == match_pattern_inputs); } diff --git a/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc index 34b8ae1e96..c0c05ad5b1 100644 --- a/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc +++ b/lib/substitutions/src/substitutions/sub_parallel_computation_graph.cc @@ -131,8 +131,8 @@ std::unordered_set get_subgraph_incoming_edges( } std::unordered_set - get_parallel_tensor_uses(SubParallelComputationGraph const &spcg, - open_parallel_tensor_guid_t const &t) { + get_open_parallel_tensor_uses(SubParallelComputationGraph const &spcg, + open_parallel_tensor_guid_t const &t) { std::unordered_set> raw_uses = get_open_kwarg_dataflow_value_uses(spcg.raw_graph, t.raw_open_dataflow_value); diff --git a/lib/substitutions/src/substitutions/substitution_builder.cc b/lib/substitutions/src/substitutions/substitution_builder.cc index f2860326ab..ffda2291b5 100644 --- a/lib/substitutions/src/substitutions/substitution_builder.cc +++ b/lib/substitutions/src/substitutions/substitution_builder.cc @@ -94,7 +94,7 @@ std::unordered_map node_expr, map_values(inputs, raw_open_kwarg_dataflow_value_from_output_graph_expr_value), - generate_map(output_slots, + generate_unordered_map(output_slots, [](TensorSlotName) { return std::monostate{}; })); return map_values( diff --git a/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc b/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc index e9087b5718..a982277d22 100644 --- a/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc +++ b/lib/substitutions/src/substitutions/unlabelled/find_pattern_matches.cc @@ -16,9 +16,6 @@ #include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_incoming_open_kwarg_dataflow_values_for_node.h" #include "utils/many_to_one/invert_many_to_one.h" #include "utils/many_to_one/many_to_one_from_map.h" -#include "utils/many_to_one/many_to_one_from_unstructured_relation.h" -#include "utils/many_to_one/unstructured_relation_from_many_to_one.h" -#include "utils/one_to_many/unstructured_relation_from_one_to_many.h" #include "utils/overload.h" namespace FlexFlow { @@ -46,7 +43,7 @@ static std::optional return OpenKwargDataflowValue{o}; }); - if (keys(pattern_outputs) != keys(graph_outputs)) { + if (unordered_keys(pattern_outputs) != unordered_keys(graph_outputs)) { return std::nullopt; } @@ -64,7 +61,7 @@ static std::optional graph_node_inputs = get_incoming_open_kwarg_dataflow_values_for_node(graph, graph_node); - if (keys(graph_node_inputs) != keys(pattern_node_inputs)) { + if (unordered_keys(graph_node_inputs) != unordered_keys(pattern_node_inputs)) { return std::nullopt; } diff --git a/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc b/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc index 703d651070..25d505b1fa 100644 --- a/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc +++ b/lib/substitutions/src/substitutions/unlabelled/pattern_matching.cc @@ -190,7 +190,7 @@ bool unlabelled_pattern_does_match( ASSERT(left_entries(match.node_assignment) == get_pattern_nodes(pattern)); ASSERT( is_subseteq_of(right_entries(match.node_assignment), get_nodes(graph))); - ASSERT(keys(match.input_assignment) == get_pattern_inputs(pattern)); + ASSERT(unordered_keys(match.input_assignment) == get_pattern_inputs(pattern)); ASSERT(is_subseteq_of(matched_by_pattern_inputs, get_all_open_kwarg_dataflow_values(graph))); diff --git a/lib/task-spec/include/task-spec/device_specific.h b/lib/task-spec/include/task-spec/device_specific.h index 2055888b1b..834378f415 100644 --- a/lib/task-spec/include/task-spec/device_specific.h +++ b/lib/task-spec/include/task-spec/device_specific.h @@ -25,6 +25,22 @@ struct DeviceSpecific { return this->tie() != other.tie(); } + bool operator<(DeviceSpecific const &other) const { + return this->tie() < other.tie(); + } + + bool operator<=(DeviceSpecific const &other) const { + return this->tie() <= other.tie(); + } + + bool operator>(DeviceSpecific const &other) const { + return this->tie() > other.tie(); + } + + bool operator>=(DeviceSpecific const &other) const { + return this->tie() >= other.tie(); + } + T const *get(device_id_t curr_device_idx) const { ASSERT(curr_device_idx == this->device_idx); return (T const *)this->ptr.get(); diff --git a/lib/task-spec/include/task-spec/device_specific_per_device_op_state.dtg.toml b/lib/task-spec/include/task-spec/device_specific_per_device_op_state.dtg.toml index 4435a472ce..0a32b29b25 100644 --- a/lib/task-spec/include/task-spec/device_specific_per_device_op_state.dtg.toml +++ b/lib/task-spec/include/task-spec/device_specific_per_device_op_state.dtg.toml @@ -4,6 +4,7 @@ type = "variant" features = [ "eq", "hash", + "ord", "fmt", ] diff --git a/lib/task-spec/include/task-spec/dynamic_graph/copy_insertion.h b/lib/task-spec/include/task-spec/dynamic_graph/copy_insertion.h index a1726c2ae1..7a383ee8eb 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/copy_insertion.h +++ b/lib/task-spec/include/task-spec/dynamic_graph/copy_insertion.h @@ -13,6 +13,11 @@ bool value_is_mapped(DynamicValueAttrs const &); bool no_part_of_graph_is_copy_inserted(DynamicOpenDataflowGraph const &); bool graph_is_fully_copy_inserted(DynamicOpenDataflowGraph const &); +std::unordered_set copies_for_invocation_inputs( + DynamicNodeInvocation const &i, + std::unordered_map const + &unmapped_value_to_mapped_source_value); + std::unordered_set perform_copy_insertion_for_invocation( DynamicNodeInvocation const &i, std::unordered_map const diff --git a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_layer_guid_t.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_layer_guid_t.dtg.toml index 8def0ec5fb..5200bfc6a6 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_layer_guid_t.dtg.toml +++ b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_layer_guid_t.dtg.toml @@ -3,6 +3,7 @@ name = "dynamic_layer_guid_t" type = "variant" features = [ "eq", + "ord", "hash", "fmt", "json", diff --git a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_node_attrs.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_node_attrs.dtg.toml index 73c023fd40..110c963383 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_node_attrs.dtg.toml +++ b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_node_attrs.dtg.toml @@ -4,6 +4,7 @@ type = "struct" features = [ "eq", "hash", + "ord", "fmt", ] @@ -11,6 +12,7 @@ includes = [ "", "task-spec/dynamic_graph/dynamic_task_type.dtg.h", "pcg/machine_space_coordinate.dtg.h", + "utils/nonempty_set/nonempty_set.h", "pcg/mapped_parallel_computation_graph/mapped_operator_task_group.h", "task-spec/dynamic_graph/dynamic_layer_guid_t.dtg.h", "task-spec/dynamic_graph/training_operation_attrs.dtg.h", @@ -26,10 +28,10 @@ name = "task_type" type = "std::optional<::FlexFlow::DynamicTaskType>" [[fields]] -name = "device_coord" -type = "std::optional<::FlexFlow::MachineSpaceCoordinate>" +name = "device_coords" +type = "std::optional<::FlexFlow::nonempty_set<::FlexFlow::MachineSpaceCoordinate>>" docstring = ''' -\note Right now the \c device_coord for a copy node is sort of meaningless +\note Right now the \c device_coords for a copy node is sort of meaningless because we have one controller issuing all copies for the entire graph, no matter where they are. However the intention is this to be the "owner" or "issuer" of the copy, which matters a lot more down the road once we write the diff --git a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_node_invocation.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_node_invocation.dtg.toml index 07060106c0..4d6d27444a 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_node_invocation.dtg.toml +++ b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_node_invocation.dtg.toml @@ -3,6 +3,7 @@ name = "DynamicNodeInvocation" type = "struct" features = [ "eq", + "ord", "fmt", "hash", ] @@ -15,13 +16,13 @@ includes = [ ] src_includes = [ - "utils/hash/unordered_map.h", - "utils/fmt/unordered_map.h", + "utils/hash/map.h", + "utils/fmt/map.h", ] [[fields]] name = "inputs" -type = "std::unordered_map<::FlexFlow::DynamicTensorSlot, ::FlexFlow::DynamicValueAttrs>" +type = "std::map<::FlexFlow::DynamicTensorSlot, ::FlexFlow::DynamicValueAttrs>" [[fields]] name = "node_attrs" @@ -29,4 +30,4 @@ type = "::FlexFlow::DynamicNodeAttrs" [[fields]] name = "outputs" -type = "std::unordered_map<::FlexFlow::DynamicTensorSlot, ::FlexFlow::DynamicValueAttrs>" +type = "std::map<::FlexFlow::DynamicTensorSlot, ::FlexFlow::DynamicValueAttrs>" diff --git a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_node_invocation.h b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_node_invocation.h new file mode 100644 index 0000000000..a87f2241c4 --- /dev/null +++ b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_node_invocation.h @@ -0,0 +1,20 @@ +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_DYNAMIC_NODE_INVOCATION_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_DYNAMIC_NODE_INVOCATION_H + +#include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" + +namespace FlexFlow { + +bool invocation_fully_satisfies(DynamicNodeInvocation const &, + std::function const &node_condition, + std::function const &value_condition, + std::function const &slot_condition); + +void require_invocation_fully_satisfies(DynamicNodeInvocation const &, + std::function const &require_node_condition, + std::function const &require_value_condition, + std::function const &require_slot_condition); + +} // namespace FlexFlow + +#endif diff --git a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_node_invocation_sharding_info.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_node_invocation_sharding_info.dtg.toml new file mode 100644 index 0000000000..00d98a2e6c --- /dev/null +++ b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_node_invocation_sharding_info.dtg.toml @@ -0,0 +1,30 @@ +namespace = "FlexFlow" +name = "DynamicNodeInvocationShardingInfo" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", + "json", +] + +includes = [ + "pcg/machine_space_coordinate.dtg.h", + "task-spec/dynamic_graph/dynamic_tensor_slot.dtg.h", + "task-spec/dynamic_graph/dynamic_value_attrs_sharding_info.dtg.h", + "utils/nonempty_set/nonempty_set.h", +] + +src_includes = [ + "utils/hash/map.h", + "utils/fmt/map.h", +] + +[[fields]] +name = "device_coords" +type = "::FlexFlow::nonempty_set<::FlexFlow::MachineSpaceCoordinate>" + +[[fields]] +name = "value_sharding" +type = "std::map<::FlexFlow::DynamicTensorSlot, ::FlexFlow::DynamicValueAttrsShardingInfo>" diff --git a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_open_dataflow_graph.h b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_open_dataflow_graph.h index 4ca62db5b1..1aba00a675 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_open_dataflow_graph.h +++ b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_open_dataflow_graph.h @@ -23,6 +23,12 @@ bool no_part_of_dynamic_graph_satisfies( std::function const &, std::function const &); +void require_full_dynamic_graph_satisfies( + DynamicOpenDataflowGraph const &, + std::function const &, + std::function const &, + std::function const &); + std::unordered_multiset get_dynamic_nodes(DynamicOpenDataflowGraph const &); std::unordered_multiset diff --git a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_tensor_accessor.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_tensor_accessor.dtg.toml index 85f8f299a4..bfe30a7a5e 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_tensor_accessor.dtg.toml +++ b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_tensor_accessor.dtg.toml @@ -3,6 +3,7 @@ name = "DynamicTensorAccessor" type = "variant" features = [ "eq", + "ord", "fmt", "hash", ] diff --git a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_tensor_guid_t.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_tensor_guid_t.dtg.toml index c9171b928b..56f8ae4359 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_tensor_guid_t.dtg.toml +++ b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_tensor_guid_t.dtg.toml @@ -3,6 +3,7 @@ name = "dynamic_tensor_guid_t" type = "variant" features = [ "eq", + "ord", "hash", "fmt", "json", diff --git a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_tensor_slot.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_tensor_slot.dtg.toml index 378582f428..851b8dc83d 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_tensor_slot.dtg.toml +++ b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_tensor_slot.dtg.toml @@ -12,6 +12,7 @@ features = [ includes = [ "op-attrs/tensor_slot_name.dtg.h", "task-spec/dynamic_graph/dynamic_tensor_role.dtg.h", + "pcg/machine_space_coordinate.dtg.h", "", ] @@ -27,3 +28,13 @@ type = "::FlexFlow::TensorSlotName" [[fields]] name = "slot_tensor_role" type = "std::optional<::FlexFlow::DynamicTensorRole>" + +[[fields]] +name = "task_shard" +type = "std::optional<::FlexFlow::MachineSpaceCoordinate>" +docstring = ''' +\brief For representing parallel operators such as \ref ReplicateAttrs as a single operator with multiple outputs instead of fully shard expanding. + +This is done for the convenience of the runtime, as the \ref ReplicateAttrs is ultimately executed as a single NCCL task/call, so shard-expanding this operator is ultimately counter-productive. +Since the output values, however, do need to be shard-expanded for the rest of the system to work, we make Replicate a single operator with multiple outputs, each with the same \ref TensorSlotName and \ref DynamicTensorRole, but with different \ref MachineSpaceCoordinate ""s to identify which node is ultimately producing that value. +''' diff --git a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_value_attrs.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_value_attrs.dtg.toml index 490a51f88d..1e13736fc7 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_value_attrs.dtg.toml +++ b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_value_attrs.dtg.toml @@ -3,6 +3,7 @@ name = "DynamicValueAttrs" type = "struct" features = [ "eq", + "ord", "fmt", "hash", ] @@ -13,7 +14,7 @@ includes = [ "op-attrs/parallel_tensor_shape.dtg.h", "op-attrs/parallel_tensor_space_coordinate.dtg.h", "pcg/machine_space_coordinate.dtg.h", - "utils/bidict/bidict.h", + "utils/one_to_many/one_to_many.h", "task-spec/dynamic_graph/dynamic_tensor_accessor.dtg.h", "task-spec/dynamic_graph/dynamic_tensor_role.dtg.h", ] @@ -25,18 +26,38 @@ src_includes = [ [[fields]] name = "tensor_guid" type = "::FlexFlow::dynamic_tensor_guid_t" +docstring = ''' +\brief The \ref tensor_guid_t or \ref parallel_tensor_guid_t of the (usually parallel) tensor this value originates from. Also allows representing tensors for computing the loss that lie outside of the scope of the \ref ComputationGraph or \ref ParallelComputationGraph, e.g., the label tensor. + +For a \ref DynamicOpenDataflowGraph originating from a \ref MapepdParallelComputationGraph, this field is filled in by \ref make_dynamic_open_dataflow_graph_from_mapped_pcg.h. +''' [[fields]] name = "parallel_tensor_shape" type = "std::optional<::FlexFlow::ParallelTensorShape>" +docstring = ''' +\brief The \ref ParallelTensorShape of the parallel tensor this value originates from. + +For a \ref DynamicOpenDataflowGraph originating form a \ref MappedParallelComputationGraph, this field is filled in by \ref make_dynamic_open_dataflow_graph_from_mapped_pcg.h. +''' [[fields]] name = "shard_coord" type = "std::optional<::FlexFlow::ParallelTensorSpaceCoordinate>" +docstring = ''' +\brief The shard (i.e., \ref ParallelTensorSpaceCoordinate) of the (usually parallel) tensor represented by this value. + +For a \ref DynamicOpenDataflowGraph originating from a \ref MappedParallelComputationGraph, this field is filled in by \ref shard_expansion.h. +''' [[fields]] name = "mapping" -type = "std::optional<::FlexFlow::bidict<::FlexFlow::ParallelTensorSpaceCoordinate, ::FlexFlow::MachineSpaceCoordinate>>" +type = "std::optional<::FlexFlow::OneToMany<::FlexFlow::ParallelTensorSpaceCoordinate, ::FlexFlow::MachineSpaceCoordinate>>" +docstring = ''' +\brief The location (i.e., \ref MachineSpaceCoordinate) of each shard (i.e., \ref ParallelTensorSpaceCoordinate) of this (usually parallel) tensor. + +For a \ref DynamicOpenDataflowGraph originating from a \ref MappedParallelComputationGraph, this field is filled in by \ref shard_expansion.h. +''' [[fields]] name = "accessor" @@ -45,3 +66,8 @@ type = "std::optional<::FlexFlow::DynamicTensorAccessor>" [[fields]] name = "role" type = "std::optional<::FlexFlow::DynamicTensorRole>" +docstring = ''' +\brief Identifies the role this tensor plays in training, e.g., a forward tensor, a gradient tensor, an optimizer buffer, a loss tensor, etc. + +For a \ref DynamicOpenDataflowGraph originating from a \ref MappedParallelComputationGraph, this field is filled in by \ref pass_expansion.h. +''' diff --git a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_value_attrs.h b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_value_attrs.h index 9cccc565cc..aa9fbd2874 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_value_attrs.h +++ b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_value_attrs.h @@ -8,6 +8,10 @@ namespace FlexFlow { DynamicValueAttrs decide_dynamic_value_attrs_role(DynamicValueAttrs const &, DynamicTensorRole); +DynamicValueAttrs decide_dynamic_value_attrs_mapping( + DynamicValueAttrs const &, + OneToMany const &); + } // namespace FlexFlow #endif diff --git a/lib/task-spec/include/task-spec/dynamic_graph/dynamic_value_attrs_sharding_info.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_value_attrs_sharding_info.dtg.toml new file mode 100644 index 0000000000..5a9234d815 --- /dev/null +++ b/lib/task-spec/include/task-spec/dynamic_graph/dynamic_value_attrs_sharding_info.dtg.toml @@ -0,0 +1,27 @@ +namespace = "FlexFlow" +name = "DynamicValueAttrsShardingInfo" +type = "struct" +features = [ + "eq", + "ord", + "hash", + "fmt", + "json", +] + +includes = [ + "utils/one_to_many/one_to_many.h", + "op-attrs/parallel_tensor_space_coordinate.dtg.h", + "pcg/machine_space_coordinate.dtg.h", +] + +src_includes = [ +] + +[[fields]] +name = "shard_coord" +type = "::FlexFlow::ParallelTensorSpaceCoordinate" + +[[fields]] +name = "mapping" +type = "::FlexFlow::OneToMany<::FlexFlow::ParallelTensorSpaceCoordinate, ::FlexFlow::MachineSpaceCoordinate>" diff --git a/lib/task-spec/include/task-spec/dynamic_graph/index.dox b/lib/task-spec/include/task-spec/dynamic_graph/index.dox index c48e67f4b3..97b72f8553 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/index.dox +++ b/lib/task-spec/include/task-spec/dynamic_graph/index.dox @@ -7,6 +7,7 @@ namespace FlexFlow { \section task-spec-lowering-passes Lowering Passes +- \ref make_dynamic_open_dataflow_graph_from_mapped_pcg.h: Embeds a \ref MappedParallelComputationGraph as a \ref DynamicOpenDataflowGraph. The first of the lowering passes. - \ref pass_expansion.h - \ref shard_expansion.h - \ref update_insertion.h diff --git a/lib/task-spec/include/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.h b/lib/task-spec/include/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.h index 6a269ec3c9..f1d693c975 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.h +++ b/lib/task-spec/include/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.h @@ -3,9 +3,15 @@ #include "pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.dtg.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.dtg.h" +#include "pcg/mapped_parallel_computation_graph/mapped_parallel_layer_invocation_info.dtg.h" namespace FlexFlow { +DynamicNodeInvocation make_dynamic_node_invocation_from_mapped( + MappedParallelLayerInvocationInfo const &); + +DynamicNodeInvocation build_replicate_invocation(MappedParallelLayerInvocationInfo const &); + DynamicOpenDataflowGraph make_dynamic_open_dataflow_graph_from_mapped_pcg( MappedParallelComputationGraph const &); diff --git a/lib/task-spec/include/task-spec/dynamic_graph/pass_expansion.h b/lib/task-spec/include/task-spec/dynamic_graph/pass_expansion.h index 6dce8ad514..ad07b2941f 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/pass_expansion.h +++ b/lib/task-spec/include/task-spec/dynamic_graph/pass_expansion.h @@ -10,8 +10,13 @@ bool node_is_pass_expanded(DynamicNodeAttrs const &); bool value_is_pass_expanded(DynamicValueAttrs const &); bool slot_is_pass_expanded(DynamicTensorSlot const &); +bool node_is_ready_for_pass_expansion(DynamicNodeAttrs const &); +bool value_is_ready_for_pass_expansion(DynamicValueAttrs const &); +bool slot_is_ready_for_pass_expansion(DynamicTensorSlot const &); + bool no_part_of_graph_is_pass_expanded(DynamicOpenDataflowGraph const &); bool graph_is_fully_pass_expanded(DynamicOpenDataflowGraph const &); +bool graph_is_ready_for_pass_expansion(DynamicOpenDataflowGraph const &); DynamicNodeInvocation perform_fwd_pass_expansion_for_invocation(DynamicNodeInvocation const &); diff --git a/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_node_attrs.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_node_attrs.dtg.toml index 3c43e1d637..c8ea415ee2 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_node_attrs.dtg.toml +++ b/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_node_attrs.dtg.toml @@ -3,6 +3,7 @@ name = "SerializableDynamicNodeAttrs" type = "struct" features = [ "eq", + "ord", "hash", "fmt", "json", @@ -15,6 +16,7 @@ includes = [ "pcg/mapped_parallel_computation_graph/mapped_operator_task_group.h", "task-spec/dynamic_graph/dynamic_layer_guid_t.dtg.h", "task-spec/dynamic_graph/training_operation_attrs.dtg.h", + "utils/nonempty_set/nonempty_set.h", ] src_includes = [ @@ -27,8 +29,8 @@ name = "task_type" type = "std::optional<::FlexFlow::DynamicTaskType>" [[fields]] -name = "device_coord" -type = "std::optional<::FlexFlow::MachineSpaceCoordinate>" +name = "device_coords" +type = "std::optional<::FlexFlow::nonempty_set<::FlexFlow::MachineSpaceCoordinate>>" [[fields]] name = "mapping" diff --git a/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_node_invocation.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_node_invocation.dtg.toml index 01f4cc8876..6051a49876 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_node_invocation.dtg.toml +++ b/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_node_invocation.dtg.toml @@ -3,6 +3,7 @@ name = "SerializableDynamicNodeInvocation" type = "struct" features = [ "eq", + "ord", "fmt", "hash", "json", @@ -16,13 +17,13 @@ includes = [ ] src_includes = [ - "utils/hash/unordered_map.h", - "utils/fmt/unordered_map.h", + "utils/hash/map.h", + "utils/fmt/map.h", ] [[fields]] name = "inputs" -type = "std::unordered_map<::FlexFlow::DynamicTensorSlot, ::FlexFlow::SerializableDynamicValueAttrs>" +type = "std::map<::FlexFlow::DynamicTensorSlot, ::FlexFlow::SerializableDynamicValueAttrs>" [[fields]] name = "node_attrs" @@ -30,4 +31,4 @@ type = "::FlexFlow::SerializableDynamicNodeAttrs" [[fields]] name = "outputs" -type = "std::unordered_map<::FlexFlow::DynamicTensorSlot, ::FlexFlow::SerializableDynamicValueAttrs>" +type = "std::map<::FlexFlow::DynamicTensorSlot, ::FlexFlow::SerializableDynamicValueAttrs>" diff --git a/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_value_attrs.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_value_attrs.dtg.toml index 454f1b7e8c..d05d8e011e 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_value_attrs.dtg.toml +++ b/lib/task-spec/include/task-spec/dynamic_graph/serializable_dynamic_value_attrs.dtg.toml @@ -3,6 +3,7 @@ name = "SerializableDynamicValueAttrs" type = "struct" features = [ "eq", + "ord", "hash", "fmt", "json", @@ -14,7 +15,7 @@ includes = [ "op-attrs/parallel_tensor_shape.dtg.h", "op-attrs/parallel_tensor_space_coordinate.dtg.h", "pcg/machine_space_coordinate.dtg.h", - "utils/bidict/bidict.h", + "utils/one_to_many/one_to_many.h", "task-spec/dynamic_graph/dynamic_tensor_role.dtg.h", ] @@ -37,7 +38,7 @@ type = "std::optional<::FlexFlow::ParallelTensorSpaceCoordinate>" [[fields]] name = "mapping" -type = "std::optional<::FlexFlow::bidict<::FlexFlow::ParallelTensorSpaceCoordinate, ::FlexFlow::MachineSpaceCoordinate>>" +type = "std::optional<::FlexFlow::OneToMany<::FlexFlow::ParallelTensorSpaceCoordinate, ::FlexFlow::MachineSpaceCoordinate>>" [[fields]] name = "role" diff --git a/lib/task-spec/include/task-spec/dynamic_graph/shard_expansion.h b/lib/task-spec/include/task-spec/dynamic_graph/shard_expansion.h index 4e0db1cd7e..50c713ee3b 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/shard_expansion.h +++ b/lib/task-spec/include/task-spec/dynamic_graph/shard_expansion.h @@ -4,19 +4,42 @@ #include "task-spec/dynamic_graph/dynamic_node_attrs.dtg.h" #include "task-spec/dynamic_graph/dynamic_node_invocation.dtg.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.dtg.h" +#include "task-spec/dynamic_graph/dynamic_value_attrs_sharding_info.dtg.h" +#include "task-spec/dynamic_graph/dynamic_node_invocation_sharding_info.dtg.h" namespace FlexFlow { -bool node_is_shard_expanded(DynamicNodeAttrs const &); -bool value_is_shard_expanded(DynamicValueAttrs const &); +[[nodiscard]] bool node_is_shard_expanded(DynamicNodeAttrs const &); +[[nodiscard]] bool value_is_shard_expanded(DynamicValueAttrs const &); +[[nodiscard]] bool invocation_is_fully_shard_expanded(DynamicNodeInvocation const &); -bool no_part_of_graph_is_shard_expanded(DynamicOpenDataflowGraph const &); -bool graph_is_fully_shard_expanded(DynamicOpenDataflowGraph const &); +[[nodiscard]] bool node_is_ready_for_shard_expansion(DynamicNodeAttrs const &); +[[nodiscard]] bool value_is_ready_for_shard_expansion(DynamicValueAttrs const &); +[[nodiscard]] bool invocation_is_ready_for_shard_expansion(DynamicNodeInvocation const &); -std::unordered_set +[[nodiscard]] bool no_part_of_graph_is_shard_expanded(DynamicOpenDataflowGraph const &); +[[nodiscard]] bool graph_is_fully_shard_expanded(DynamicOpenDataflowGraph const &); +[[nodiscard]] bool graph_is_ready_for_shard_expansion(DynamicOpenDataflowGraph const &); + +[[nodiscard]] DynamicNodeAttrs apply_dynamic_node_attrs_sharding_info( + DynamicNodeAttrs const &, + MachineSpaceCoordinate const &); + +[[nodiscard]] DynamicValueAttrs apply_dynamic_value_attrs_sharding_info( + DynamicValueAttrs const &, + DynamicValueAttrsShardingInfo const &); + +[[nodiscard]] DynamicNodeInvocation apply_dynamic_node_invocation_sharding_info( + DynamicNodeInvocation const &, + DynamicNodeInvocationShardingInfo const &); + +[[nodiscard]] std::unordered_set + generate_shard_expansion_for_invocation(DynamicNodeInvocation const &); + +[[nodiscard]] std::unordered_set perform_shard_expansion_for_invocation(DynamicNodeInvocation const &); -DynamicOpenDataflowGraph +[[nodiscard]] DynamicOpenDataflowGraph perform_shard_expansion(DynamicOpenDataflowGraph const &); } // namespace FlexFlow diff --git a/lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.dtg.toml b/lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.dtg.toml index 8f8f6467c8..2c4e739571 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.dtg.toml +++ b/lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.dtg.toml @@ -3,6 +3,7 @@ name = "TrainingOperationAttrs" type = "variant" features = [ "eq", + "ord", "hash", "fmt", "json", diff --git a/lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.h b/lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.h new file mode 100644 index 0000000000..9caea8c341 --- /dev/null +++ b/lib/task-spec/include/task-spec/dynamic_graph/training_operation_attrs.h @@ -0,0 +1,14 @@ +#ifndef _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_TRAINING_OPERATION_ATTRS_H +#define _FLEXFLOW_LIB_TASK_SPEC_INCLUDE_TASK_SPEC_DYNAMIC_GRAPH_TRAINING_OPERATION_ATTRS_H + +#include "op-attrs/operator_type.dtg.h" +#include "task-spec/dynamic_graph/training_operation_attrs.dtg.h" + +namespace FlexFlow { + +bool training_op_attrs_has_op_type(TrainingOperationAttrs const &, + OperatorType); + +} // namespace FlexFlow + +#endif diff --git a/lib/task-spec/include/task-spec/dynamic_graph/update_insertion.h b/lib/task-spec/include/task-spec/dynamic_graph/update_insertion.h index 23fb7050a0..9818152b34 100644 --- a/lib/task-spec/include/task-spec/dynamic_graph/update_insertion.h +++ b/lib/task-spec/include/task-spec/dynamic_graph/update_insertion.h @@ -6,6 +6,15 @@ namespace FlexFlow { +bool node_has_already_had_update_insertion_performed(DynamicNodeAttrs const &); +bool value_has_already_had_update_insertion_performed(DynamicValueAttrs const &); + +bool node_is_ready_for_update_insertion(DynamicNodeAttrs const &); +bool value_is_ready_for_update_insertion(DynamicValueAttrs const &); + +bool no_part_of_graph_has_had_update_insertion_performed(DynamicOpenDataflowGraph const &); +bool graph_is_ready_for_update_insertion(DynamicOpenDataflowGraph const &); + std::unordered_set perform_update_insertion_for_invocation(DynamicNodeInvocation const &, OptimizerAttrs const &); diff --git a/lib/task-spec/src/task-spec/dynamic_graph/copy_insertion.cc b/lib/task-spec/src/task-spec/dynamic_graph/copy_insertion.cc index 4c1b9d4609..ebf84f2d81 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/copy_insertion.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/copy_insertion.cc @@ -10,7 +10,7 @@ #include "task-spec/dynamic_graph/dynamic_tensor_slot.dtg.h" #include "task-spec/dynamic_graph/dynamic_value_attrs.dtg.h" #include "utils/bidict/algorithms/bidict_from_pairs.h" -#include "utils/bidict/algorithms/unordered_set_of.h" +#include "utils/bidict/algorithms/bidict_unordered_set_of.h" #include "utils/containers/contains_key.h" #include "utils/containers/flatmap.h" #include "utils/containers/intersection.h" @@ -18,6 +18,7 @@ #include "utils/containers/set_difference.h" #include "utils/containers/transform.h" #include "utils/optional.h" +#include "task-spec/dynamic_graph/training_operation_attrs.h" namespace FlexFlow { @@ -31,9 +32,22 @@ bool value_is_mapped(DynamicValueAttrs const &n) { bool no_part_of_graph_is_copy_inserted(DynamicOpenDataflowGraph const &g) { auto slot_is_mapped = [](DynamicTensorSlot const &) -> bool { return false; }; - - return no_part_of_dynamic_graph_satisfies( - g, node_is_copy, value_is_mapped, slot_is_mapped); + for (DynamicNodeInvocation const &i : g.invocations) { + if (node_is_copy(i.node_attrs)) { + return false; + } + for (auto const &[slot, value] : i.inputs) { + if (value_is_mapped(value)) { + return false; + } + } + for (auto const &[slot, value] : i.outputs) { + if (value_is_mapped(value)) { + return false; + } + } + } + return true; } bool graph_is_fully_copy_inserted(DynamicOpenDataflowGraph const &g) { @@ -44,6 +58,26 @@ bool graph_is_fully_copy_inserted(DynamicOpenDataflowGraph const &g) { g, node_is_any, value_is_mapped, slot_is_mapped); } +void require_node_is_ready_for_copy_insertion(DynamicNodeAttrs const &node_attrs) { + ASSERT(node_attrs.mapping.has_value()); +} + +void require_graph_is_ready_for_copy_insertion(DynamicOpenDataflowGraph const &g) { + auto require_slot_is_ready_for_copy_insertion = [](DynamicTensorSlot const &slot) -> void { + return; + }; + + auto require_value_is_ready_for_copy_insertion = [](DynamicValueAttrs const &value_attrs) -> void { + return; + }; + + require_full_dynamic_graph_satisfies( + g, + require_node_is_ready_for_copy_insertion, + require_value_is_ready_for_copy_insertion, + require_slot_is_ready_for_copy_insertion); +} + static DynamicValueAttrs map_dynamic_value_attrs_for_task_group( DynamicTensorSlot const &slot, DynamicValueAttrs const &value, @@ -58,10 +92,10 @@ static std::pair DynamicValueAttrs const &output) { std::unordered_set< std::pair> - input_mapping = unordered_set_of(assert_unwrap(input.mapping)); + input_mapping = unstructured_relation_from_one_to_many(assert_unwrap(input.mapping)); std::unordered_set< std::pair> - output_mapping = unordered_set_of(assert_unwrap(output.mapping)); + output_mapping = unstructured_relation_from_one_to_many(assert_unwrap(output.mapping)); // Exclude the point shared between the input and output mappings, because // those will not result in actual copies once shard expansion is performed @@ -71,19 +105,24 @@ static std::pair DynamicValueAttrs filtered_input = input; filtered_input.mapping = - bidict_from_pairs(set_difference(input_mapping, remove)); + one_to_many_from_unstructured_relation(set_difference(input_mapping, remove)); DynamicValueAttrs filtered_output = output; filtered_output.mapping = - bidict_from_pairs(set_difference(output_mapping, remove)); + one_to_many_from_unstructured_relation(set_difference(output_mapping, remove)); return std::pair{filtered_input, filtered_output}; } -std::unordered_set perform_copy_insertion_for_invocation( - DynamicNodeInvocation const &i, - std::unordered_map const - &unmapped_value_to_mapped_source_value) { +std::unordered_set copies_for_invocation_inputs( + DynamicNodeInvocation const &i, + std::unordered_map const &unmapped_value_to_src_mapped_value) +{ + if (training_op_attrs_has_op_type(assert_unwrap(i.node_attrs.op_attrs), OperatorType::REPLICATE)) { + // copies should not be inserted before a replicate, as the replicate + // implicitly includes the copy operations + return {}; + } MappedOperatorTaskGroup mapping = assert_unwrap(i.node_attrs.mapping); @@ -92,33 +131,29 @@ std::unordered_set perform_copy_insertion_for_invocation( return map_dynamic_value_attrs_for_task_group(slot, value, mapping); }; - std::unordered_map mapped_inputs = + std::map mapped_inputs = map_values2(i.inputs, map_tensor); - std::unordered_map mapped_outputs = - map_values2(i.outputs, map_tensor); - std::unordered_set result{DynamicNodeInvocation{ - /*inputs=*/mapped_inputs, - /*node_attrs=*/i.node_attrs, - /*outputs=*/mapped_outputs, - }}; + std::unordered_set result; for (auto const &[slot, input] : i.inputs) { - if (!contains_key(unmapped_value_to_mapped_source_value, input)) { + if (!contains_key(unmapped_value_to_src_mapped_value, input)) { continue; } - DynamicValueAttrs source_value = - unmapped_value_to_mapped_source_value.at(input); - DynamicValueAttrs use_value = mapped_inputs.at(slot); - if (source_value != use_value) { - auto const &[filtered_source, filtered_use] = - filter_mapping_to_avoid_degenerate_copies(source_value, use_value); + DynamicValueAttrs src_mapped_value = unmapped_value_to_src_mapped_value.at(input); + DynamicValueAttrs use_mapped_value = mapped_inputs.at(slot); + + if (src_mapped_value != use_mapped_value) { + auto const &[filtered_source, filtered_use] = filter_mapping_to_avoid_degenerate_copies(src_mapped_value, use_mapped_value); DynamicNodeInvocation copy{ /*inputs=*/{ { - DynamicTensorSlot{TensorSlotName::INPUT, - slot.slot_tensor_role}, + DynamicTensorSlot{ + TensorSlotName::INPUT, + slot.slot_tensor_role, + /*task_shard=*/std::nullopt, + }, filtered_source, }, }, @@ -136,8 +171,11 @@ std::unordered_set perform_copy_insertion_for_invocation( /*outputs=*/ { { - DynamicTensorSlot{TensorSlotName::OUTPUT, - slot.slot_tensor_role}, + DynamicTensorSlot{ + TensorSlotName::OUTPUT, + slot.slot_tensor_role, + /*task_shard=*/std::nullopt, + }, filtered_use, }, }, @@ -149,10 +187,44 @@ std::unordered_set perform_copy_insertion_for_invocation( return result; } +std::unordered_set perform_copy_insertion_for_invocation( + DynamicNodeInvocation const &i, + std::unordered_map const + &unmapped_value_to_mapped_source_value) { + + MappedOperatorTaskGroup mapping = assert_unwrap(i.node_attrs.mapping); + + auto map_tensor = [&](DynamicTensorSlot const &slot, + DynamicValueAttrs const &value) { + return map_dynamic_value_attrs_for_task_group(slot, value, mapping); + }; + + DynamicNodeInvocation mapped_i = [&] { + std::map mapped_inputs = + map_values2(i.inputs, map_tensor); + std::map mapped_outputs = + map_values2(i.outputs, map_tensor); + + DynamicNodeInvocation r = i; + r.inputs = mapped_inputs; + r.outputs = mapped_outputs; + return r; + }(); + + std::unordered_set result = set_union( + copies_for_invocation_inputs(i, unmapped_value_to_mapped_source_value), + std::unordered_set{ + mapped_i, + }); + + return result; +} + DynamicOpenDataflowGraph perform_copy_insertion(DynamicOpenDataflowGraph const &g) { ASSERT(no_part_of_graph_is_copy_inserted(g)); + require_graph_is_ready_for_copy_insertion(g); std::unordered_map unmapped_value_to_mapped_source_value; diff --git a/lib/task-spec/src/task-spec/dynamic_graph/dynamic_node_invocation.cc b/lib/task-spec/src/task-spec/dynamic_graph/dynamic_node_invocation.cc new file mode 100644 index 0000000000..ea50449347 --- /dev/null +++ b/lib/task-spec/src/task-spec/dynamic_graph/dynamic_node_invocation.cc @@ -0,0 +1,35 @@ +#include "task-spec/dynamic_graph/dynamic_node_invocation.h" +#include "utils/containers/values.h" +#include "utils/containers/keys.h" +#include "utils/containers/all_of.h" + +namespace FlexFlow { + +bool invocation_fully_satisfies(DynamicNodeInvocation const &i, + std::function const &node_condition, + std::function const &value_condition, + std::function const &slot_condition) +{ + return node_condition(i.node_attrs) + && all_of(values(i.inputs), value_condition) + && all_of(keys(i.inputs), slot_condition) + && all_of(values(i.outputs), value_condition) + && all_of(keys(i.outputs), slot_condition); +} + +void require_invocation_fully_satisfies(DynamicNodeInvocation const &i, + std::function const &require_node_condition, + std::function const &require_value_condition, + std::function const &require_slot_condition) { + require_node_condition(i.node_attrs); + for (DynamicTensorSlot const &k : keys(i.inputs)) { + require_slot_condition(k); + require_value_condition(i.inputs.at(k)); + } + for (DynamicTensorSlot const &k : keys(i.outputs)) { + require_slot_condition(k); + require_value_condition(i.outputs.at(k)); + } +} + +} // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/dynamic_graph/dynamic_open_dataflow_graph.cc b/lib/task-spec/src/task-spec/dynamic_graph/dynamic_open_dataflow_graph.cc index d2a5b653e5..38d49ef183 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/dynamic_open_dataflow_graph.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/dynamic_open_dataflow_graph.cc @@ -19,6 +19,9 @@ #include "utils/graph/open_dataflow_graph/algorithms/get_inputs.h" #include "utils/graph/open_kwarg_dataflow_graph/kwarg_dataflow_graph_input.dtg.h" #include "utils/many_to_one/many_to_one.h" +#include "utils/containers/require_all_of.h" +#include "utils/containers/unordered_map_from_map.h" +#include "utils/containers/map_from_unordered.h" namespace FlexFlow { @@ -56,6 +59,18 @@ bool no_part_of_dynamic_graph_satisfies( [&](DynamicTensorSlot const &s) -> bool { return !slot_condition(s); }); } +void require_full_dynamic_graph_satisfies( + DynamicOpenDataflowGraph const &g, + std::function const &node_condition, + std::function const &value_condition, + std::function const &slot_condition) +{ + require_all_of(get_dynamic_nodes(g), node_condition); + require_all_of(get_dynamic_values(g), value_condition); + require_all_of(get_dynamic_tensor_slots(g), slot_condition); +} + + std::unordered_multiset get_dynamic_nodes(DynamicOpenDataflowGraph const &g) { return transform(unordered_multiset_of(g.invocations), @@ -203,16 +218,16 @@ std::pair void { KwargNodeAddedResult added = result.add_node( invocation.node_attrs, - map_values(invocation.inputs, + map_values(unordered_map_from_map(invocation.inputs), [&](DynamicValueAttrs const &input) -> OpenKwargDataflowValue { return value_map.at_r(input); }), - invocation.outputs); + unordered_map_from_map(invocation.outputs)); node_map.equate(added.node, invocation); for (auto const &[k, v] : - zip_values_strict(invocation.outputs, added.outputs)) { + zip_values_strict(invocation.outputs, map_from_unordered(added.outputs))) { DynamicValueAttrs invocation_output = v.first; KwargDataflowOutput graph_output = v.second; value_map.equate( diff --git a/lib/task-spec/src/task-spec/dynamic_graph/dynamic_value_attrs.cc b/lib/task-spec/src/task-spec/dynamic_graph/dynamic_value_attrs.cc index 282279edbe..9a70c5cdd0 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/dynamic_value_attrs.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/dynamic_value_attrs.cc @@ -13,4 +13,17 @@ DynamicValueAttrs return result; } +DynamicValueAttrs decide_dynamic_value_attrs_mapping( + DynamicValueAttrs const &attrs, + OneToMany const &mapping) +{ + ASSERT(!attrs.mapping.has_value()); + + DynamicValueAttrs result = attrs; + result.mapping = mapping; + + return result; +} + + } // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/dynamic_graph/loss_insertion.cc b/lib/task-spec/src/task-spec/dynamic_graph/loss_insertion.cc index 8066926262..8ff24b51ad 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/loss_insertion.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/loss_insertion.cc @@ -18,6 +18,7 @@ LossInsertionResult perform_loss_insertion( LossAttrs const &loss_attrs, dynamic_tensor_guid_t logit_tensor, std::optional const &loss_mapping) { + DynamicValueAttrs logit_value = assert_unwrap( find_output_value_attrs(dg, logit_tensor, mk_dynamic_tensor_role_fwd())); @@ -29,6 +30,7 @@ LossInsertionResult perform_loss_insertion( /*accessor=*/std::nullopt, /*role=*/mk_dynamic_tensor_role_loss(), }; + DynamicValueAttrs logit_grad_value{ /*tensor_guid=*/logit_value.tensor_guid, /*parallel_tensor_shape=*/logit_value.parallel_tensor_shape, @@ -37,14 +39,25 @@ LossInsertionResult perform_loss_insertion( /*accessor=*/std::nullopt, /*role=*/mk_dynamic_tensor_role_bwd(), }; + DynamicNodeInvocation loss_invocation{ /*inputs=*/{ - {DynamicTensorSlot{/*slot_name=*/TensorSlotName::INPUT, - /*slot_tensor_role=*/label_value.role}, - label_value}, - {DynamicTensorSlot{/*slot_name=*/TensorSlotName::LOGIT, - /*slot_tensor_role=*/logit_value.role}, - logit_value}, + { + DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::INPUT, + /*slot_tensor_role=*/label_value.role, + /*task_shard=*/std::nullopt, + }, + label_value, + }, + { + DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::LOGIT, + /*slot_tensor_role=*/logit_value.role, + /*task_shard=*/std::nullopt, + }, + logit_value, + }, }, /*node_attrs=*/ DynamicNodeAttrs{ @@ -57,11 +70,17 @@ LossInsertionResult perform_loss_insertion( }, /*outputs=*/ { - {DynamicTensorSlot{/*slot_name=*/TensorSlotName::LOGIT, - /*slot_tensor_role=*/logit_grad_value.role}, - logit_grad_value}, + { + DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::LOGIT, + /*slot_tensor_role=*/logit_grad_value.role, + /*task_shard=*/std::nullopt, + }, + logit_grad_value, + }, }, }; + DynamicOpenDataflowGraph result = dg; result.invocations.insert(loss_invocation); return LossInsertionResult{result, label_value, logit_grad_value}; diff --git a/lib/task-spec/src/task-spec/dynamic_graph/machine_slicing.cc b/lib/task-spec/src/task-spec/dynamic_graph/machine_slicing.cc index 0a22015ddf..6dd73bed7d 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/machine_slicing.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/machine_slicing.cc @@ -8,9 +8,9 @@ std::unordered_set DynamicNodeInvocation const &invocation, MachineSpaceCoordinate const &device_coord) { - ASSERT(invocation.node_attrs.device_coord.has_value()); + ASSERT(invocation.node_attrs.device_coords.has_value()); - if (invocation.node_attrs.device_coord.value() == device_coord) { + if (contains(invocation.node_attrs.device_coords.value(), device_coord)) { return {invocation}; } else { return {}; diff --git a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_cg.cc b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_cg.cc index 6bfc477e3a..740a93415e 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_cg.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_cg.cc @@ -7,10 +7,11 @@ #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" #include "task-spec/dynamic_graph/dynamic_tensor_role.h" #include "task-spec/dynamic_graph/training_operation_attrs.dtg.h" -#include "utils/containers/generate_map.h" +#include "utils/containers/generate_unordered_map.h" #include #include #include +#include "utils/containers/map_from_unordered.h" namespace FlexFlow { @@ -39,6 +40,7 @@ DynamicOpenDataflowGraph DynamicTensorSlot{ /*slot_name=*/slot_name, /*slot_tensor_role=*/std::nullopt, + /*task_shard=*/std::nullopt, }, DynamicValueAttrs{ /*tensor_guid=*/dynamic_tensor_guid_t{tensor}, @@ -50,6 +52,7 @@ DynamicOpenDataflowGraph }, }; }); + std::unordered_map result_outputs = transform( get_outgoing_tensors(cg, layer), @@ -59,6 +62,7 @@ DynamicOpenDataflowGraph DynamicTensorSlot{ /*slot_name=*/slot_name, /*slot_tensor_role=*/std::nullopt, + /*task_shard=*/std::nullopt, }, DynamicValueAttrs{ /*tensor_guid=*/dynamic_tensor_guid_t{tensor}, @@ -71,7 +75,7 @@ DynamicOpenDataflowGraph }; }); - result.invocations.emplace(result_inputs, result_attrs, result_outputs); + result.invocations.emplace(map_from_unordered(result_inputs), result_attrs, map_from_unordered(result_outputs)); } return result; diff --git a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc index 380c2d17a1..9c7638440f 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc @@ -3,80 +3,79 @@ #include "op-attrs/pcg_operator_attrs.h" #include "pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h" #include "pcg/parallel_computation_graph/parallel_computation_graph.h" +#include "pcg/parallel_computation_graph/parallel_computation_graph_edge.h" #include "pcg/parallel_computation_graph/parallel_tensor_attrs.dtg.h" +#include "pcg/parallel_computation_graph/parallel_tensor_use_t.h" #include "task-spec/dynamic_graph/dynamic_layer_guid_t.dtg.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" #include "task-spec/dynamic_graph/dynamic_tensor_role.h" -#include "utils/containers/generate_map.h" +#include "utils/bidict/algorithms/merge_disjoint_bidicts.h" +#include "utils/containers/get_only.h" +#include "utils/containers/map_keys_and_values.h" +#include "utils/containers/require_only_key.h" +#include "utils/containers/transform_pairs.h" #include #include #include +#include "utils/containers/unordered_map_from_map.h" +#include "utils/bidict/algorithms/bidict_unordered_set_of.h" namespace FlexFlow { -DynamicOpenDataflowGraph make_dynamic_open_dataflow_graph_from_mapped_pcg( - MappedParallelComputationGraph const &mpcg) { - DynamicOpenDataflowGraph result = make_empty_dynamic_open_dataflow_graph(); - - ParallelComputationGraph pcg = pcg_from_mpcg(mpcg); +DynamicNodeInvocation make_dynamic_node_invocation_from_mapped( + MappedParallelLayerInvocationInfo const &invocation_info) +{ + DynamicNodeAttrs result_attrs{ + /*task_type=*/std::nullopt, + /*device_coord=*/std::nullopt, + /*mapping=*/invocation_info.layer_info.mapping, + /*op_attrs=*/TrainingOperationAttrs{invocation_info.layer_info.attrs.op_attrs}, + /*pcg_layer_guid=*/dynamic_layer_guid_t{invocation_info.layer_info.guid}, + /*per_device_op_state=*/std::nullopt, + }; - for (auto const &[layer, attrs] : get_parallel_layer_attrs_mapping(pcg)) { - DynamicNodeAttrs result_attrs{ - /*task_type=*/std::nullopt, - /*device_coord=*/std::nullopt, - /*mapping=*/mpcg_get_mapping_for_layer(mpcg, layer), - /*op_attrs=*/TrainingOperationAttrs{attrs.op_attrs}, - /*pcg_layer_guid=*/dynamic_layer_guid_t{layer}, - /*per_device_op_state=*/std::nullopt, + auto lift_kv_pair = + [&](TensorSlotName slot_name, + ParallelTensorInfo const &tensor) + -> std::pair + { + return { + DynamicTensorSlot{ + /*slot_name=*/slot_name, + /*slot_tensor_role=*/std::nullopt, + /*task_shard=*/std::nullopt, + }, + DynamicValueAttrs{ + /*tensor_guid=*/dynamic_tensor_guid_t{tensor.guid}, + /*parallel_tensor_shape=*/tensor.attrs.shape, + /*shard_coord=*/std::nullopt, + /*mapping=*/std::nullopt, + /*accessor=*/std::nullopt, + /*role=*/std::nullopt, + }, }; + }; + + std::map result_inputs = + transform(invocation_info.incoming, lift_kv_pair); + + std::map result_outputs = + transform(invocation_info.outgoing, lift_kv_pair); - std::unordered_map result_inputs = - transform(get_incoming_tensors(pcg, layer), - [&](TensorSlotName const &slot_name, - parallel_tensor_guid_t const &tensor) { - ParallelTensorAttrs attrs = - get_parallel_tensor_attrs(pcg, tensor); - return std::pair{ - DynamicTensorSlot{ - /*slot_name=*/slot_name, - /*slot_tensor_role=*/std::nullopt, - }, - DynamicValueAttrs{ - /*tensor_guid=*/dynamic_tensor_guid_t{tensor}, - /*parallel_tensor_shape=*/attrs.shape, - /*shard_coord=*/std::nullopt, - /*mapping=*/std::nullopt, - /*accessor=*/std::nullopt, - /*role=*/std::nullopt, - }, - }; - }); - std::unordered_map result_outputs = - transform(get_outgoing_tensors(pcg, layer), - [&](TensorSlotName const &slot_name, - parallel_tensor_guid_t const &tensor) { - ParallelTensorAttrs attrs = - get_parallel_tensor_attrs(pcg, tensor); - return std::pair{ - DynamicTensorSlot{ - /*slot_name=*/slot_name, - /*slot_tensor_role=*/std::nullopt, - }, - DynamicValueAttrs{ - /*tensor_guid=*/dynamic_tensor_guid_t{tensor}, - /*parallel_tensor_shape=*/attrs.shape, - /*shard_coord=*/std::nullopt, - /*mapping=*/std::nullopt, - /*accessor=*/std::nullopt, - /*role=*/std::nullopt, - }, - }; - }); + DynamicNodeInvocation invocation = DynamicNodeInvocation{ + /*inputs=*/result_inputs, + /*node_attrs=*/result_attrs, + /*outputs=*/result_outputs, + }; - result.invocations.emplace(result_inputs, result_attrs, result_outputs); - } + return invocation; +} + +DynamicOpenDataflowGraph make_dynamic_open_dataflow_graph_from_mapped_pcg( + MappedParallelComputationGraph const &mpcg) { - return result; + return dynamic_open_dataflow_graph_from_invocation_set( + transform(unordered_set_of(mpcg_get_invocation_set(mpcg)), make_dynamic_node_invocation_from_mapped)); } } // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc b/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc index 0cee06368f..a348ba77da 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/pass_expansion.cc @@ -1,7 +1,9 @@ #include "task-spec/dynamic_graph/pass_expansion.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" #include "task-spec/dynamic_graph/dynamic_tensor_role.h" +#include "task-spec/dynamic_graph/training_operation_attrs.h" #include "utils/containers/are_all_same.h" +#include "utils/containers/get_only.h" #include "utils/containers/merge_disjoint_maps.h" #include "utils/containers/transform.h" @@ -19,6 +21,18 @@ bool value_is_pass_expanded(DynamicValueAttrs const &v) { return v.role.has_value(); } +bool node_is_ready_for_pass_expansion(DynamicNodeAttrs const &) { + return true; +} + +bool value_is_ready_for_pass_expansion(DynamicValueAttrs const &) { + return true; +} + +bool slot_is_ready_for_pass_expansion(DynamicTensorSlot const &) { + return true; +} + bool no_part_of_graph_is_pass_expanded(DynamicOpenDataflowGraph const &g) { return no_part_of_dynamic_graph_satisfies( g, node_is_pass_expanded, value_is_pass_expanded, slot_is_pass_expanded); @@ -29,6 +43,11 @@ bool graph_is_fully_pass_expanded(DynamicOpenDataflowGraph const &g) { g, node_is_pass_expanded, value_is_pass_expanded, slot_is_pass_expanded); } +bool graph_is_ready_for_pass_expansion(DynamicOpenDataflowGraph const &g) { + return full_dynamic_graph_satisfies( + g, node_is_ready_for_pass_expansion, value_is_ready_for_pass_expansion, slot_is_ready_for_pass_expansion); +} + DynamicTensorSlot pass_expand_slot(DynamicTensorSlot const &s, FwbTensorType tensor_type) { ASSERT(!slot_is_pass_expanded(s)); @@ -82,6 +101,9 @@ DynamicNodeInvocation perform_fwd_pass_expansion_for_invocation( DynamicNodeInvocation perform_bwd_pass_expansion_for_invocation( DynamicNodeInvocation const &invocation) { + TrainingOperationAttrs op_attrs = + assert_unwrap(invocation.node_attrs.op_attrs); + auto to_fwd = [](DynamicTensorSlot const &k, DynamicValueAttrs const &v) { return std::pair{ pass_expand_slot(k, FwbTensorType::FORWARD), @@ -96,17 +118,37 @@ DynamicNodeInvocation perform_bwd_pass_expansion_for_invocation( }; }; - return DynamicNodeInvocation{ - /*inputs=*/ - merge_disjoint_maps(std::vector{ - transform(invocation.inputs, to_fwd), - transform(invocation.outputs, to_fwd), - transform(invocation.outputs, to_grad), - }), - /*node_attrs=*/ - pass_expand_node(invocation.node_attrs, DynamicTaskType::BWD), - /*outputs=*/ - transform(invocation.inputs, to_grad), + if (training_op_attrs_has_op_type(op_attrs, OperatorType::REPLICATE)) { + auto [input_slot, input] = get_only(invocation.inputs); + auto [output_slot, output] = get_only(invocation.outputs); + + DynamicNodeInvocation bwd{ + /*inputs=*/{ + to_fwd(output_slot, output), + to_grad(output_slot, output), + }, + /*node_attrs=*/ + pass_expand_node(invocation.node_attrs, DynamicTaskType::BWD), + /*outputs=*/ + { + to_grad(input_slot, input), + }, + }; + + return bwd; + } else { + return DynamicNodeInvocation{ + /*inputs=*/ + merge_disjoint_maps(std::vector{ + transform(invocation.inputs, to_fwd), + transform(invocation.outputs, to_fwd), + transform(invocation.outputs, to_grad), + }), + /*node_attrs=*/ + pass_expand_node(invocation.node_attrs, DynamicTaskType::BWD), + /*outputs=*/ + transform(invocation.inputs, to_grad), + }; }; } @@ -114,6 +156,7 @@ DynamicOpenDataflowGraph perform_pass_expansion(DynamicOpenDataflowGraph const &g) { ASSERT(no_part_of_graph_is_pass_expanded(g)); + ASSERT(graph_is_ready_for_pass_expansion(g)); DynamicOpenDataflowGraph result = flatmap_dynamic_invocation_set( g, [](DynamicNodeInvocation const &invocation) { diff --git a/lib/task-spec/src/task-spec/dynamic_graph/serializable_dynamic_node_attrs.cc b/lib/task-spec/src/task-spec/dynamic_graph/serializable_dynamic_node_attrs.cc index d613194d14..67ae7b58f3 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/serializable_dynamic_node_attrs.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/serializable_dynamic_node_attrs.cc @@ -7,7 +7,7 @@ SerializableDynamicNodeAttrs dynamic_node_attrs_to_serializable(DynamicNodeAttrs const &attrs) { return SerializableDynamicNodeAttrs{ /*task_type=*/attrs.task_type, - /*device_coord=*/attrs.device_coord, + /*device_coords=*/attrs.device_coords, /*mapping=*/attrs.mapping, /*op_attrs=*/attrs.op_attrs, /*layer_guid=*/attrs.layer_guid, @@ -18,7 +18,7 @@ DynamicNodeAttrs dynamic_node_attrs_from_serializable( SerializableDynamicNodeAttrs const &attrs) { return DynamicNodeAttrs{ /*task_type=*/attrs.task_type, - /*device_coord=*/attrs.device_coord, + /*device_coords=*/attrs.device_coords, /*mapping=*/attrs.mapping, /*op_attrs=*/attrs.op_attrs, /*layer_guid=*/attrs.layer_guid, diff --git a/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc b/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc index fb6efb96d0..f910c6c6c2 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/shard_expansion.cc @@ -1,21 +1,99 @@ #include "task-spec/dynamic_graph/shard_expansion.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" #include "task-spec/dynamic_graph/dynamic_value_attrs.dtg.h" -#include "utils/bidict/algorithms/filter_keys.h" +#include "utils/bidict/algorithms/bidict_filter_keys.h" #include "utils/containers/get_only.h" #include "utils/containers/map_values2.h" #include "utils/containers/require_same.h" #include "utils/containers/transform.h" #include "utils/optional.h" +#include "utils/containers/binary_merge_disjoint_maps.h" +#include "task-spec/dynamic_graph/dynamic_node_invocation.h" +#include "utils/containers/map_from_unordered.h" +#include "utils/one_to_many/one_to_many_filter_keys.h" +#include "task-spec/dynamic_graph/training_operation_attrs.h" +#include "utils/one_to_many/require_one_to_many_is_bijection.h" +#include "pcg/mapped_parallel_computation_graph/operator_atomic_task_shard_binding.h" +#include "utils/bidict/algorithms/bidict_filter_values.h" +#include "task-spec/dynamic_graph/dynamic_tensor_role.h" +#include "utils/containers/merge_disjoint_maps.h" +#include "utils/containers/map_keys.h" +#include "utils/containers/require_only_key.h" namespace FlexFlow { bool node_is_shard_expanded(DynamicNodeAttrs const &n) { - return n.device_coord.has_value(); + return n.device_coords.has_value(); +} + +bool node_is_ready_for_shard_expansion(DynamicNodeAttrs const &n) { + if (!n.op_attrs.has_value()) { + return false; + } + + if (n.op_attrs.value().is_pcg_op()) { + if (!n.mapping.has_value()) { + return false; + } + } + + return true; +} + +void require_node_is_ready_for_shard_expansion(DynamicNodeAttrs const &n) { + ASSERT(n.op_attrs.has_value()); + if (n.op_attrs.value().is_pcg_op()) { + ASSERT(n.mapping.has_value()); + } +} + + +bool invocation_is_fully_shard_expanded(DynamicNodeInvocation const &i) { + auto slot_is_shard_expanded = [](DynamicTensorSlot const &) { + return true; + }; + + return invocation_fully_satisfies( + i, + node_is_shard_expanded, + value_is_shard_expanded, + slot_is_shard_expanded); } bool value_is_shard_expanded(DynamicValueAttrs const &n) { - return n.shard_coord.has_value(); + return n.shard_coord.has_value() && n.mapping.has_value(); +} + +bool value_is_ready_for_shard_expansion(DynamicValueAttrs const &n) { + return true; +} + +void require_value_is_ready_for_shard_expansion(DynamicValueAttrs const &n) { + return; +} + +bool invocation_is_ready_for_shard_expansion(DynamicNodeInvocation const &i) { + auto slot_is_ready_for_shard_expansion = [](DynamicTensorSlot const &) { + return true; + }; + + return invocation_fully_satisfies( + i, + node_is_ready_for_shard_expansion, + value_is_ready_for_shard_expansion, + slot_is_ready_for_shard_expansion); +} + +void require_invocation_is_ready_for_shard_expansion(DynamicNodeInvocation const &i) { + auto require_slot_is_ready_for_shard_expansion = [](DynamicTensorSlot const &) -> void { + return; + }; + + require_invocation_fully_satisfies( + i, + require_node_is_ready_for_shard_expansion, + require_value_is_ready_for_shard_expansion, + require_slot_is_ready_for_shard_expansion); } bool no_part_of_graph_is_shard_expanded(DynamicOpenDataflowGraph const &g) { @@ -40,20 +118,51 @@ bool graph_is_fully_shard_expanded(DynamicOpenDataflowGraph const &g) { slot_is_shard_expanded); } -static bidict +static OneToMany restrict_tensor_mapping_keys_to_coord( - bidict const + OneToMany const &mapping, ParallelTensorSpaceCoordinate const ¶llel_tensor_coord) { - return filter_keys(mapping, [&](ParallelTensorSpaceCoordinate const &p) { + return one_to_many_filter_keys(mapping, [&](ParallelTensorSpaceCoordinate const &p) { return p == parallel_tensor_coord; }); } +static DynamicNodeInvocationShardingInfo invocation_sharding_info_for_binding( + DynamicNodeInvocation const &i, + MachineSpaceCoordinate const &machine_coord, + OperatorAtomicTaskShardBinding const &binding) { + + auto shard_expand_value_attrs = + [&](DynamicTensorSlot const &s, DynamicValueAttrs const &v) -> DynamicValueAttrsShardingInfo { + ParallelTensorSpaceCoordinate parallel_tensor_coord = + binding.tensor_coords.at(s.slot_name); + + return DynamicValueAttrsShardingInfo{ + /*shard_coord=*/parallel_tensor_coord, + /*mapping=*/restrict_tensor_mapping_keys_to_coord(v.mapping.value(), parallel_tensor_coord), + }; + }; + + DynamicNodeAttrs expanded_node_attrs = [&]() { + DynamicNodeAttrs result = i.node_attrs; + result.device_coords = nonempty_set{machine_coord}; + return result; + }(); + + return DynamicNodeInvocationShardingInfo{ + /*device_coord=*/nonempty_set{machine_coord}, + /*value_sharding=*/map_values2( + binary_merge_disjoint_maps(i.inputs, i.outputs), + shard_expand_value_attrs), + }; +} + static DynamicNodeInvocation shard_invocation_for_binding( DynamicNodeInvocation const &i, MachineSpaceCoordinate const &machine_coord, OperatorAtomicTaskShardBinding const &binding) { + auto shard_expand_value_attrs = [&](DynamicTensorSlot const &s, DynamicValueAttrs const &v) -> DynamicValueAttrs { @@ -64,8 +173,9 @@ static DynamicNodeInvocation shard_invocation_for_binding( result.shard_coord = parallel_tensor_coord; result.mapping = transform( v.mapping, - [&](bidict const - &mapping) { + [&](OneToMany const &mapping) + -> OneToMany + { return restrict_tensor_mapping_keys_to_coord(mapping, parallel_tensor_coord); }); @@ -74,7 +184,7 @@ static DynamicNodeInvocation shard_invocation_for_binding( DynamicNodeAttrs expanded_node_attrs = [&]() { DynamicNodeAttrs result = i.node_attrs; - result.device_coord = machine_coord; + result.device_coords = nonempty_set{machine_coord}; return result; }(); @@ -85,17 +195,19 @@ static DynamicNodeInvocation shard_invocation_for_binding( }; } -static std::unordered_set - perform_shard_expansion_for_copy(DynamicNodeInvocation const &i) { +static std::set + generate_shard_expansion_for_copy(DynamicNodeInvocation const &i) { auto [input_slot, input] = get_only(i.inputs); auto [output_slot, output] = get_only(i.outputs); - bidict input_mapping = + + OneToMany input_mapping = assert_unwrap(input.mapping); require_same(input_mapping.left_values(), assert_unwrap(output.mapping).left_values()); return transform( - input_mapping.left_values(), [&](ParallelTensorSpaceCoordinate const &p) { + input_mapping.left_values(), + [&](ParallelTensorSpaceCoordinate const &p) -> DynamicNodeInvocationShardingInfo { // The machine coord for a copy is inherently nebulous because it // doesn't strictly run in any single location. Further, Realm has the // flexibility to issue a copy operation from anywhere in the machine, @@ -103,9 +215,9 @@ static std::unordered_set // because we expect this to align with the most efficient way to issue // copies in Realm, although the current Realm backend uses a // centralized controller and thus issues copies all from a single node. - MachineSpaceCoordinate machine_coord = input_mapping.at_l(p); + MachineSpaceCoordinate machine_coord = get_only(input_mapping.at_l(p)); - return shard_invocation_for_binding(i, + return invocation_sharding_info_for_binding(i, machine_coord, OperatorAtomicTaskShardBinding{{ {input_slot.slot_name, p}, @@ -114,11 +226,341 @@ static std::unordered_set }); } +// TODO(@lockshaw): There is a lot of code duplication between +// generate_shard_expansion_for_fwd_replicate and +// generate_shard_expansion_for_bwd_replicate that should eventually be +// factored out. +static std::set + generate_shard_expansion_for_fwd_replicate(DynamicNodeInvocation const &i) { + ASSERT(i.node_attrs.task_type == DynamicTaskType::FWD); + + MappedOperatorTaskGroup node_mapping = assert_unwrap(i.node_attrs.mapping); + + DynamicTensorSlot expected_input_slot = DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::INPUT, + /*slot_tensor_role=*/mk_dynamic_tensor_role_fwd(), + /*task_shard=*/std::nullopt, + }; + + DynamicValueAttrs input = require_only_key(i.inputs, expected_input_slot); + + DynamicTensorSlot expected_output_slot = DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::OUTPUT, + /*slot_tensor_role=*/mk_dynamic_tensor_role_fwd(), + /*task_shard=*/std::nullopt, + }; + + DynamicValueAttrs output = require_only_key(i.outputs, expected_output_slot); + + bidict + input_value_mapping = require_one_to_many_is_bijection( + assert_unwrap(input.mapping)); + + std::set input_tensor_shards = set_of(input_value_mapping.left_values()); + + bidict + output_value_mapping = require_one_to_many_is_bijection( + assert_unwrap(output.mapping)); + + auto get_task_shard_machine_coords_for_input_tensor_shard + = [&](ParallelTensorSpaceCoordinate const &input_tensor_shard) + -> nonempty_set + { + bidict dependent_on_input_tensor_shard + = bidict_filter_values( + node_mapping.get_shard_bindings(), + [&](OperatorAtomicTaskShardBinding const &b) -> bool { + return ptensor_space_coord_for_slot_name(b, TensorSlotName::INPUT) == input_tensor_shard; + }); + + return nonempty_set(set_of(dependent_on_input_tensor_shard.left_values())); + }; + + auto invocation_sharding_info_for_input_tensor_shard = [&](ParallelTensorSpaceCoordinate const &c) + -> DynamicNodeInvocationShardingInfo + { + nonempty_set task_shard_machine_coords = + get_task_shard_machine_coords_for_input_tensor_shard(c); + + std::map output_sharding_infos = + generate_map(task_shard_machine_coords.unwrap_as_set(), + [&](MachineSpaceCoordinate const &mc) + -> DynamicValueAttrsShardingInfo + { + ParallelTensorSpaceCoordinate pc = output_value_mapping.at_r(mc); + + return DynamicValueAttrsShardingInfo{ + /*shard_coord=*/pc, + /*mapping=*/OneToMany{ + { + pc, + {mc}, + }, + }, + }; + }); + + std::map keyed_output_sharding_infos = + map_keys(output_sharding_infos, + [&](MachineSpaceCoordinate const &mc) -> DynamicTensorSlot { + return DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::OUTPUT, + /*slot_tensor_role=*/mk_dynamic_tensor_role_fwd(), + /*task_shard=*/mc, + }; + }); + + DynamicTensorSlot input_slot = DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::INPUT, + /*slot_tensor_role=*/mk_dynamic_tensor_role_fwd(), + /*task_shard=*/std::nullopt, + }; + + DynamicValueAttrsShardingInfo input_sharding_info = DynamicValueAttrsShardingInfo{ + /*shard_coord=*/c, + /*mapping=*/OneToMany{ + { + c, + {input_value_mapping.at_l(c)}, + }, + }, + }; + + std::map sharding_infos = + binary_merge_disjoint_maps( + keyed_output_sharding_infos, + std::map{ + { + input_slot, + input_sharding_info, + }, + }); + + return DynamicNodeInvocationShardingInfo{ + /*device_coords=*/task_shard_machine_coords, + /*value_sharding=*/sharding_infos, + }; + }; + + return transform(input_tensor_shards, invocation_sharding_info_for_input_tensor_shard); +} + +static std::set + generate_shard_expansion_for_bwd_replicate(DynamicNodeInvocation const &i) { + ASSERT(i.node_attrs.task_type == DynamicTaskType::BWD); + + MappedOperatorTaskGroup node_mapping = assert_unwrap(i.node_attrs.mapping); + + DynamicTensorSlot expected_output_grad_slot = DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::OUTPUT, + /*slot_tensor_role=*/mk_dynamic_tensor_role_bwd(), + /*task_shard=*/std::nullopt, + }; + + DynamicValueAttrs output_grad = require_only_key(i.inputs, expected_output_grad_slot); + + DynamicTensorSlot expected_input_grad_slot = DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::INPUT, + /*slot_tensor_role=*/mk_dynamic_tensor_role_bwd(), + /*task_shard=*/std::nullopt, + }; + + DynamicValueAttrs input_grad = require_only_key(i.outputs, expected_input_grad_slot); + + bidict + output_grad_value_mapping = require_one_to_many_is_bijection( + assert_unwrap(output_grad.mapping)); + + bidict + input_grad_value_mapping = require_one_to_many_is_bijection( + assert_unwrap(input_grad.mapping)); + + std::set input_grad_tensor_shards = set_of(input_grad_value_mapping.left_values()); + + auto get_task_shard_machine_coords_for_input_grad_tensor_shard + = [&](ParallelTensorSpaceCoordinate const &input_grad_tensor_shard) + -> nonempty_set + { + bidict produce_input_grad_tensor_shard + = bidict_filter_values( + node_mapping.get_shard_bindings(), + [&](OperatorAtomicTaskShardBinding const &b) -> bool { + return ptensor_space_coord_for_slot_name(b, TensorSlotName::INPUT) == input_grad_tensor_shard; + }); + + return nonempty_set(set_of(produce_input_grad_tensor_shard.left_values())); + }; + + auto invocation_sharding_info_for_input_grad_tensor_shard = [&](ParallelTensorSpaceCoordinate const &c) + -> DynamicNodeInvocationShardingInfo + { + nonempty_set task_shard_machine_coords = + get_task_shard_machine_coords_for_input_grad_tensor_shard(c); + + std::map output_grad_sharding_infos = + generate_map(task_shard_machine_coords.unwrap_as_set(), + [&](MachineSpaceCoordinate const &mc) + -> DynamicValueAttrsShardingInfo + { + ParallelTensorSpaceCoordinate pc = output_grad_value_mapping.at_r(mc); + + return DynamicValueAttrsShardingInfo{ + /*shard_coord=*/pc, + /*mapping=*/OneToMany{ + { + pc, + {mc}, + }, + }, + }; + }); + + std::map keyed_output_grad_sharding_infos = + map_keys(output_grad_sharding_infos, + [&](MachineSpaceCoordinate const &mc) -> DynamicTensorSlot { + return DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::OUTPUT, + /*slot_tensor_role=*/mk_dynamic_tensor_role_bwd(), + /*task_shard=*/mc, + }; + }); + + DynamicTensorSlot input_grad_slot = DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::INPUT, + /*slot_tensor_role=*/mk_dynamic_tensor_role_bwd(), + /*task_shard=*/std::nullopt, + }; + + DynamicValueAttrsShardingInfo input_grad_sharding_info = DynamicValueAttrsShardingInfo{ + /*shard_coord=*/c, + /*mapping=*/OneToMany{ + { + c, + {input_grad_value_mapping.at_l(c)}, + }, + }, + }; + + std::map sharding_infos = + binary_merge_disjoint_maps( + keyed_output_grad_sharding_infos, + std::map{ + { + input_grad_slot, + input_grad_sharding_info, + }, + }); + + return DynamicNodeInvocationShardingInfo{ + /*device_coords=*/task_shard_machine_coords, + /*value_sharding=*/sharding_infos, + }; + }; + + return transform(input_grad_tensor_shards, invocation_sharding_info_for_input_grad_tensor_shard); +} + std::unordered_set perform_shard_expansion_for_invocation(DynamicNodeInvocation const &i) { - if (i.node_attrs.op_attrs.has_value() && - i.node_attrs.op_attrs.value().is_copy()) { - return perform_shard_expansion_for_copy(i); + + std::unordered_set + shard_expansion_info = generate_shard_expansion_for_invocation(i); + + return transform( + shard_expansion_info, + [&](DynamicNodeInvocationShardingInfo const &s) + -> DynamicNodeInvocation + { + return apply_dynamic_node_invocation_sharding_info(i, s); + }); +} + +bool graph_is_ready_for_shard_expansion(DynamicOpenDataflowGraph const &g) { + auto slot_is_ready_for_shard_expansion = [](DynamicTensorSlot const &) -> bool { + return false; + }; + + return full_dynamic_graph_satisfies(g, + node_is_ready_for_shard_expansion, + value_is_ready_for_shard_expansion, + slot_is_ready_for_shard_expansion); +} + + +void require_graph_is_ready_for_shard_expansion(DynamicOpenDataflowGraph const &g) { + auto require_slot_is_ready_for_shard_expansion = [](DynamicTensorSlot const &) -> void { + return; + }; + + return require_full_dynamic_graph_satisfies(g, + require_node_is_ready_for_shard_expansion, + require_value_is_ready_for_shard_expansion, + require_slot_is_ready_for_shard_expansion); +} + +DynamicNodeAttrs apply_dynamic_node_attrs_sharding_info( + DynamicNodeAttrs const &node_attrs, + nonempty_set const &device_coords) +{ + DynamicNodeAttrs result = node_attrs; + result.device_coords = device_coords; + + return result; +} + +DynamicValueAttrs apply_dynamic_value_attrs_sharding_info( + DynamicValueAttrs const &value_attrs, + DynamicValueAttrsShardingInfo const &value_sharding_info) +{ + DynamicValueAttrs result = value_attrs; + result.shard_coord = value_sharding_info.shard_coord; + result.mapping = value_sharding_info.mapping; + return result; +} + +DynamicNodeInvocation apply_dynamic_node_invocation_sharding_info( + DynamicNodeInvocation const &invocation, + DynamicNodeInvocationShardingInfo const &invocation_sharding_info) +{ + require_invocation_is_ready_for_shard_expansion(invocation); + + auto shard_value = [&](DynamicTensorSlot const &slot, DynamicValueAttrs const &value_attrs) + -> DynamicValueAttrs + { + DynamicValueAttrsShardingInfo sharding_info = invocation_sharding_info.value_sharding.at(slot); + return apply_dynamic_value_attrs_sharding_info(value_attrs, sharding_info); + }; + + DynamicNodeInvocation result = DynamicNodeInvocation{ + /*inputs=*/map_values2(invocation.inputs, shard_value), + /*node_attrs=*/apply_dynamic_node_attrs_sharding_info( + invocation.node_attrs, invocation_sharding_info.device_coords), + /*outputs=*/map_values2(invocation.outputs, shard_value), + }; + + ASSERT(invocation_is_fully_shard_expanded(result)); + return result; +} + +std::unordered_set + generate_shard_expansion_for_invocation(DynamicNodeInvocation const &i) +{ + require_invocation_is_ready_for_shard_expansion(i); + + if (i.node_attrs.op_attrs.value().is_copy()) { + return unordered_set_of(generate_shard_expansion_for_copy(i)); + } + + if (training_op_attrs_has_op_type(i.node_attrs.op_attrs.value(), OperatorType::REPLICATE)) { + DynamicTaskType task_type = assert_unwrap(i.node_attrs.task_type); + switch (task_type) { + case DynamicTaskType::FWD: + return unordered_set_of(generate_shard_expansion_for_fwd_replicate(i)); + case DynamicTaskType::BWD: + return unordered_set_of(generate_shard_expansion_for_bwd_replicate(i)); + default: + PANIC("Unexpected task type for Replicate: {}", task_type); + } } MappedOperatorTaskGroup mapping = assert_unwrap(i.node_attrs.mapping); @@ -128,11 +570,11 @@ std::unordered_set return transform( shard_machine_coords, - [&](MachineSpaceCoordinate const &c) -> DynamicNodeInvocation { + [&](MachineSpaceCoordinate const &c) -> DynamicNodeInvocationShardingInfo { OperatorAtomicTaskShardBinding slot_bindings = mapping.get_shard_bindings().at_l(c); - return shard_invocation_for_binding(i, c, slot_bindings); + return invocation_sharding_info_for_binding(i, c, slot_bindings); }); } @@ -140,6 +582,7 @@ DynamicOpenDataflowGraph perform_shard_expansion(DynamicOpenDataflowGraph const &g) { ASSERT(no_part_of_graph_is_shard_expanded(g)); + require_graph_is_ready_for_shard_expansion(g); DynamicOpenDataflowGraph result = flatmap_dynamic_invocation_set(g, [&](DynamicNodeInvocation const &i) { diff --git a/lib/task-spec/src/task-spec/dynamic_graph/training_operation_attrs.cc b/lib/task-spec/src/task-spec/dynamic_graph/training_operation_attrs.cc new file mode 100644 index 0000000000..a9be225ff5 --- /dev/null +++ b/lib/task-spec/src/task-spec/dynamic_graph/training_operation_attrs.cc @@ -0,0 +1,18 @@ +#include "task-spec/dynamic_graph/training_operation_attrs.h" +#include "op-attrs/pcg_operator_attrs.h" +#include "utils/overload.h" + +namespace FlexFlow { + +bool training_op_attrs_has_op_type(TrainingOperationAttrs const &op_attrs, + OperatorType op_type) { + return op_attrs.visit(overload{ + [&](PCGOperatorAttrs const &pcg_op_attrs) -> bool { + return pcg_op_attrs_get_op_type(pcg_op_attrs) == op_type; + }, + [](LossAttrs const &) -> bool { return false; }, + [](CopyAttrs const &) -> bool { return false; }, + }); +} + +} // namespace FlexFlow diff --git a/lib/task-spec/src/task-spec/dynamic_graph/update_insertion.cc b/lib/task-spec/src/task-spec/dynamic_graph/update_insertion.cc index 58a32db6c1..51a79cff59 100644 --- a/lib/task-spec/src/task-spec/dynamic_graph/update_insertion.cc +++ b/lib/task-spec/src/task-spec/dynamic_graph/update_insertion.cc @@ -79,7 +79,7 @@ static DynamicNodeInvocation get_update_invocation_for_invocation( /*inputs=*/map_from_pairs( transform(tensor_roles, create_binding_for_role)), /*node_attrs=*/update_node_attrs, - /*outputs=*/std::unordered_map{}, + /*outputs=*/std::map{}, }; } diff --git a/lib/task-spec/src/task-spec/ops/impl/element_binary.cc b/lib/task-spec/src/task-spec/ops/impl/element_binary.cc index 13465d7a5f..c8460af538 100644 --- a/lib/task-spec/src/task-spec/ops/impl/element_binary.cc +++ b/lib/task-spec/src/task-spec/ops/impl/element_binary.cc @@ -36,8 +36,8 @@ static std::optional forward_task_impl(TaskArgumentAccessor const &acc) { ProfilingSettings profiling = acc.get_profiling_settings(); DeviceType kernel_device_type = acc.get_kernel_device_type(); - ElementBinaryPerDeviceState per_device_state = - acc.get_per_device_op_state().require_element_binary().value(); + std::optional per_device_state = + acc.get_per_device_op_state().require_element_binary(); ElementBinaryAttrs attrs = acc.get_op_attrs().require_element_binary(); device_handle_t handle = acc.get_ff_handle(); @@ -62,8 +62,8 @@ static std::optional backward_task_impl(TaskArgumentAccessor const &acc) { ProfilingSettings profiling = acc.get_profiling_settings(); DeviceType kernel_device_type = acc.get_kernel_device_type(); - ElementBinaryPerDeviceState per_device_state = - acc.get_per_device_op_state().require_element_binary().value(); + std::optional per_device_state = + acc.get_per_device_op_state().require_element_binary(); ElementBinaryAttrs attrs = acc.get_op_attrs().require_element_binary(); device_handle_t handle = acc.get_ff_handle(); diff --git a/lib/task-spec/src/task-spec/ops/impl/element_unary.cc b/lib/task-spec/src/task-spec/ops/impl/element_unary.cc index d66ff9ab8d..9a092b90b8 100644 --- a/lib/task-spec/src/task-spec/ops/impl/element_unary.cc +++ b/lib/task-spec/src/task-spec/ops/impl/element_unary.cc @@ -35,8 +35,8 @@ static std::optional ProfilingSettings profiling = acc.get_profiling_settings(); DeviceType kernel_device_type = acc.get_kernel_device_type(); - ElementUnaryPerDeviceState per_device_state = - acc.get_per_device_op_state().require_element_unary().value(); + std::optional per_device_state = + acc.get_per_device_op_state().require_element_unary(); return profile(forward_kernel, profiling, @@ -62,8 +62,8 @@ static std::optional ProfilingSettings profiling = acc.get_profiling_settings(); DeviceType kernel_device_type = acc.get_kernel_device_type(); - ElementUnaryPerDeviceState per_device_state = - acc.get_per_device_op_state().require_element_unary().value(); + std::optional per_device_state = + acc.get_per_device_op_state().require_element_unary(); return profile(backward_kernel, profiling, diff --git a/lib/task-spec/test/src/task-spec/dynamic_graph/copy_insertion.cc b/lib/task-spec/test/src/task-spec/dynamic_graph/copy_insertion.cc index 2160f6bf82..05324c2195 100644 --- a/lib/task-spec/test/src/task-spec/dynamic_graph/copy_insertion.cc +++ b/lib/task-spec/test/src/task-spec/dynamic_graph/copy_insertion.cc @@ -6,11 +6,14 @@ #include "task-spec/dynamic_graph/dynamic_value_attrs.dtg.h" #include "test/utils/doctest/fmt/unordered_set.h" #include +#include "task-spec/dynamic_graph/dynamic_value_attrs.h" +#include "task-spec/dynamic_graph/serializable_dynamic_node_invocation.h" +#include "op-attrs/ops/element_unary.h" using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("perform_copy_insertion_for_invocation") { + TEST_CASE("copies_for_invocation_inputs") { auto mk_machine_coord = [](nonnegative_int node_idx, nonnegative_int device_idx) -> MachineSpaceCoordinate { @@ -21,6 +24,34 @@ TEST_SUITE(FF_TEST_SUITE) { }; }; + auto mk_slot = [](TensorSlotName const &slot_name) -> DynamicTensorSlot { + return DynamicTensorSlot{ + /*slot_name=*/slot_name, + /*slot_tensor_role=*/mk_dynamic_tensor_role_fwd(), + /*task_shard=*/std::nullopt, + }; + }; + + auto mk_value = [](size_t src_node_id, + TensorSlotName src_slot_name) + -> DynamicValueAttrs { + return DynamicValueAttrs{ + /*tensor_guid=*/dynamic_tensor_guid_t{ + parallel_tensor_guid_t{ + KwargDataflowOutput{ + Node{src_node_id}, + src_slot_name, + }, + }, + }, + /*parallel_tensor_shape=*/std::nullopt, + /*shard_coord=*/std::nullopt, + /*mapping=*/std::nullopt, + /*accessor=*/std::nullopt, + /*role=*/std::nullopt, + }; + }; + auto mk_pt_coord = [](nonnegative_int idx1, nonnegative_int idx2, @@ -37,397 +68,571 @@ TEST_SUITE(FF_TEST_SUITE) { }; }; - auto mk_input_shard_binding = [&](ParallelTensorSpaceCoordinate const &c) - -> OperatorAtomicTaskShardBinding { - return OperatorAtomicTaskShardBinding{ - /*tensor_coords=*/{ + size_t invocation_id = 20; + + MachineSpaceCoordinate mc1 = mk_machine_coord(0_n, 0_n); + MachineSpaceCoordinate mc2 = mk_machine_coord(1_n, 0_n); + MachineSpaceCoordinate mc3 = mk_machine_coord(2_n, 0_n); + MachineSpaceCoordinate mc4 = mk_machine_coord(3_n, 0_n); + + SUBCASE("standard operator") { + auto mk_input_shard_binding = [&](ParallelTensorSpaceCoordinate const &c) + -> OperatorAtomicTaskShardBinding { + return OperatorAtomicTaskShardBinding{ + /*tensor_coords=*/{ + { + TensorSlotName::OUTPUT, + c, + }, + }, + }; + }; + + auto mk_shard_binding = [&](ParallelTensorSpaceCoordinate const &c1, + ParallelTensorSpaceCoordinate const &c2, + ParallelTensorSpaceCoordinate const &c3, + ParallelTensorSpaceCoordinate const &c4) + -> OperatorAtomicTaskShardBinding { + return OperatorAtomicTaskShardBinding{ + /*tensor_coords=*/{ + { + TensorSlotName::INPUT, + c1, + }, + { + TensorSlotName::WEIGHT, + c2, + }, + { + TensorSlotName::OUTPUT_1, + c3, + }, + { + TensorSlotName::OUTPUT_2, + c4, + }, + }, + }; + }; + + ParallelTensorSpaceCoordinate mc1_input_coord = + mk_pt_coord(0_n, 0_n, 0_n, 0_n); + ParallelTensorSpaceCoordinate mc1_weight_coord = + mk_pt_coord(0_n, 1_n, 2_n, 0_n); + ParallelTensorSpaceCoordinate mc1_output_1_coord = + mk_pt_coord(1_n, 0_n, 0_n, 1_n); + ParallelTensorSpaceCoordinate mc1_output_2_coord = + mk_pt_coord(3_n, 0_n, 0_n, 0_n); + + ParallelTensorSpaceCoordinate mc2_input_coord = + mk_pt_coord(0_n, 1_n, 0_n, 0_n); + ParallelTensorSpaceCoordinate mc2_weight_coord = + mk_pt_coord(0_n, 4_n, 2_n, 0_n); + ParallelTensorSpaceCoordinate mc2_output_1_coord = + mk_pt_coord(1_n, 2_n, 0_n, 1_n); + ParallelTensorSpaceCoordinate mc2_output_2_coord = + mk_pt_coord(0_n, 0_n, 0_n, 0_n); + + MappedOperatorTaskGroup input_mapping_same = MappedOperatorTaskGroup{ + bidict{ + { + mc1, + mk_input_shard_binding(mc1_input_coord), + }, { - TensorSlotName::OUTPUT, - c, + mc2, + mk_input_shard_binding(mc2_input_coord), }, }, }; - }; - auto mk_shard_binding = [&](ParallelTensorSpaceCoordinate const &c1, - ParallelTensorSpaceCoordinate const &c2, - ParallelTensorSpaceCoordinate const &c3, - ParallelTensorSpaceCoordinate const &c4) - -> OperatorAtomicTaskShardBinding { - return OperatorAtomicTaskShardBinding{ - /*tensor_coords=*/{ + MappedOperatorTaskGroup weight_mapping_same = MappedOperatorTaskGroup{ + bidict{ { - TensorSlotName::INPUT, - c1, + mc1, + mk_input_shard_binding(mc1_weight_coord), }, { - TensorSlotName::WEIGHT, - c2, + mc2, + mk_input_shard_binding(mc2_weight_coord), }, + }, + }; + + MappedOperatorTaskGroup invocation_mapping = MappedOperatorTaskGroup{ + bidict{ { - TensorSlotName::OUTPUT_1, - c3, + mc1, + mk_shard_binding(mc1_input_coord, + mc1_weight_coord, + mc1_output_1_coord, + mc1_output_2_coord), }, { - TensorSlotName::OUTPUT_2, - c4, + mc2, + mk_shard_binding(mc2_input_coord, + mc2_weight_coord, + mc2_output_1_coord, + mc2_output_2_coord), }, }, }; - }; - MachineSpaceCoordinate mc1 = mk_machine_coord(0_n, 0_n); - MachineSpaceCoordinate mc2 = mk_machine_coord(1_n, 0_n); - MachineSpaceCoordinate mc3 = mk_machine_coord(2_n, 0_n); - MachineSpaceCoordinate mc4 = mk_machine_coord(3_n, 0_n); + MappedOperatorTaskGroup invocation_mapping_diff_vs_copy1 = + MappedOperatorTaskGroup{ + bidict{ + { + mc2, + mk_shard_binding(mc2_input_coord, + mc2_weight_coord, + mc2_output_1_coord, + mc2_output_2_coord), + }, + }, + }; - ParallelTensorSpaceCoordinate mc1_input_coord = - mk_pt_coord(0_n, 0_n, 0_n, 0_n); - ParallelTensorSpaceCoordinate mc1_weight_coord = - mk_pt_coord(0_n, 1_n, 2_n, 0_n); - ParallelTensorSpaceCoordinate mc1_output_1_coord = - mk_pt_coord(1_n, 0_n, 0_n, 1_n); - ParallelTensorSpaceCoordinate mc1_output_2_coord = - mk_pt_coord(3_n, 0_n, 0_n, 0_n); - - ParallelTensorSpaceCoordinate mc2_input_coord = - mk_pt_coord(0_n, 1_n, 0_n, 0_n); - ParallelTensorSpaceCoordinate mc2_weight_coord = - mk_pt_coord(0_n, 4_n, 2_n, 0_n); - ParallelTensorSpaceCoordinate mc2_output_1_coord = - mk_pt_coord(1_n, 2_n, 0_n, 1_n); - ParallelTensorSpaceCoordinate mc2_output_2_coord = - mk_pt_coord(0_n, 0_n, 0_n, 0_n); - - MappedOperatorTaskGroup input_mapping_same = MappedOperatorTaskGroup{ - bidict{ - { - mc1, - mk_input_shard_binding(mc1_input_coord), - }, - { - mc2, - mk_input_shard_binding(mc2_input_coord), - }, - }, - }; + DynamicValueAttrs graph_input1 = + mk_value(0, TensorSlotName::OUTPUT); + + DynamicValueAttrs graph_input1_use = + decide_dynamic_value_attrs_mapping( + graph_input1, + get_tensor_bindings_for_slot_name(invocation_mapping, TensorSlotName::INPUT)); + + DynamicValueAttrs graph_input1_use_diff_vs_copy1 = + decide_dynamic_value_attrs_mapping( + graph_input1, + get_tensor_bindings_for_slot_name(invocation_mapping_diff_vs_copy1, TensorSlotName::INPUT)); + + DynamicValueAttrs graph_input2 = + mk_value(1, TensorSlotName::OUTPUT); + + DynamicValueAttrs graph_input2_use = + decide_dynamic_value_attrs_mapping( + graph_input2, + get_tensor_bindings_for_slot_name(invocation_mapping, TensorSlotName::WEIGHT)); + + DynamicValueAttrs invocation_output1 = mk_value(invocation_id, + TensorSlotName::OUTPUT_1); + DynamicValueAttrs invocation_output1_src = + decide_dynamic_value_attrs_mapping( + invocation_output1, + get_tensor_bindings_for_slot_name(invocation_mapping, TensorSlotName::OUTPUT_1)); + + DynamicValueAttrs invocation_output2 = mk_value(invocation_id, + TensorSlotName::OUTPUT_2); + DynamicValueAttrs invocation_output2_src = + decide_dynamic_value_attrs_mapping( + invocation_output2, + get_tensor_bindings_for_slot_name(invocation_mapping, TensorSlotName::OUTPUT_2)); + + DynamicValueAttrs graph_input1_src_same = + decide_dynamic_value_attrs_mapping( + graph_input1, + get_tensor_bindings_for_slot_name(input_mapping_same, TensorSlotName::OUTPUT)); + + DynamicValueAttrs graph_input2_src_same = + decide_dynamic_value_attrs_mapping( + graph_input2, + get_tensor_bindings_for_slot_name(weight_mapping_same, TensorSlotName::OUTPUT)); + + DynamicNodeInvocation input = DynamicNodeInvocation{ + /*inputs=*/{ + { + mk_slot(TensorSlotName::INPUT), + graph_input1, + }, + { + mk_slot(TensorSlotName::WEIGHT), + graph_input2, + }, + }, + /*node_attrs=*/ + DynamicNodeAttrs{ + /*task_type=*/DynamicTaskType::FWD, + /*device_coord=*/std::nullopt, + /*mapping=*/invocation_mapping, + /*op_attrs=*/TrainingOperationAttrs{ + PCGOperatorAttrs{ + make_relu_attrs(), + }, + }, + /*layer_guid=*/ + dynamic_layer_guid_t{parallel_layer_guid_t{Node{invocation_id}}}, + /*per_device_op_state=*/std::nullopt, + }, + /*outputs=*/ + { + { + mk_slot(TensorSlotName::OUTPUT_1), + invocation_output1, + }, + { + mk_slot(TensorSlotName::OUTPUT_2), + invocation_output2, + }, + }, + }; - MappedOperatorTaskGroup weight_mapping_same = MappedOperatorTaskGroup{ - bidict{ - { - mc1, - mk_input_shard_binding(mc1_weight_coord), + auto mk_copy = [&](DynamicValueAttrs const &src, + DynamicValueAttrs const &dst) { + return DynamicNodeInvocation{ + /*inputs=*/{{mk_slot(TensorSlotName::INPUT), src}}, + /*node_attrs=*/ + DynamicNodeAttrs{ + /*task_type=*/DynamicTaskType::FWD, + /*device_coord=*/std::nullopt, + /*mapping=*/std::nullopt, + /*op_attrs*/ TrainingOperationAttrs{CopyAttrs{}}, + /*layer_guid=*/dynamic_layer_guid_t{dynamic_copy_layer_guid_t{}}, + /*per_device_op_state=*/std::nullopt, }, - { - mc2, - mk_input_shard_binding(mc2_weight_coord), + /*outputs=*/{{mk_slot(TensorSlotName::OUTPUT), dst}}, + }; + }; + + SUBCASE("same mapping, no copies") { + std::unordered_map sources_same{ + {graph_input1, graph_input1_src_same}, + {graph_input2, graph_input2_src_same}, + }; + + std::unordered_set result = + copies_for_invocation_inputs(input, sources_same); + + std::unordered_set correct = {}; + + CHECK(result.size() == correct.size()); + CHECK(result == correct); + } + + SUBCASE("copy one tensor, one point") { + MappedOperatorTaskGroup input_mapping_copy1 = MappedOperatorTaskGroup{ + bidict{ + { + mc1, + mk_input_shard_binding(mc1_input_coord), + }, + { + mc3, + mk_input_shard_binding(mc2_input_coord), + }, }, - }, - }; + }; - MappedOperatorTaskGroup invocation_mapping = MappedOperatorTaskGroup{ - bidict{ - { - mc1, - mk_shard_binding(mc1_input_coord, - mc1_weight_coord, - mc1_output_1_coord, - mc1_output_2_coord), + MappedOperatorTaskGroup input_mapping_copy1_diff_vs_use = + MappedOperatorTaskGroup{ + bidict{ + { + mc3, + mk_input_shard_binding(mc2_input_coord), + }, + }, + }; + + DynamicValueAttrs graph_input1_src_copy1 = + decide_dynamic_value_attrs_mapping( + graph_input1, + get_tensor_bindings_for_slot_name(input_mapping_copy1, TensorSlotName::OUTPUT)); + + DynamicValueAttrs graph_input1_src_copy1_diff_vs_use = + decide_dynamic_value_attrs_mapping( + graph_input1, + get_tensor_bindings_for_slot_name(input_mapping_copy1_diff_vs_use, TensorSlotName::OUTPUT)); + + std::unordered_map sources_copy1{ + {graph_input1, graph_input1_src_copy1}, + {graph_input2, graph_input2_src_same}}; + + std::unordered_set result = + copies_for_invocation_inputs(input, sources_copy1); + + std::unordered_set correct = { + mk_copy(graph_input1_src_copy1_diff_vs_use, graph_input1_use_diff_vs_copy1), + }; + + CHECK(result.size() == correct.size()); + CHECK(result == correct); + } + + SUBCASE("copy two tensors, two points") { + MappedOperatorTaskGroup input_mapping_copy2 = MappedOperatorTaskGroup{ + bidict{ + { + mc3, + mk_input_shard_binding(mc1_input_coord), + }, + { + mc4, + mk_input_shard_binding(mc2_input_coord), + }, }, - { - mc2, - mk_shard_binding(mc2_input_coord, - mc2_weight_coord, - mc2_output_1_coord, - mc2_output_2_coord), + }; + MappedOperatorTaskGroup weight_mapping_copy2 = MappedOperatorTaskGroup{ + bidict{ + { + mc4, + mk_input_shard_binding(mc1_weight_coord), + }, + { + mc3, + mk_input_shard_binding(mc2_weight_coord), + }, }, - }, - }; + }; - MappedOperatorTaskGroup invocation_mapping_diff_vs_copy1 = - MappedOperatorTaskGroup{ - bidict{ + DynamicValueAttrs graph_input1_src_copy2 = + decide_dynamic_value_attrs_mapping( + graph_input1, + get_tensor_bindings_for_slot_name(input_mapping_copy2, TensorSlotName::OUTPUT)); + + DynamicValueAttrs graph_input2_src_copy2 = + decide_dynamic_value_attrs_mapping( + graph_input2, + get_tensor_bindings_for_slot_name(weight_mapping_copy2, TensorSlotName::OUTPUT)); + + std::unordered_map sources_copy2{ + {graph_input1, graph_input1_src_copy2}, + {graph_input2, graph_input2_src_copy2}}; + + std::unordered_set result = + copies_for_invocation_inputs(input, sources_copy2); + + std::unordered_set correct = { + mk_copy(graph_input1_src_copy2, graph_input1_use), + mk_copy(graph_input2_src_copy2, graph_input2_use), + }; + + CHECK(result.size() == correct.size()); + CHECK(result == correct); + } + } + + SUBCASE("replicate operator") { + + auto mk_shard_binding = [&](ParallelTensorSpaceCoordinate const &c1, + ParallelTensorSpaceCoordinate const &c2) + -> OperatorAtomicTaskShardBinding { + return OperatorAtomicTaskShardBinding{ + /*tensor_coords=*/{ + { + TensorSlotName::INPUT, + c1, + }, { - mc2, - mk_shard_binding(mc2_input_coord, - mc2_weight_coord, - mc2_output_1_coord, - mc2_output_2_coord), + TensorSlotName::OUTPUT, + c2, }, }, }; - auto mk_slot = [](TensorSlotName const &slot_name) -> DynamicTensorSlot { - return DynamicTensorSlot{ - /*slot_name=*/slot_name, - /*slot_tensor_role=*/mk_dynamic_tensor_role_fwd(), }; - }; - auto mk_value = [&](size_t src_node_id, - TensorSlotName src_slot_name, - MappedOperatorTaskGroup const &mapping, - std::optional const &use_slot_name) - -> DynamicValueAttrs { - return DynamicValueAttrs{ - /*tensor_guid=*/dynamic_tensor_guid_t{parallel_tensor_guid_t{ - KwargDataflowOutput{ - Node{src_node_id}, - src_slot_name, + ParallelTensorSpaceCoordinate mc_input_coord = + mk_pt_coord(0_n, 0_n, 0_n, 0_n); + + ParallelTensorSpaceCoordinate mc1_output_coord = + mk_pt_coord(0_n, 0_n, 0_n, 0_n); + ParallelTensorSpaceCoordinate mc2_output_coord = + mk_pt_coord(0_n, 1_n, 0_n, 0_n); + + MappedOperatorTaskGroup invocation_mapping = MappedOperatorTaskGroup{ + bidict{ + { + mc1, + mk_shard_binding(mc_input_coord, + mc1_output_coord), }, - }}, - /*parallel_tensor_shape=*/std::nullopt, - /*shard_coord=*/std::nullopt, - /*mapping=*/ - transform(use_slot_name, - [&](TensorSlotName s) { - return get_tensor_bindings_for_slot_name(mapping, s); - }), - /*accessor=*/std::nullopt, - /*role=*/std::nullopt, + { + mc2, + mk_shard_binding(mc_input_coord, + mc2_output_coord), + }, + }, }; - }; - size_t invocation1_id = 20; - - DynamicValueAttrs graph_input1 = - mk_value(0, TensorSlotName::OUTPUT, invocation_mapping, std::nullopt); - DynamicValueAttrs graph_input1_use = mk_value( - 0, TensorSlotName::OUTPUT, invocation_mapping, TensorSlotName::INPUT); - DynamicValueAttrs graph_input1_use_diff_vs_copy1 = - mk_value(0, - TensorSlotName::OUTPUT, - invocation_mapping_diff_vs_copy1, - TensorSlotName::INPUT); - DynamicValueAttrs graph_input2 = - mk_value(1, TensorSlotName::OUTPUT, invocation_mapping, std::nullopt); - DynamicValueAttrs graph_input2_use = mk_value( - 1, TensorSlotName::OUTPUT, invocation_mapping, TensorSlotName::WEIGHT); - DynamicValueAttrs invocation1_output1 = mk_value(invocation1_id, - TensorSlotName::OUTPUT_1, - invocation_mapping, - std::nullopt); - DynamicValueAttrs invocation1_output1_src = - mk_value(invocation1_id, - TensorSlotName::OUTPUT_1, - invocation_mapping, - TensorSlotName::OUTPUT_1); - DynamicValueAttrs invocation1_output2 = mk_value(invocation1_id, - TensorSlotName::OUTPUT_2, - invocation_mapping, - std::nullopt); - DynamicValueAttrs invocation1_output2_src = - mk_value(invocation1_id, - TensorSlotName::OUTPUT_2, - invocation_mapping, - TensorSlotName::OUTPUT_2); - - DynamicValueAttrs graph_input1_src_same = mk_value( - 0, TensorSlotName::OUTPUT, input_mapping_same, TensorSlotName::OUTPUT); - DynamicValueAttrs graph_input2_src_same = mk_value( - 1, TensorSlotName::OUTPUT, weight_mapping_same, TensorSlotName::OUTPUT); - - DynamicNodeInvocation input = DynamicNodeInvocation{ + DynamicValueAttrs graph_input_unmapped = + mk_value(0, TensorSlotName::OUTPUT); + DynamicValueAttrs graph_input_use_mapped = + decide_dynamic_value_attrs_mapping( + graph_input_unmapped, + get_tensor_bindings_for_slot_name(invocation_mapping, TensorSlotName::INPUT)); + + DynamicValueAttrs invocation_output_unmapped = + mk_value(invocation_id, TensorSlotName::OUTPUT); + DynamicValueAttrs invocation_output_src_mapped = + decide_dynamic_value_attrs_mapping( + invocation_output_unmapped, + get_tensor_bindings_for_slot_name(invocation_mapping, TensorSlotName::OUTPUT)); + + DynamicNodeInvocation input = DynamicNodeInvocation{ /*inputs=*/{ { mk_slot(TensorSlotName::INPUT), - graph_input1, - }, - { - mk_slot(TensorSlotName::WEIGHT), - graph_input2, + graph_input_unmapped, }, }, - /*node_attrs=*/ - DynamicNodeAttrs{ + /*node_attrs=*/DynamicNodeAttrs{ /*task_type=*/DynamicTaskType::FWD, /*device_coord=*/std::nullopt, /*mapping=*/invocation_mapping, - /*op_attrs=*/std::nullopt, - /*layer_guid=*/ - dynamic_layer_guid_t{parallel_layer_guid_t{Node{20}}}, - /*per_device_op_state=*/std::nullopt, - }, - /*outputs=*/ - { - { - mk_slot(TensorSlotName::OUTPUT_1), - invocation1_output1, - }, - { - mk_slot(TensorSlotName::OUTPUT_2), - invocation1_output2, - }, - }, - }; - - DynamicNodeInvocation mapped = DynamicNodeInvocation{ - /*inputs=*/{ - { - mk_slot(TensorSlotName::INPUT), - graph_input1_use, + /*op_attrs=*/TrainingOperationAttrs{ + PCGOperatorAttrs{ + ReplicateAttrs{ + 2_p, + }, + }, }, - { - mk_slot(TensorSlotName::WEIGHT), - graph_input2_use, + /*layer_guid=*/dynamic_layer_guid_t{ + parallel_layer_guid_t{ + Node{invocation_id}, + }, }, - }, - /*node_attrs=*/ - DynamicNodeAttrs{ - /*task_type=*/DynamicTaskType::FWD, - /*device_coord=*/std::nullopt, - /*mapping=*/invocation_mapping, - /*op_attrs=*/std::nullopt, - /*layer_guid=*/ - dynamic_layer_guid_t{parallel_layer_guid_t{Node{20}}}, /*per_device_op_state=*/std::nullopt, }, - /*outputs=*/ - { - { - mk_slot(TensorSlotName::OUTPUT_1), - invocation1_output1_src, - }, + /*outputs=*/{ { - mk_slot(TensorSlotName::OUTPUT_2), - invocation1_output2_src, + mk_slot(TensorSlotName::OUTPUT), + invocation_output_unmapped, }, }, - }; - - auto mk_copy = [&](DynamicValueAttrs const &src, - DynamicValueAttrs const &dst) { - return DynamicNodeInvocation{ - /*inputs=*/{{mk_slot(TensorSlotName::INPUT), src}}, - /*node_attrs=*/ - DynamicNodeAttrs{ - /*task_type=*/DynamicTaskType::FWD, - /*device_coord=*/std::nullopt, - /*mapping=*/std::nullopt, - /*op_attrs*/ TrainingOperationAttrs{CopyAttrs{}}, - /*layer_guid=*/dynamic_layer_guid_t{dynamic_copy_layer_guid_t{}}, - /*per_device_op_state=*/std::nullopt, - }, - /*outputs=*/{{mk_slot(TensorSlotName::OUTPUT), dst}}, }; - }; - - SUBCASE("same mapping, no copies") { - std::unordered_map sources_same{ - {graph_input1, graph_input1_src_same}, - {graph_input2, graph_input2_src_same}}; - - std::unordered_set result = - perform_copy_insertion_for_invocation(input, sources_same); - std::unordered_set correct = {mapped}; - - CHECK(result.size() == correct.size()); - CHECK(result == correct); - } - - SUBCASE("copy one tensor, one point") { - MappedOperatorTaskGroup input_mapping_copy1 = MappedOperatorTaskGroup{ - bidict{ - { - mc1, - mk_input_shard_binding(mc1_input_coord), - }, - { - mc3, - mk_input_shard_binding(mc2_input_coord), - }, + std::unordered_map unmapped_to_mapped_source_value = { + { + graph_input_unmapped, + decide_dynamic_value_attrs_mapping( + graph_input_unmapped, + OneToMany{ + { + mc_input_coord, + {mc3}, + }, + }) }, - }; - MappedOperatorTaskGroup input_mapping_copy1_diff_vs_use = - MappedOperatorTaskGroup{ - bidict{ - { - mc3, - mk_input_shard_binding(mc2_input_coord), - }, - }, - }; - - DynamicValueAttrs graph_input1_src_copy1 = - mk_value(0, - TensorSlotName::OUTPUT, - input_mapping_copy1, - TensorSlotName::OUTPUT); - DynamicValueAttrs graph_input1_src_copy1_diff_vs_use = - mk_value(0, - TensorSlotName::OUTPUT, - input_mapping_copy1_diff_vs_use, - TensorSlotName::OUTPUT); - - std::unordered_map sources_copy1{ - {graph_input1, graph_input1_src_copy1}, - {graph_input2, graph_input2_src_same}}; - - std::unordered_set result = - perform_copy_insertion_for_invocation(input, sources_copy1); - - std::unordered_set correct = { - mapped, - mk_copy(graph_input1_src_copy1_diff_vs_use, - graph_input1_use_diff_vs_copy1), - }; + }; - CHECK(result.size() == correct.size()); - CHECK(result == correct); - } + std::unordered_set result = copies_for_invocation_inputs( + input, unmapped_to_mapped_source_value); - SUBCASE("copy two tensors, two points") { - MappedOperatorTaskGroup input_mapping_copy2 = MappedOperatorTaskGroup{ - bidict{ - { - mc3, - mk_input_shard_binding(mc1_input_coord), - }, - { - mc4, - mk_input_shard_binding(mc2_input_coord), - }, - }, - }; - MappedOperatorTaskGroup weight_mapping_copy2 = MappedOperatorTaskGroup{ - bidict{ - { - mc4, - mk_input_shard_binding(mc1_weight_coord), - }, - { - mc3, - mk_input_shard_binding(mc2_weight_coord), - }, - }, - }; + std::unordered_set correct = {}; - DynamicValueAttrs graph_input1_src_copy2 = - mk_value(0, - TensorSlotName::OUTPUT, - input_mapping_copy2, - TensorSlotName::OUTPUT); - DynamicValueAttrs graph_input2_src_copy2 = - mk_value(1, - TensorSlotName::OUTPUT, - weight_mapping_copy2, - TensorSlotName::OUTPUT); - - std::unordered_map sources_copy2{ - {graph_input1, graph_input1_src_copy2}, - {graph_input2, graph_input2_src_copy2}}; - - std::unordered_set result = - perform_copy_insertion_for_invocation(input, sources_copy2); - - std::unordered_set correct = { - mapped, - mk_copy(graph_input1_src_copy2, graph_input1_use), - mk_copy(graph_input2_src_copy2, graph_input2_use), - }; + nlohmann::json result_j = transform(result, dynamic_node_invocation_to_serializable); + nlohmann::json correct_j = transform(correct, dynamic_node_invocation_to_serializable); - CHECK(result.size() == correct.size()); - CHECK(result == correct); + CHECK(result_j == correct_j); } + + // SUBCASE("reduction operator") { + + // auto mk_shard_binding = [&](ParallelTensorSpaceCoordinate const &c1, + // ParallelTensorSpaceCoordinate const &c2) + // -> OperatorAtomicTaskShardBinding { + // return OperatorAtomicTaskShardBinding{ + // /*tensor_coords=*/{ + // { + // TensorSlotName::INPUT, + // c1, + // }, + // { + // TensorSlotName::OUTPUT, + // c2, + // }, + // }, + // }; + // }; + + // ParallelTensorSpaceCoordinate mc1_input_coord = + // mk_pt_coord(0_n, 0_n, 0_n, 0_n); + // ParallelTensorSpaceCoordinate mc2_input_coord = + // mk_pt_coord(1_n, 0_n, 0_n, 0_n); + + // ParallelTensorSpaceCoordinate mc_output_coord = + // mk_pt_coord(0_n, 0_n, 0_n, 0_n); + + // MappedOperatorTaskGroup invocation_mapping = MappedOperatorTaskGroup{ + // bidict{ + // { + // mc3, + // mk_shard_binding(mc1_input_coord, + // mc_output_coord), + // }, + // }, + // }; + + // DynamicValueAttrs graph_input_unmapped = + // mk_value(0, TensorSlotName::OUTPUT); + // DynamicValueAttrs graph_input_use_mapped = + // decide_dynamic_value_attrs_mapping( + // graph_input_unmapped, + // get_tensor_bindings_for_slot_name(invocation_mapping, TensorSlotName::INPUT)); + + // DynamicValueAttrs invocation_output_unmapped = + // mk_value(invocation_id, TensorSlotName::OUTPUT); + // DynamicValueAttrs invocation_output_src_mapped = + // decide_dynamic_value_attrs_mapping( + // invocation_output_unmapped, + // get_tensor_bindings_for_slot_name(invocation_mapping, TensorSlotName::OUTPUT)); + + // DynamicNodeInvocation input = DynamicNodeInvocation{ + // /*inputs=*/{ + // { + // mk_slot(TensorSlotName::INPUT), + // graph_input_unmapped, + // }, + // }, + // /*node_attrs=*/DynamicNodeAttrs{ + // /*task_type=*/DynamicTaskType::FWD, + // /*device_coord=*/std::nullopt, + // /*mapping=*/invocation_mapping, + // /*op_attrs=*/TrainingOperationAttrs{ + // PCGOperatorAttrs{ + // ReductionAttrs{ + // /*reduction_degree=*/2_p, + // }, + // }, + // }, + // /*layer_guid=*/dynamic_layer_guid_t{ + // parallel_layer_guid_t{ + // Node{invocation_id}, + // }, + // }, + // /*per_device_op_state=*/std::nullopt, + // }, + // /*outputs=*/{ + // { + // mk_slot(TensorSlotName::OUTPUT), + // invocation_output_unmapped, + // }, + // }, + // }; + + // std::unordered_map unmapped_to_mapped_source_value = { + // { + // graph_input_unmapped, + // decide_dynamic_value_attrs_mapping( + // graph_input_unmapped, + // OneToMany{ + // { + // mc1_input_coord, + // {mc1}, + // }, + // { + // mc2_input_coord, + // {mc2}, + // }, + // }) + // }, + // }; + + // std::unordered_set result = copies_for_invocation_inputs( + // input, unmapped_to_mapped_source_value); + + // std::unordered_set correct = {}; + + // nlohmann::json result_j = transform(result, dynamic_node_invocation_to_serializable); + // nlohmann::json correct_j = transform(correct, dynamic_node_invocation_to_serializable); + + // CHECK(result_j == correct_j); + // } } } diff --git a/lib/task-spec/test/src/task-spec/dynamic_graph/dynamic_open_dataflow_graph.cc b/lib/task-spec/test/src/task-spec/dynamic_graph/dynamic_open_dataflow_graph.cc index 49b8d4a77a..96523b6c31 100644 --- a/lib/task-spec/test/src/task-spec/dynamic_graph/dynamic_open_dataflow_graph.cc +++ b/lib/task-spec/test/src/task-spec/dynamic_graph/dynamic_open_dataflow_graph.cc @@ -58,33 +58,40 @@ TEST_SUITE(FF_TEST_SUITE) { }; DynamicNodeInvocation invocation_1 = DynamicNodeInvocation{ - /*inputs=*/std::unordered_map{ - {DynamicTensorSlot{ - /*slot_name=*/TensorSlotName::INPUT, - /*slot_tensor_role=*/std::nullopt, - }, - value_1}, + /*inputs=*/std::map{ + { + DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::INPUT, + /*slot_tensor_role=*/std::nullopt, + /*task_shard=*/std::nullopt, + }, + value_1, + }, }, /*node_attrs=*/node_attrs, /*outputs=*/ - std::unordered_map{ - {DynamicTensorSlot{ - /*slot_name=*/TensorSlotName::OUTPUT, - /*slot_tensor_role=*/std::nullopt, - }, - value_2}, + std::map{ + { + DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::OUTPUT, + /*slot_tensor_role=*/std::nullopt, + /*task_shard=*/std::nullopt, + }, + value_2, + }, }, }; DynamicNodeInvocation invocation_2 = DynamicNodeInvocation{ - /*inputs=*/std::unordered_map{}, + /*inputs=*/std::map{}, /*node_attrs=*/node_attrs, /*outputs=*/ - std::unordered_map{ + std::map{ { DynamicTensorSlot{ /*slot_name=*/TensorSlotName::OUTPUT, /*slot_tensor_role=*/std::nullopt, + /*task_shard=*/std::nullopt, }, value_3, }, @@ -92,11 +99,12 @@ TEST_SUITE(FF_TEST_SUITE) { }; DynamicNodeInvocation invocation_3 = DynamicNodeInvocation{ - /*inputs=*/std::unordered_map{ + /*inputs=*/std::map{ { DynamicTensorSlot{ /*slot_name=*/TensorSlotName::INPUT, /*slot_tensor_role=*/std::nullopt, + /*task_shard=*/std::nullopt, }, value_1, }, @@ -104,6 +112,7 @@ TEST_SUITE(FF_TEST_SUITE) { DynamicTensorSlot{ /*slot_name=*/TensorSlotName::WEIGHT, /*slot_tensor_role=*/std::nullopt, + /*task_shard=*/std::nullopt, }, value_2, }, @@ -111,12 +120,13 @@ TEST_SUITE(FF_TEST_SUITE) { DynamicTensorSlot{ /*slot_name=*/TensorSlotName::BIAS, /*slot_tensor_role=*/std::nullopt, + /*task_shard=*/std::nullopt, }, value_1, }, }, /*node_attrs=*/node_attrs, - /*outputs=*/std::unordered_map{}, + /*outputs=*/std::map{}, }; std::unordered_set invocation_set = { diff --git a/lib/task-spec/test/src/task-spec/dynamic_graph/machine_slicing.cc b/lib/task-spec/test/src/task-spec/dynamic_graph/machine_slicing.cc index 40b3460ee5..789d89a676 100644 --- a/lib/task-spec/test/src/task-spec/dynamic_graph/machine_slicing.cc +++ b/lib/task-spec/test/src/task-spec/dynamic_graph/machine_slicing.cc @@ -58,6 +58,7 @@ TEST_SUITE(FF_TEST_SUITE) { return DynamicTensorSlot{ /*slot_name=*/slot_name, /*slot_tensor_role=*/std::nullopt, + /*task_shard=*/std::nullopt, }; }; @@ -112,7 +113,7 @@ TEST_SUITE(FF_TEST_SUITE) { /*node_attrs=*/ DynamicNodeAttrs{ /*task_type=*/std::nullopt, - /*device_coord=*/mc2, + /*device_coords=*/nonempty_set{mc2}, /*mapping=*/std::nullopt, /*op_attrs=*/std::nullopt, /*layer_guid=*/ @@ -139,7 +140,7 @@ TEST_SUITE(FF_TEST_SUITE) { /*node_attrs=*/ DynamicNodeAttrs{ /*task_type=*/std::nullopt, - /*device_coord=*/mc1, + /*device_coord=*/nonempty_set{mc1}, /*mapping=*/std::nullopt, /*op_attrs=*/std::nullopt, /*layer_guid=*/ @@ -173,7 +174,7 @@ TEST_SUITE(FF_TEST_SUITE) { /*node_attrs=*/ DynamicNodeAttrs{ /*task_type=*/std::nullopt, - /*device_coord=*/mc2, + /*device_coord=*/nonempty_set{mc2}, /*mapping=*/std::nullopt, /*op_attrs=*/std::nullopt, /*layer_guid=*/ diff --git a/lib/task-spec/test/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc b/lib/task-spec/test/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc new file mode 100644 index 0000000000..2e14c88654 --- /dev/null +++ b/lib/task-spec/test/src/task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.cc @@ -0,0 +1,398 @@ +#include +#include "task-spec/dynamic_graph/make_dynamic_open_dataflow_graph_from_mapped_pcg.h" +#include "utils/containers/require_only_key.h" +#include "op-attrs/ops/element_unary.h" +#include "pcg/mapped_parallel_computation_graph/mapped_parallel_computation_graph.h" + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("make_dynamic_node_invocation_from_mapped") { + SUBCASE("Replicate") { + MachineSpaceCoordinate gpu0 = MachineSpaceCoordinate{0_n, 0_n, DeviceType::GPU}; + MachineSpaceCoordinate gpu1 = MachineSpaceCoordinate{0_n, 1_n, DeviceType::GPU}; + + ParallelTensorSpaceCoordinate tensor_coord0 = ParallelTensorSpaceCoordinate{ + /*sum_component=*/0_n, + /*discard_copy_component=*/0_n, + /*shard_component=*/FFOrdered{0_n}, + }; + + ParallelTensorSpaceCoordinate tensor_coord1 = ParallelTensorSpaceCoordinate{ + /*sum_component=*/0_n, + /*discard_copy_component=*/1_n, + /*shard_component=*/FFOrdered{0_n}, + }; + + MappedOperatorTaskGroup mapping = MappedOperatorTaskGroup{ + { + { + gpu0, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord0}, + }}, + }, + { + gpu1, + OperatorAtomicTaskShardBinding{{ + {TensorSlotName::OUTPUT, tensor_coord1}, + }}, + }, + }, + }; + + ParallelTensorShape input_shape = ParallelTensorShape{ + /*dims=*/ParallelTensorDims{ + /*shard_dims=*/FFOrdered{ + ShardParallelDim{8_p, 2_p}, + ShardParallelDim{5_p, 1_p}, + }, + /*replica_dims=*/ReplicaParallelDimSet{ + SumDegree{1_p}, + DiscardCopyDegree{1_p}, + }, + }, + /*data_type=*/DataType::FLOAT, + }; + + ParallelTensorShape output_shape = [&] { + ParallelTensorShape shape = input_shape; + shape.dims.replica_dims.discard_copy_degree = DiscardCopyDegree{2_p}; + return shape; + }(); + + PCGOperatorAttrs op_attrs = PCGOperatorAttrs{ + ReplicateAttrs{ + 2_p, + }, + }; + + parallel_layer_guid_t layer_guid = parallel_layer_guid_t{Node{0}}; + parallel_tensor_guid_t input_tensor_guid = parallel_tensor_guid_t{ + KwargDataflowOutput{ + Node{5}, + TensorSlotName::OUTPUT, + }, + }; + parallel_tensor_guid_t output_tensor_guid = parallel_tensor_guid_t{ + KwargDataflowOutput{ + Node{0}, + TensorSlotName::OUTPUT, + }, + }; + + MappedParallelLayerInvocationInfo input = MappedParallelLayerInvocationInfo{ + /*incoming=*/{ + { + TensorSlotName::INPUT, + ParallelTensorInfo{ + /*guid=*/input_tensor_guid, + /*attrs=*/ParallelTensorAttrs{ + /*shape=*/input_shape, + /*create_grad=*/CreateGrad::YES, + }, + }, + }, + }, + /*layer_info=*/MappedParallelLayerInfo{ + /*guid=*/layer_guid, + /*attrs=*/ParallelLayerAttrs{ + /*op_attrs=*/op_attrs, + /*name=*/std::nullopt, + }, + /*mapping=*/mapping, + }, + /*outgoing=*/{ + { + TensorSlotName::OUTPUT, + ParallelTensorInfo{ + /*guid=*/output_tensor_guid, + /*attrs=*/ParallelTensorAttrs{ + /*shape=*/output_shape, + /*create_grad=*/CreateGrad::YES, + }, + }, + }, + }, + }; + + DynamicNodeInvocation result = make_dynamic_node_invocation_from_mapped(input); + + DynamicNodeInvocation correct = DynamicNodeInvocation{ + /*inputs=*/{ + { + DynamicTensorSlot{ + TensorSlotName::INPUT, + /*slot_tensor_role=*/std::nullopt, + /*task_shard=*/std::nullopt, + }, + DynamicValueAttrs{ + /*tensor_guid=*/dynamic_tensor_guid_t{input_tensor_guid}, + /*parallel_tensor_shape=*/input_shape, + /*shard_coord=*/std::nullopt, + /*mapping=*/std::nullopt, + /*accessor=*/std::nullopt, + /*role=*/std::nullopt, + }, + }, + }, + /*node_attrs=*/DynamicNodeAttrs{ + /*task_type=*/std::nullopt, + /*device_coord=*/std::nullopt, + /*mapping=*/mapping, + /*op_attrs=*/TrainingOperationAttrs{op_attrs}, + /*layer_guid=*/dynamic_layer_guid_t{layer_guid}, + /*per_device_op_state=*/std::nullopt, + }, + /*outputs=*/{ + { + DynamicTensorSlot{ + TensorSlotName::OUTPUT, + /*slot_tensor_role=*/std::nullopt, + /*task_shard=*/std::nullopt, + }, + DynamicValueAttrs{ + /*tensor_guid=*/dynamic_tensor_guid_t{output_tensor_guid}, + /*parallel_tensor_shape=*/output_shape, + /*shard_coord=*/std::nullopt, + /*mapping=*/std::nullopt, + /*accessor=*/std::nullopt, + /*role=*/std::nullopt, + }, + }, + } + }; + + CHECK(result == correct); + } + + // SUBCASE("standard op") { + // + // } + } + + // TEST_CASE("make_dynamic_open_dataflow_graph_from_mapped_pcg") { + // positive_int batch_size = 10_p; + // positive_int data_dim = 16_p; + // positive_int hidden_dim = 32_p; + // positive_int output_dim = 1_p; + + // auto make_layer_attrs = [](auto const &op_attrs) -> ParallelLayerAttrs { + // return ParallelLayerAttrs{ + // /*op_attrs=*/PCGOperatorAttrs{op_attrs}, + // /*name=*/std::nullopt, + // }; + // }; + + + // TensorShape output_tensor_shape = TensorShape{ + // TensorDims{FFOrdered{batch_size, output_dim}}, DataType::FLOAT}; + + // TensorShape label_tensor_shape = TensorShape{ + // TensorDims{FFOrdered{batch_size, output_dim}}, DataType::FLOAT}; + + // ParallelComputationGraph pcg = empty_parallel_computation_graph(); + + // TensorShape input_tensor_shape = TensorShape{ + // TensorDims{FFOrdered{batch_size, data_dim}}, DataType::FLOAT}; + + // ParallelLayerAddedResult inputs_layer = + // pcg_add_input_layer(pcg, input_tensor_shape); + // parallel_tensor_guid_t t_input = + // require_only_key(inputs_layer.outputs, TensorSlotName::OUTPUT); + + // ParallelLayerAddedResult inputs_layer_2 = + // pcg_add_input_layer(pcg, input_tensor_shape); + // parallel_tensor_guid_t t_input_2 = + // require_only_key(inputs_layer_2.outputs, TensorSlotName::OUTPUT); + + // ElementBinaryAttrs add_attrs = ElementBinaryAttrs{ + // OperatorType::EW_ADD, + // DataType::FLOAT, + // false, + // false, + // }; + + // ParallelLayerAddedResult add_operator_1 = + // add_parallel_layer(pcg, + // make_layer_attrs(add_attrs), + // { + // { + // TensorSlotName::LHS_INPUT, + // t_input, + // }, + // { + // TensorSlotName::RHS_INPUT, + // t_input_2, + // }, + // }, + // /*weights=*/{}); + + // parallel_tensor_guid_t t_add_1 = + // require_only_key(add_operator_1.outputs, TensorSlotName::OUTPUT); + + // positive_int replicate_degree = 2_p; + // ReplicateAttrs repl_attrs = ReplicateAttrs{replicate_degree}; + // ParallelLayerAddedResult repl_operator_1 = + // add_parallel_layer(pcg, + // make_layer_attrs(repl_attrs), + // { + // { + // TensorSlotName::INPUT, + // t_add_1, + // }, + // }, + // /*weight=*/{}); + + // parallel_tensor_guid_t t_repl_1 = + // require_only_key(repl_operator_1.outputs, TensorSlotName::OUTPUT); + + // ParallelLayerAddedResult relu_operator_1 = + // add_parallel_layer(pcg, + // make_layer_attrs(make_relu_attrs()), + // /*inputs=*/ + // { + // { + // TensorSlotName::INPUT, + // t_repl_1, + // }, + // }, + // /*weights=*/{}); + + // parallel_tensor_guid_t t_relu_1 = + // require_only_key(relu_operator_1.outputs, TensorSlotName::OUTPUT); + + // MachineSpaceCoordinate gpu0{0_n, 0_n, DeviceType::GPU}; + // MachineSpaceCoordinate gpu1{0_n, 1_n, DeviceType::GPU}; + + // ParallelTensorSpaceCoordinate tensor_coord0{ + // /*sum_component=*/0_n, + // /*discard_copy_component=*/0_n, + // /*shard_component=*/FFOrdered{0_n}}; + // ParallelTensorSpaceCoordinate tensor_coord1{ + // /*sum_component=*/0_n, + // /*discard_copy_component=*/1_n, + // /*shard_component=*/FFOrdered{0_n}}; + + // MappedOperatorTaskGroup input_1_mapping = MappedOperatorTaskGroup{ + // { + // { + // gpu0, + // OperatorAtomicTaskShardBinding{{ + // {TensorSlotName::OUTPUT, tensor_coord0}, + // }}, + // }, + // }, + // }; + + // MappedOperatorTaskGroup input_2_mapping = MappedOperatorTaskGroup{ + // { + // { + // gpu0, + // OperatorAtomicTaskShardBinding{{ + // {TensorSlotName::OUTPUT, tensor_coord0}, + // }}, + // }, + // }, + // }; + + // MappedOperatorTaskGroup add_operator_1_mapping = MappedOperatorTaskGroup{ + // { + // { + // gpu0, + // OperatorAtomicTaskShardBinding{{ + // {TensorSlotName::LHS_INPUT, tensor_coord0}, + // {TensorSlotName::RHS_INPUT, tensor_coord0}, + // {TensorSlotName::OUTPUT, tensor_coord0}, + // }}, + // }, + // }, + // }; + + // MappedOperatorTaskGroup repl_operator_1_mapping = MappedOperatorTaskGroup{ + // { + // { + // gpu0, + // OperatorAtomicTaskShardBinding{{ + // {TensorSlotName::OUTPUT, tensor_coord0}, + // }}, + // }, + // { + // gpu1, + // OperatorAtomicTaskShardBinding{{ + // {TensorSlotName::OUTPUT, tensor_coord1}, + // }}, + // }, + // }, + // }; + + // MappedOperatorTaskGroup relu_operator_1_mapping = MappedOperatorTaskGroup{ + // { + // { + // gpu0, + // OperatorAtomicTaskShardBinding{{ + // {TensorSlotName::INPUT, tensor_coord0}, + // {TensorSlotName::OUTPUT, tensor_coord0}, + // }}, + // }, + // { + // gpu1, + // OperatorAtomicTaskShardBinding{{ + // {TensorSlotName::INPUT, tensor_coord1}, + // {TensorSlotName::OUTPUT, tensor_coord1}, + // }}, + // }, + // }, + // }; + + // MappedParallelComputationGraph mpcg = mapped_pcg_from_pcg_and_mapped_op_task_groups( + // /*pcg=*/pcg, + // /*mapped_op_task_groups=*/{ + // { + // inputs_layer.parallel_layer, + // input_1_mapping, + // }, + // { + // inputs_layer_2.parallel_layer, + // input_2_mapping, + // }, + // { + // add_operator_1.parallel_layer, + // add_operator_1_mapping, + // }, + // { + // repl_operator_1.parallel_layer, + // repl_operator_1_mapping, + // }, + // { + // relu_operator_1.parallel_layer, + // relu_operator_1_mapping, + // }, + // }); + + + // DynamicOpenDataflowGraph result = make_dynamic_open_dataflow_graph_from_mapped_pcg(mpcg); + + // DynamicNodeInvocation input_1_invocation = DynamicNodeInvocation{ + // DynamicNodeAttrs{ + // /*task_type=*/std::nullopt, + // /*device_coord=*/std::nullopt, + // /*mapping=*/input_1_mapping, + // /*op_attrs=*/TrainingOperationAttrs{ + // /*pcg_layer_guid=*/ + // /*per_device_op_state=*/std::nullopt, + // }, + // }; + + // DynamicNodeInvocation input_2_invocation = + // DynamicNodeInvocation add_operator_1_invocation = + // DynamicNodeInvocation repl_operator_1_invocation = + // DynamicNodeInvocation relu_operator_1_invocation = + + // DynamicOpenDataflowGraph correct = dynamic_open_dataflow_graph_from_invocation_set( + // /*invocations=*/{ + + // }, + // }; + // } +} diff --git a/lib/task-spec/test/src/task-spec/dynamic_graph/pass_expansion.cc b/lib/task-spec/test/src/task-spec/dynamic_graph/pass_expansion.cc index fb087f5295..90fbdec5f7 100644 --- a/lib/task-spec/test/src/task-spec/dynamic_graph/pass_expansion.cc +++ b/lib/task-spec/test/src/task-spec/dynamic_graph/pass_expansion.cc @@ -1,4 +1,5 @@ #include "task-spec/dynamic_graph/pass_expansion.h" +#include "op-attrs/ops/element_unary.h" #include "task-spec/dynamic_graph/dynamic_open_dataflow_graph.h" #include "task-spec/dynamic_graph/dynamic_tensor_role.h" #include @@ -31,11 +32,24 @@ TEST_SUITE(FF_TEST_SUITE) { return DynamicTensorSlot{ /*slot_name=*/slot_name, /*slot_tensor_role=*/role, + /*task_shard=*/std::nullopt, }; }; dynamic_layer_guid_t layer_guid{parallel_layer_guid_t{Node{20}}}; + TrainingOperationAttrs op_attrs = TrainingOperationAttrs{ + PCGOperatorAttrs{ + LinearAttrs{ + /*out_channels=*/8_p, + /*use_bias=*/true, + /*data_type=*/DataType::FLOAT, + /*activation=*/std::nullopt, + /*regularizer=*/std::nullopt, + }, + }, + }; + DynamicNodeInvocation invocation = [&]() -> DynamicNodeInvocation { DynamicValueAttrs v1 = mk_value_attrs(0, std::nullopt); DynamicValueAttrs v2 = mk_value_attrs(1, std::nullopt); @@ -46,14 +60,13 @@ TEST_SUITE(FF_TEST_SUITE) { {mk_slot(TensorSlotName::INPUT, std::nullopt), v1}, {mk_slot(TensorSlotName::WEIGHT, std::nullopt), v2}, {mk_slot(TensorSlotName::BIAS, std::nullopt), v1}, - {mk_slot(TensorSlotName::SCALE, std::nullopt), v1}, }, /*node_attrs=*/ DynamicNodeAttrs{ /*task_type=*/std::nullopt, /*device_coord=*/std::nullopt, /*mapping=*/std::nullopt, - /*op_attrs=*/std::nullopt, + /*op_attrs=*/op_attrs, /*layer_guid=*/layer_guid, /*per_device_op_state=*/std::nullopt, }, @@ -79,14 +92,13 @@ TEST_SUITE(FF_TEST_SUITE) { {mk_slot(TensorSlotName::INPUT, fwd_role), v1_fwd}, {mk_slot(TensorSlotName::WEIGHT, fwd_role), v2_fwd}, {mk_slot(TensorSlotName::BIAS, fwd_role), v1_fwd}, - {mk_slot(TensorSlotName::SCALE, fwd_role), v1_fwd}, }, /*node_attrs=*/ DynamicNodeAttrs{ /*task_type=*/DynamicTaskType::FWD, /*device_coord=*/std::nullopt, /*mapping=*/std::nullopt, - /*op_attrs=*/std::nullopt, + /*op_attrs=*/op_attrs, /*layer_guid=*/layer_guid, /*per_device_op_state=*/std::nullopt, }, @@ -125,93 +137,168 @@ TEST_SUITE(FF_TEST_SUITE) { return DynamicTensorSlot{ /*slot_name=*/slot_name, /*slot_tensor_role=*/role, + /*task_shard=*/std::nullopt, }; }; dynamic_layer_guid_t layer_guid{parallel_layer_guid_t{Node{20}}}; - DynamicNodeInvocation invocation = [&]() -> DynamicNodeInvocation { - DynamicValueAttrs v1 = mk_value_attrs(0, std::nullopt); - DynamicValueAttrs v2 = mk_value_attrs(1, std::nullopt); - DynamicValueAttrs v3 = mk_value_attrs(2, std::nullopt); - - return DynamicNodeInvocation{ - /*inputs=*/{ - {mk_slot(TensorSlotName::INPUT, std::nullopt), v1}, - {mk_slot(TensorSlotName::WEIGHT, std::nullopt), v2}, - {mk_slot(TensorSlotName::BIAS, std::nullopt), v1}, - {mk_slot(TensorSlotName::SCALE, std::nullopt), v1}, - }, - /*node_attrs=*/ - DynamicNodeAttrs{ - /*task_type=*/std::nullopt, - /*device_coord=*/std::nullopt, - /*mapping=*/std::nullopt, - /*op_attrs=*/std::nullopt, - /*layer_guid=*/layer_guid, - /*per_device_op_state=*/std::nullopt, - }, - /*outputs=*/ - { - {mk_slot(TensorSlotName::OUTPUT, std::nullopt), v3}, + DynamicValueAttrs v1 = mk_value_attrs(0, std::nullopt); + DynamicValueAttrs v2 = mk_value_attrs(1, std::nullopt); + DynamicValueAttrs v3 = mk_value_attrs(2, std::nullopt); + + DynamicTensorRole fwd_role = DynamicTensorRole{FwbTensorType::FORWARD}; + DynamicTensorRole grad_role = DynamicTensorRole{FwbTensorType::GRADIENT}; + + DynamicValueAttrs v1_fwd = mk_value_attrs(0, fwd_role); + DynamicValueAttrs v2_fwd = mk_value_attrs(1, fwd_role); + DynamicValueAttrs v3_fwd = mk_value_attrs(2, fwd_role); + DynamicValueAttrs v1_grad = mk_value_attrs(0, grad_role); + DynamicValueAttrs v2_grad = mk_value_attrs(1, grad_role); + DynamicValueAttrs v3_grad = mk_value_attrs(2, grad_role); + + SUBCASE("normal operator") { + TrainingOperationAttrs op_attrs = TrainingOperationAttrs{ + PCGOperatorAttrs{ + LinearAttrs{ + /*out_channels=*/8_p, + /*use_bias=*/true, + /*data_type=*/DataType::FLOAT, + /*activation=*/std::nullopt, + /*regularizer=*/std::nullopt, + }, }, }; - }(); - - DynamicNodeInvocation result = - perform_bwd_pass_expansion_for_invocation(invocation); - DynamicNodeInvocation correct = [&]() -> DynamicNodeInvocation { - DynamicTensorRole fwd_role = DynamicTensorRole{FwbTensorType::FORWARD}; - DynamicTensorRole grad_role = DynamicTensorRole{FwbTensorType::GRADIENT}; - - DynamicValueAttrs v1_fwd = mk_value_attrs(0, fwd_role); - DynamicValueAttrs v2_fwd = mk_value_attrs(1, fwd_role); - DynamicValueAttrs v3_fwd = mk_value_attrs(2, fwd_role); - DynamicValueAttrs v1_grad = mk_value_attrs(0, grad_role); - DynamicValueAttrs v2_grad = mk_value_attrs(1, grad_role); - DynamicValueAttrs v3_grad = mk_value_attrs(2, grad_role); - - return DynamicNodeInvocation{ - /*inputs=*/{ - {mk_slot(TensorSlotName::INPUT, fwd_role), v1_fwd}, - {mk_slot(TensorSlotName::WEIGHT, fwd_role), v2_fwd}, - {mk_slot(TensorSlotName::BIAS, fwd_role), v1_fwd}, - {mk_slot(TensorSlotName::SCALE, fwd_role), v1_fwd}, - {mk_slot(TensorSlotName::OUTPUT, fwd_role), v3_fwd}, - {mk_slot(TensorSlotName::OUTPUT, grad_role), v3_grad}, - }, - /*node_attrs=*/ - DynamicNodeAttrs{ - /*pass_type=*/DynamicTaskType::BWD, - /*device_coord=*/std::nullopt, - /*mapping=*/std::nullopt, - /*op_attrs=*/std::nullopt, - /*layer_guid=*/layer_guid, - /*per_device_op_state=*/std::nullopt, - }, - /*outputs=*/ - { - {mk_slot(TensorSlotName::INPUT, grad_role), v1_grad}, - {mk_slot(TensorSlotName::WEIGHT, grad_role), v2_grad}, - {mk_slot(TensorSlotName::BIAS, grad_role), v1_grad}, - {mk_slot(TensorSlotName::SCALE, grad_role), v1_grad}, + DynamicNodeInvocation invocation = [&]() -> DynamicNodeInvocation { + return DynamicNodeInvocation{ + /*inputs=*/{ + {mk_slot(TensorSlotName::INPUT, std::nullopt), v1}, + {mk_slot(TensorSlotName::WEIGHT, std::nullopt), v2}, + {mk_slot(TensorSlotName::BIAS, std::nullopt), v1}, + }, + /*node_attrs=*/ + DynamicNodeAttrs{ + /*task_type=*/std::nullopt, + /*device_coord=*/std::nullopt, + /*mapping=*/std::nullopt, + /*op_attrs=*/op_attrs, + /*layer_guid=*/layer_guid, + /*per_device_op_state=*/std::nullopt, + }, + /*outputs=*/ + { + {mk_slot(TensorSlotName::OUTPUT, std::nullopt), v3}, + }, + }; + }(); + + DynamicNodeInvocation result = + perform_bwd_pass_expansion_for_invocation(invocation); + + DynamicNodeInvocation correct = [&]() -> DynamicNodeInvocation { + return DynamicNodeInvocation{ + /*inputs=*/{ + {mk_slot(TensorSlotName::INPUT, fwd_role), v1_fwd}, + {mk_slot(TensorSlotName::WEIGHT, fwd_role), v2_fwd}, + {mk_slot(TensorSlotName::BIAS, fwd_role), v1_fwd}, + {mk_slot(TensorSlotName::OUTPUT, fwd_role), v3_fwd}, + {mk_slot(TensorSlotName::OUTPUT, grad_role), v3_grad}, + }, + /*node_attrs=*/ + DynamicNodeAttrs{ + /*pass_type=*/DynamicTaskType::BWD, + /*device_coord=*/std::nullopt, + /*mapping=*/std::nullopt, + /*op_attrs=*/op_attrs, + /*layer_guid=*/layer_guid, + /*per_device_op_state=*/std::nullopt, + }, + /*outputs=*/ + { + {mk_slot(TensorSlotName::INPUT, grad_role), v1_grad}, + {mk_slot(TensorSlotName::WEIGHT, grad_role), v2_grad}, + {mk_slot(TensorSlotName::BIAS, grad_role), v1_grad}, + }, + }; + }(); + + ASSERT(result == correct); + } + + SUBCASE("replicate operator optimization") { + TrainingOperationAttrs op_attrs = TrainingOperationAttrs{ + PCGOperatorAttrs{ + ReplicateAttrs{ + /*replicate_degree=*/2_p, + }, }, }; - }(); - ASSERT(result == correct); + DynamicNodeInvocation invocation = [&]() -> DynamicNodeInvocation { + return DynamicNodeInvocation{ + /*inputs=*/{ + {mk_slot(TensorSlotName::INPUT, std::nullopt), v1}, + }, + /*node_attrs=*/ + DynamicNodeAttrs{ + /*task_type=*/std::nullopt, + /*device_coord=*/std::nullopt, + /*mapping=*/std::nullopt, + /*op_attrs=*/op_attrs, + /*layer_guid=*/layer_guid, + /*per_device_op_state=*/std::nullopt, + }, + /*outputs=*/ + { + {mk_slot(TensorSlotName::OUTPUT, std::nullopt), v2}, + }, + }; + }(); + + DynamicNodeInvocation result = + perform_bwd_pass_expansion_for_invocation(invocation); + + DynamicNodeInvocation correct = [&]() -> DynamicNodeInvocation { + DynamicTensorRole fwd_role = DynamicTensorRole{FwbTensorType::FORWARD}; + DynamicTensorRole grad_role = + DynamicTensorRole{FwbTensorType::GRADIENT}; + + return DynamicNodeInvocation{ + /*inputs=*/{ + {mk_slot(TensorSlotName::OUTPUT, fwd_role), v2_fwd}, + {mk_slot(TensorSlotName::OUTPUT, grad_role), v2_grad}, + }, + /*node_attrs=*/ + DynamicNodeAttrs{ + /*pass_type=*/DynamicTaskType::BWD, + /*device_coord=*/std::nullopt, + /*mapping=*/std::nullopt, + /*op_attrs=*/op_attrs, + /*layer_guid=*/layer_guid, + /*per_device_op_state=*/std::nullopt, + }, + /*outputs=*/ + { + {mk_slot(TensorSlotName::INPUT, grad_role), v1_grad}, + }, + }; + }(); + + ASSERT(result == correct); + } } TEST_CASE("perform_pass_expansion(DynamicOpenDataflowGraph)") { auto mk_node_attrs = [](size_t layer_id, + TrainingOperationAttrs const &op_attrs, std::optional const &pass_type) -> DynamicNodeAttrs { return DynamicNodeAttrs{ /*pass_type=*/pass_type, /*device_coord=*/std::nullopt, /*mapping=*/std::nullopt, - /*op_attrs=*/std::nullopt, + /*op_attrs=*/op_attrs, /*layer_guid=*/ dynamic_layer_guid_t{parallel_layer_guid_t{Node{layer_id}}}, /*per_device_op_state=*/std::nullopt, @@ -236,46 +323,73 @@ TEST_SUITE(FF_TEST_SUITE) { }; }; + TrainingOperationAttrs input_op_attrs = TrainingOperationAttrs{ + PCGOperatorAttrs{ + InputAttrs{ + TensorShape{ + TensorDims{ + FFOrdered{ + 4_p, + 8_p, + }, + }, + DataType::FLOAT, + }, + }, + }, + }; + + TrainingOperationAttrs relu_op_attrs = TrainingOperationAttrs{ + PCGOperatorAttrs{ + make_relu_attrs(), + }, + }; + DynamicOpenDataflowGraph input = [&]() -> DynamicOpenDataflowGraph { - DynamicNodeAttrs n1 = mk_node_attrs(10, std::nullopt); - DynamicNodeAttrs n2 = mk_node_attrs(11, std::nullopt); + DynamicNodeAttrs n1 = mk_node_attrs(10, input_op_attrs, std::nullopt); + DynamicNodeAttrs n2 = mk_node_attrs(11, relu_op_attrs, std::nullopt); DynamicValueAttrs v1 = mk_value_attrs(0, std::nullopt); DynamicValueAttrs v2 = mk_value_attrs(1, std::nullopt); std::unordered_set invocation_set = { DynamicNodeInvocation{ - /*inputs=*/std::unordered_map{}, + /*inputs=*/std::map{}, /*node_attrs=*/n1, /*outputs=*/ - std::unordered_map{ + std::map{ { DynamicTensorSlot{ /*slot_name=*/TensorSlotName::OUTPUT, /*slot_tensor_role=*/std::nullopt, + /*task_shard=*/std::nullopt, }, v1, }, }, }, DynamicNodeInvocation{ - /*inputs=*/std::unordered_map{ - {DynamicTensorSlot{ + /*inputs=*/std::map{ + { + DynamicTensorSlot{ /*slot_name=*/TensorSlotName::INPUT, /*slot_tensor_role=*/std::nullopt, - }, - v1}, + /*task_shard=*/std::nullopt, + }, + v1, + }, }, /*node_attrs=*/n2, /*outputs=*/ - std::unordered_map{ - {DynamicTensorSlot{ - /*slot_name=*/TensorSlotName::OUTPUT, - /*slot_tensor_role=*/std::nullopt, - }, - v2}, + std::map{ + { + DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::OUTPUT, + /*slot_tensor_role=*/std::nullopt, + /*task_shard=*/std::nullopt, + }, + v2, + }, }, }, }; @@ -286,10 +400,14 @@ TEST_SUITE(FF_TEST_SUITE) { DynamicOpenDataflowGraph result = perform_pass_expansion(input); DynamicOpenDataflowGraph correct = [&]() -> DynamicOpenDataflowGraph { - DynamicNodeAttrs n1_fwd = mk_node_attrs(10, DynamicTaskType::FWD); - DynamicNodeAttrs n2_fwd = mk_node_attrs(11, DynamicTaskType::FWD); - DynamicNodeAttrs n1_bwd = mk_node_attrs(10, DynamicTaskType::BWD); - DynamicNodeAttrs n2_bwd = mk_node_attrs(11, DynamicTaskType::BWD); + DynamicNodeAttrs n1_fwd = + mk_node_attrs(10, input_op_attrs, DynamicTaskType::FWD); + DynamicNodeAttrs n2_fwd = + mk_node_attrs(11, relu_op_attrs, DynamicTaskType::FWD); + DynamicNodeAttrs n1_bwd = + mk_node_attrs(10, input_op_attrs, DynamicTaskType::BWD); + DynamicNodeAttrs n2_bwd = + mk_node_attrs(11, relu_op_attrs, DynamicTaskType::BWD); DynamicValueAttrs v1_activation = mk_value_attrs(0, mk_dynamic_tensor_role_fwd()); @@ -302,48 +420,51 @@ TEST_SUITE(FF_TEST_SUITE) { std::unordered_set invocation_set = { DynamicNodeInvocation{ - /*inputs=*/std::unordered_map{}, + /*inputs=*/std::map{}, /*node_attrs=*/n1_fwd, /*outputs=*/ - std::unordered_map{ + std::map{ std::pair{ DynamicTensorSlot{ /*slot_name=*/TensorSlotName::OUTPUT, /*slot_tensor_role=*/mk_dynamic_tensor_role_fwd(), + /*task_shard=*/std::nullopt, }, v1_activation, }, }, }, DynamicNodeInvocation{ - /*inputs=*/std::unordered_map{ + /*inputs=*/std::map{ std::pair{ DynamicTensorSlot{ TensorSlotName::INPUT, mk_dynamic_tensor_role_fwd(), + /*task_shard=*/std::nullopt, }, v1_activation, }, }, /*node_attrs=*/n2_fwd, /*outputs=*/ - std::unordered_map{ + std::map{ std::pair{ DynamicTensorSlot{ TensorSlotName::OUTPUT, mk_dynamic_tensor_role_fwd(), + /*task_shard=*/std::nullopt, }, v2_activation, }, }, }, DynamicNodeInvocation{ - /*inputs=*/std::unordered_map{ + /*inputs=*/std::map{ std::pair{ DynamicTensorSlot{ TensorSlotName::INPUT, mk_dynamic_tensor_role_fwd(), + /*task_shard=*/std::nullopt, }, v1_activation, }, @@ -351,6 +472,7 @@ TEST_SUITE(FF_TEST_SUITE) { DynamicTensorSlot{ TensorSlotName::OUTPUT, mk_dynamic_tensor_role_fwd(), + /*task_shard=*/std::nullopt, }, v2_activation, }, @@ -358,17 +480,19 @@ TEST_SUITE(FF_TEST_SUITE) { DynamicTensorSlot{ TensorSlotName::OUTPUT, mk_dynamic_tensor_role_bwd(), + /*task_shard=*/std::nullopt, }, v2_gradient, }, }, /*node_attrs=*/n2_bwd, /*outputs=*/ - std::unordered_map{ + std::map{ std::pair{ DynamicTensorSlot{ TensorSlotName::INPUT, mk_dynamic_tensor_role_bwd(), + /*task_shard=*/std::nullopt, }, v1_gradient, }, diff --git a/lib/task-spec/test/src/task-spec/dynamic_graph/shard_expansion.cc b/lib/task-spec/test/src/task-spec/dynamic_graph/shard_expansion.cc index efe21146db..7f23053943 100644 --- a/lib/task-spec/test/src/task-spec/dynamic_graph/shard_expansion.cc +++ b/lib/task-spec/test/src/task-spec/dynamic_graph/shard_expansion.cc @@ -4,8 +4,13 @@ #include "task-spec/dynamic_graph/dynamic_copy_layer_guid_t.dtg.h" #include "task-spec/dynamic_graph/training_operation_attrs.dtg.h" #include "test/utils/doctest/fmt/unordered_set.h" -#include "utils/bidict/algorithms/filter_keys.h" #include +#include "task-spec/dynamic_graph/dynamic_tensor_role.h" +#include "op-attrs/ops/element_unary.h" +#include "utils/one_to_many/one_to_many_filter_keys.h" +#include "utils/one_to_many/one_to_many_filter_values.h" +#include "utils/containers/map_from_pairs.h" +#include "utils/containers/binary_merge_disjoint_maps.h" using namespace ::FlexFlow; @@ -33,25 +38,30 @@ static ParallelTensorSpaceCoordinate mk_pt_coord(nonnegative_int idx1, }; }; -DynamicTensorSlot mk_slot(TensorSlotName const &slot_name) { +DynamicTensorSlot mk_slot(TensorSlotName const &slot_name, + std::optional const &task_shard = std::nullopt) { return DynamicTensorSlot{ /*slot_name=*/slot_name, /*slot_tensor_role=*/std::nullopt, + /*task_shard=*/task_shard, }; }; DynamicValueAttrs mk_value(size_t src_node_id, TensorSlotName src_slot_name, - bidict - tensor_binding, - std::optional const &shard_coord) { + OneToMany const &tensor_binding, + std::optional const &shard_coord, + std::optional const &role = std::nullopt) { + + OneToMany mapping = tensor_binding; if (shard_coord.has_value()) { - tensor_binding = filter_keys(tensor_binding, - [&](ParallelTensorSpaceCoordinate const &p) { - return p == shard_coord.value(); - }); + mapping = one_to_many_filter_keys(mapping, + [&](ParallelTensorSpaceCoordinate const &p) { + return p == shard_coord.value(); + }); } + return DynamicValueAttrs{ /*tensor_guid=*/dynamic_tensor_guid_t{parallel_tensor_guid_t{ KwargDataflowOutput{ @@ -61,173 +71,148 @@ DynamicValueAttrs }}, /*parallel_tensor_shape=*/std::nullopt, /*shard_coord=*/shard_coord, - /*mapping=*/ - tensor_binding, + /*mapping=*/mapping, /*accessor=*/std::nullopt, - /*role=*/std::nullopt, + /*role=*/role, }; }; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("perform_shard_expansion_for_invocation") { - auto mk_shard_binding = [&](ParallelTensorSpaceCoordinate const &c1, - ParallelTensorSpaceCoordinate const &c2, - ParallelTensorSpaceCoordinate const &c3, - ParallelTensorSpaceCoordinate const &c4) - -> OperatorAtomicTaskShardBinding { - return OperatorAtomicTaskShardBinding{ - /*tensor_coords=*/{ - { - TensorSlotName::INPUT, - c1, - }, - { - TensorSlotName::WEIGHT, - c2, - }, - { - TensorSlotName::OUTPUT_1, - c3, - }, - { - TensorSlotName::OUTPUT_2, - c4, - }, - }, - }; - }; - - MachineSpaceCoordinate mc1 = mk_machine_coord(0_n, 0_n); - MachineSpaceCoordinate mc2 = mk_machine_coord(2_n, 0_n); - - ParallelTensorSpaceCoordinate mc1_input_coord = - mk_pt_coord(0_n, 0_n, 0_n, 0_n); - ParallelTensorSpaceCoordinate mc1_weight_coord = - mk_pt_coord(0_n, 1_n, 2_n, 0_n); - ParallelTensorSpaceCoordinate mc1_output_1_coord = - mk_pt_coord(1_n, 0_n, 0_n, 1_n); - ParallelTensorSpaceCoordinate mc1_output_2_coord = - mk_pt_coord(3_n, 0_n, 0_n, 0_n); - - ParallelTensorSpaceCoordinate mc2_input_coord = - mk_pt_coord(0_n, 1_n, 0_n, 0_n); - ParallelTensorSpaceCoordinate mc2_weight_coord = - mk_pt_coord(0_n, 4_n, 2_n, 0_n); - ParallelTensorSpaceCoordinate mc2_output_1_coord = - mk_pt_coord(1_n, 2_n, 0_n, 1_n); - ParallelTensorSpaceCoordinate mc2_output_2_coord = - mk_pt_coord(0_n, 0_n, 0_n, 0_n); - - MappedOperatorTaskGroup mapped_task_group = MappedOperatorTaskGroup{ - bidict{ - { - mc1, - mk_shard_binding(mc1_input_coord, - mc1_weight_coord, - mc1_output_1_coord, - mc1_output_2_coord), - }, - { - mc2, - mk_shard_binding(mc2_input_coord, - mc2_weight_coord, - mc2_output_1_coord, - mc2_output_2_coord), - }, - }, - }; - + TEST_CASE("generate_shard_expansion_for_invocation") { auto mk_op_value = [&](size_t src_node_id, TensorSlotName src_slot_name, TensorSlotName use_slot_name, - std::optional const &shard_coord) + MappedOperatorTaskGroup const &mapped_task_group, + std::optional const &shard_coord, + std::optional const &role = std::nullopt) -> DynamicValueAttrs { - bidict + OneToMany tensor_binding = get_tensor_bindings_for_slot_name(mapped_task_group, use_slot_name); - return mk_value(src_node_id, src_slot_name, tensor_binding, shard_coord); + return mk_value(src_node_id, src_slot_name, tensor_binding, shard_coord, role); }; - DynamicNodeInvocation input = DynamicNodeInvocation{ - /*inputs=*/{ - { - mk_slot(TensorSlotName::INPUT), - mk_op_value(0, - TensorSlotName::OUTPUT, - TensorSlotName::INPUT, - std::nullopt), - }, - { - mk_slot(TensorSlotName::WEIGHT), - mk_op_value(1, - TensorSlotName::OUTPUT, - TensorSlotName::WEIGHT, - std::nullopt), - }, + auto mk_sharding_info = [&](TensorSlotName slot_name, + ParallelTensorSpaceCoordinate const &shard_coord, + MappedOperatorTaskGroup const &mapped_op_task_group, + MachineSpaceCoordinate const &device_coord) + -> std::pair + { + OneToMany + tensor_binding = get_tensor_bindings_for_slot_name(mapped_op_task_group, + slot_name); + return std::pair{ + mk_slot(slot_name), + DynamicValueAttrsShardingInfo{ + /*shard_coord=*/shard_coord, + /*mapping=*/one_to_many_filter_values(tensor_binding, + [&](MachineSpaceCoordinate const &c) -> bool { + return device_coord == c; + }), }, - /*node_attrs=*/ - DynamicNodeAttrs{ - /*task_type=*/std::nullopt, - /*device_coord=*/std::nullopt, - /*mapping=*/mapped_task_group, - /*op_attrs=*/std::nullopt, - /*layer_guid=*/ - dynamic_layer_guid_t{parallel_layer_guid_t{Node{20}}}, - /*per_device_op_state=*/std::nullopt, - }, - /*outputs=*/ - { - { - mk_slot(TensorSlotName::OUTPUT_1), - mk_op_value(20, - TensorSlotName::OUTPUT_1, - TensorSlotName::OUTPUT_1, - std::nullopt), - }, - { - mk_slot(TensorSlotName::OUTPUT_2), - mk_op_value(20, - TensorSlotName::OUTPUT_2, - TensorSlotName::OUTPUT_2, - std::nullopt), + }; + }; + + SUBCASE("standard operator") { + MachineSpaceCoordinate mc1 = mk_machine_coord(0_n, 0_n); + MachineSpaceCoordinate mc2 = mk_machine_coord(2_n, 0_n); + + auto mk_shard_binding = [&](ParallelTensorSpaceCoordinate const &c1, + ParallelTensorSpaceCoordinate const &c2, + ParallelTensorSpaceCoordinate const &c3, + ParallelTensorSpaceCoordinate const &c4) + -> OperatorAtomicTaskShardBinding { + return OperatorAtomicTaskShardBinding{ + /*tensor_coords=*/{ + { + TensorSlotName::INPUT, + c1, + }, + { + TensorSlotName::WEIGHT, + c2, + }, + { + TensorSlotName::OUTPUT_1, + c3, + }, + { + TensorSlotName::OUTPUT_2, + c4, + }, }, + }; + }; + + ParallelTensorSpaceCoordinate mc1_input_coord = + mk_pt_coord(0_n, 0_n, 0_n, 0_n); + ParallelTensorSpaceCoordinate mc1_weight_coord = + mk_pt_coord(0_n, 1_n, 2_n, 0_n); + ParallelTensorSpaceCoordinate mc1_output_1_coord = + mk_pt_coord(1_n, 0_n, 0_n, 1_n); + ParallelTensorSpaceCoordinate mc1_output_2_coord = + mk_pt_coord(3_n, 0_n, 0_n, 0_n); + + ParallelTensorSpaceCoordinate mc2_input_coord = + mk_pt_coord(0_n, 1_n, 0_n, 0_n); + ParallelTensorSpaceCoordinate mc2_weight_coord = + mk_pt_coord(0_n, 4_n, 2_n, 0_n); + ParallelTensorSpaceCoordinate mc2_output_1_coord = + mk_pt_coord(1_n, 2_n, 0_n, 1_n); + ParallelTensorSpaceCoordinate mc2_output_2_coord = + mk_pt_coord(0_n, 0_n, 0_n, 0_n); + + TrainingOperationAttrs op_attrs = TrainingOperationAttrs{ + PCGOperatorAttrs{ + make_relu_attrs(), }, - }; + }; + + MappedOperatorTaskGroup mapped_task_group = MappedOperatorTaskGroup{ + bidict{ + { + mc1, + mk_shard_binding(mc1_input_coord, + mc1_weight_coord, + mc1_output_1_coord, + mc1_output_2_coord), + }, + { + mc2, + mk_shard_binding(mc2_input_coord, + mc2_weight_coord, + mc2_output_1_coord, + mc2_output_2_coord), + }, + }, + }; - std::unordered_set result = - perform_shard_expansion_for_invocation(input); - - auto mk_invocation_shard = - [&](MachineSpaceCoordinate const &device_coord, - ParallelTensorSpaceCoordinate const &input_shard_coord, - ParallelTensorSpaceCoordinate const &weight_shard_coord, - ParallelTensorSpaceCoordinate const &output_1_shard_coord, - ParallelTensorSpaceCoordinate const &output_2_shard_coord) - -> DynamicNodeInvocation { - return DynamicNodeInvocation{ + DynamicNodeInvocation input = DynamicNodeInvocation{ /*inputs=*/{ { mk_slot(TensorSlotName::INPUT), mk_op_value(0, TensorSlotName::OUTPUT, TensorSlotName::INPUT, - input_shard_coord), + mapped_task_group, + std::nullopt), }, { mk_slot(TensorSlotName::WEIGHT), mk_op_value(1, TensorSlotName::OUTPUT, TensorSlotName::WEIGHT, - weight_shard_coord), + mapped_task_group, + std::nullopt), }, }, /*node_attrs=*/ DynamicNodeAttrs{ /*task_type=*/std::nullopt, - /*device_coord=*/device_coord, + /*device_coord=*/std::nullopt, /*mapping=*/mapped_task_group, - /*op_attrs=*/std::nullopt, + /*op_attrs=*/op_attrs, /*layer_guid=*/ dynamic_layer_guid_t{parallel_layer_guid_t{Node{20}}}, /*per_device_op_state=*/std::nullopt, @@ -239,112 +224,427 @@ TEST_SUITE(FF_TEST_SUITE) { mk_op_value(20, TensorSlotName::OUTPUT_1, TensorSlotName::OUTPUT_1, - output_1_shard_coord), + mapped_task_group, + std::nullopt), }, { mk_slot(TensorSlotName::OUTPUT_2), mk_op_value(20, TensorSlotName::OUTPUT_2, TensorSlotName::OUTPUT_2, - output_2_shard_coord), + mapped_task_group, + std::nullopt), }, }, }; - }; - std::unordered_set correct = { - mk_invocation_shard(mc1, - mc1_input_coord, - mc1_weight_coord, - mc1_output_1_coord, - mc1_output_2_coord), - mk_invocation_shard(mc2, - mc2_input_coord, - mc2_weight_coord, - mc2_output_1_coord, - mc2_output_2_coord), - }; + std::unordered_set result = + generate_shard_expansion_for_invocation(input); - CHECK(result.size() == correct.size()); - CHECK(result == correct); - } + auto mk_invocation_shard = + [&](MachineSpaceCoordinate const &device_coord, + ParallelTensorSpaceCoordinate const &input_shard_coord, + ParallelTensorSpaceCoordinate const &weight_shard_coord, + ParallelTensorSpaceCoordinate const &output_1_shard_coord, + ParallelTensorSpaceCoordinate const &output_2_shard_coord) + -> DynamicNodeInvocationShardingInfo { + return DynamicNodeInvocationShardingInfo{ + /*device_coord=*/nonempty_set{device_coord}, + /*value_sharding=*/{ + mk_sharding_info(TensorSlotName::INPUT, input_shard_coord, mapped_task_group, device_coord), + mk_sharding_info(TensorSlotName::WEIGHT, weight_shard_coord, mapped_task_group, device_coord), + mk_sharding_info(TensorSlotName::OUTPUT_1, output_1_shard_coord, mapped_task_group, device_coord), + mk_sharding_info(TensorSlotName::OUTPUT_2, output_2_shard_coord, mapped_task_group, device_coord), + }, + }; + }; - TEST_CASE("perform_shard_expansion_for_invocation (copy)") { - MachineSpaceCoordinate mc1 = mk_machine_coord(0_n, 0_n); - MachineSpaceCoordinate mc2 = mk_machine_coord(1_n, 0_n); - MachineSpaceCoordinate mc3 = mk_machine_coord(2_n, 0_n); - MachineSpaceCoordinate mc4 = mk_machine_coord(3_n, 0_n); + std::unordered_set correct = { + mk_invocation_shard(mc1, + mc1_input_coord, + mc1_weight_coord, + mc1_output_1_coord, + mc1_output_2_coord), + mk_invocation_shard(mc2, + mc2_input_coord, + mc2_weight_coord, + mc2_output_1_coord, + mc2_output_2_coord), + }; - ParallelTensorSpaceCoordinate pt1 = mk_pt_coord(0_n, 0_n, 0_n, 0_n); - ParallelTensorSpaceCoordinate pt2 = mk_pt_coord(0_n, 1_n, 0_n, 0_n); + nlohmann::json result_json = result; + nlohmann::json correct_json = correct; - bidict src_binding{ - {pt1, mc1}, - {pt2, mc2}, - }; - bidict dst_binding{ - {pt1, mc3}, - {pt2, mc4}, - }; + CHECK(result.size() == correct.size()); + CHECK(result_json == correct_json); + CHECK(result == correct); + } + + SUBCASE("copy operator") { + MachineSpaceCoordinate mc1 = mk_machine_coord(0_n, 0_n); + MachineSpaceCoordinate mc2 = mk_machine_coord(1_n, 0_n); + MachineSpaceCoordinate mc3 = mk_machine_coord(2_n, 0_n); + MachineSpaceCoordinate mc4 = mk_machine_coord(3_n, 0_n); + + ParallelTensorSpaceCoordinate pt1 = mk_pt_coord(0_n, 0_n, 0_n, 0_n); + ParallelTensorSpaceCoordinate pt2 = mk_pt_coord(0_n, 1_n, 0_n, 0_n); + + OneToMany src_binding{ + {pt1, {mc1}}, + {pt2, {mc2}}, + }; - DynamicNodeInvocation input = DynamicNodeInvocation{ - /*inputs=*/{ + OneToMany dst_binding{ + {pt1, {mc3}}, + {pt2, {mc4}}, + }; + + DynamicNodeInvocation input = DynamicNodeInvocation{ + /*inputs=*/{ + { + mk_slot(TensorSlotName::INPUT), + mk_value(0, TensorSlotName::OUTPUT, src_binding, std::nullopt), + }, + }, + /*node_attrs=*/ + DynamicNodeAttrs{ + /*task_type=*/std::nullopt, + /*device_coord=*/std::nullopt, + /*mapping=*/std::nullopt, + /*op_attrs=*/TrainingOperationAttrs{CopyAttrs{}}, + /*layer_guid=*/dynamic_layer_guid_t{dynamic_copy_layer_guid_t{}}, + /*per_device_op_state=*/std::nullopt, + }, + /*outputs=*/ + { + { + mk_slot(TensorSlotName::OUTPUT), + mk_value(20, TensorSlotName::OUTPUT, dst_binding, std::nullopt), + }, + }, + }; + + std::unordered_set result = + generate_shard_expansion_for_invocation(input); + + auto mk_invocation_shard = + [&](MachineSpaceCoordinate const &device_coord, + ParallelTensorSpaceCoordinate const &tensor_shard_coord) + -> DynamicNodeInvocationShardingInfo { + + return DynamicNodeInvocationShardingInfo{ + /*device_coord=*/nonempty_set{device_coord}, + /*value_sharding=*/std::map{ { - mk_slot(TensorSlotName::INPUT), - mk_value(0, TensorSlotName::OUTPUT, src_binding, std::nullopt), + mk_slot(TensorSlotName::INPUT), + DynamicValueAttrsShardingInfo{ + tensor_shard_coord, + one_to_many_filter_keys(src_binding, + [&](ParallelTensorSpaceCoordinate const &pt_coord) -> bool { + return pt_coord == tensor_shard_coord; + }), + }, }, - }, - /*node_attrs=*/ - DynamicNodeAttrs{ - /*task_type=*/std::nullopt, - /*device_coord=*/std::nullopt, - /*mapping=*/std::nullopt, - /*op_attrs=*/TrainingOperationAttrs{CopyAttrs{}}, - /*layer_guid=*/dynamic_layer_guid_t{dynamic_copy_layer_guid_t{}}, - /*per_device_op_state=*/std::nullopt, - }, - /*outputs=*/ - { { - mk_slot(TensorSlotName::OUTPUT), - mk_value(20, TensorSlotName::OUTPUT, dst_binding, std::nullopt), + mk_slot(TensorSlotName::OUTPUT), + DynamicValueAttrsShardingInfo{ + tensor_shard_coord, + one_to_many_filter_keys(dst_binding, + [&](ParallelTensorSpaceCoordinate const &pt_coord) -> bool { + return pt_coord == tensor_shard_coord; + }), + }, }, - }, - }; + }, + }; + }; - std::unordered_set result = - perform_shard_expansion_for_invocation(input); + std::unordered_set correct = { + mk_invocation_shard(mc1, pt1), + mk_invocation_shard(mc2, pt2), + }; - auto mk_invocation_shard = - [&](MachineSpaceCoordinate const &device_coord, - ParallelTensorSpaceCoordinate const &tensor_shard_coord) - -> DynamicNodeInvocation { - DynamicNodeInvocation result = input; - result.inputs = { - { - mk_slot(TensorSlotName::INPUT), - mk_value( - 0, TensorSlotName::OUTPUT, src_binding, tensor_shard_coord), - }, + CHECK(result.size() == correct.size()); + CHECK(result == correct); + } + + SUBCASE("replicate operator") { + MachineSpaceCoordinate mc1 = mk_machine_coord(0_n, 0_n); + MachineSpaceCoordinate mc2 = mk_machine_coord(1_n, 0_n); + MachineSpaceCoordinate mc3 = mk_machine_coord(2_n, 0_n); + MachineSpaceCoordinate mc4 = mk_machine_coord(3_n, 0_n); + + ParallelTensorSpaceCoordinate pt1 = mk_pt_coord(0_n, 0_n, 0_n, 0_n); + ParallelTensorSpaceCoordinate pt2 = mk_pt_coord(0_n, 0_n, 0_n, 1_n); + ParallelTensorSpaceCoordinate pt3 = mk_pt_coord(0_n, 1_n, 0_n, 0_n); + ParallelTensorSpaceCoordinate pt4 = mk_pt_coord(0_n, 1_n, 0_n, 1_n); + + auto mk_shard_binding = [&](ParallelTensorSpaceCoordinate const &c1, + ParallelTensorSpaceCoordinate const &c2) + -> OperatorAtomicTaskShardBinding { + return OperatorAtomicTaskShardBinding{ + /*tensor_coords=*/{ + { + TensorSlotName::INPUT, + c1, + }, + { + TensorSlotName::OUTPUT, + c2, + }, + }, + }; }; - // See perform_shard_expansion_for_copy in shard_expansion.cc for explanation of the choice of device placement. - result.node_attrs.device_coord = device_coord; - result.outputs = { - { - mk_slot(TensorSlotName::OUTPUT), - mk_value( - 20, TensorSlotName::OUTPUT, dst_binding, tensor_shard_coord), + + MappedOperatorTaskGroup mapped_task_group = MappedOperatorTaskGroup{ + bidict{ + { + mc1, + mk_shard_binding(pt1, pt1), + }, + { + mc2, + mk_shard_binding(pt1, pt2), + }, + { + mc3, + mk_shard_binding(pt2, pt3), + }, + { + mc4, + mk_shard_binding(pt2, pt4), + }, }, }; - return result; - }; - std::unordered_set correct = { - mk_invocation_shard(mc1, pt1), - mk_invocation_shard(mc2, pt2), - }; + SUBCASE("fwd") { + OneToMany src_binding{ + {pt1, {mc1}}, + {pt2, {mc2}}, + }; + + OneToMany dst_binding{ + {pt1, {mc1}}, + {pt2, {mc2}}, + {pt3, {mc3}}, + {pt4, {mc4}}, + }; + + DynamicNodeInvocation input = DynamicNodeInvocation{ + /*inputs=*/{ + { + DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::INPUT, + /*slot_tensor_role=*/mk_dynamic_tensor_role_fwd(), + /*task_shard=*/std::nullopt, + }, + mk_value(0, TensorSlotName::OUTPUT, src_binding, std::nullopt), + }, + }, + /*node_attrs=*/ + DynamicNodeAttrs{ + /*task_type=*/DynamicTaskType::FWD, + /*device_coords=*/std::nullopt, + /*mapping=*/mapped_task_group, + /*op_attrs=*/TrainingOperationAttrs{ + PCGOperatorAttrs{ + ReplicateAttrs{ + /*replicate_degree=*/2_p, + }, + }, + }, + /*layer_guid=*/dynamic_layer_guid_t{parallel_layer_guid_t{Node{20}}}, + /*per_device_op_state=*/std::nullopt, + }, + /*outputs=*/ + { + { + DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::OUTPUT, + /*slot_tensor_role=*/mk_dynamic_tensor_role_fwd(), + /*task_shard=*/std::nullopt, + }, + mk_value(20, TensorSlotName::OUTPUT, dst_binding, std::nullopt), + }, + }, + }; + + std::unordered_set result = + generate_shard_expansion_for_invocation(input); + + + auto mk_output_binding = [&](MachineSpaceCoordinate const &mc) + -> std::pair + { + return { + DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::OUTPUT, + /*slot_tensor_role=*/mk_dynamic_tensor_role_fwd(), + /*task_shard=*/mc, + }, + DynamicValueAttrsShardingInfo{ + dst_binding.at_r(mc), + one_to_many_filter_keys(dst_binding, + [&](ParallelTensorSpaceCoordinate const &pt_coord) -> bool { + return pt_coord == dst_binding.at_r(mc); + }), + }, + }; + }; + + auto mk_invocation_shard = + [&](nonempty_set const &device_coords, + ParallelTensorSpaceCoordinate const &input_shard_coord, + std::unordered_set const &output_task_shards) + -> DynamicNodeInvocationShardingInfo { + + return DynamicNodeInvocationShardingInfo{ + /*device_coords=*/device_coords, + /*value_sharding=*/ + binary_merge_disjoint_maps( + std::map{ + { + DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::INPUT, + /*slot_tensor_role=*/mk_dynamic_tensor_role_fwd(), + /*task_shard=*/std::nullopt, + }, + DynamicValueAttrsShardingInfo{ + input_shard_coord, + one_to_many_filter_keys( + src_binding, + [&](ParallelTensorSpaceCoordinate const &pt_coord) -> bool { + return pt_coord == input_shard_coord; + }), + }, + }, + }, + map_from_pairs(transform(output_task_shards, mk_output_binding))), + }; + }; + + std::unordered_set correct = { + mk_invocation_shard(nonempty_set{mc1, mc2}, pt1, {mc1, mc2}), + mk_invocation_shard(nonempty_set{mc3, mc4}, pt2, {mc3, mc4}), + }; + + CHECK(result.size() == correct.size()); + CHECK(result == correct); + } + + SUBCASE("bwd") { + OneToMany output_grad_binding{ + {pt1, {mc1}}, + {pt2, {mc2}}, + {pt3, {mc3}}, + {pt4, {mc4}}, + }; + + OneToMany input_grad_binding{ + {pt1, {mc1}}, + {pt2, {mc2}}, + }; + + DynamicNodeInvocation input = DynamicNodeInvocation{ + /*inputs=*/{ + { + DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::OUTPUT, + /*slot_tensor_role=*/mk_dynamic_tensor_role_bwd(), + /*task_shard=*/std::nullopt, + }, + mk_value(0, TensorSlotName::OUTPUT, output_grad_binding, std::nullopt), + }, + }, + /*node_attrs=*/ + DynamicNodeAttrs{ + /*task_type=*/DynamicTaskType::BWD, + /*device_coords=*/std::nullopt, + /*mapping=*/mapped_task_group, + /*op_attrs=*/TrainingOperationAttrs{ + PCGOperatorAttrs{ + ReplicateAttrs{ + /*replicate_degree=*/2_p, + }, + }, + }, + /*layer_guid=*/dynamic_layer_guid_t{parallel_layer_guid_t{Node{20}}}, + /*per_device_op_state=*/std::nullopt, + }, + /*outputs=*/ + { + { + DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::INPUT, + /*slot_tensor_role=*/mk_dynamic_tensor_role_bwd(), + /*task_shard=*/std::nullopt, + }, + mk_value(20, TensorSlotName::INPUT, input_grad_binding, std::nullopt), + }, + }, + }; + + std::unordered_set result = + generate_shard_expansion_for_invocation(input); + + auto mk_output_grad_binding = [&](MachineSpaceCoordinate const &mc) + -> std::pair + { + return { + DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::OUTPUT, + /*slot_tensor_role=*/mk_dynamic_tensor_role_bwd(), + /*task_shard=*/mc, + }, + DynamicValueAttrsShardingInfo{ + output_grad_binding.at_r(mc), + one_to_many_filter_keys(output_grad_binding, + [&](ParallelTensorSpaceCoordinate const &pt_coord) -> bool { + return pt_coord == output_grad_binding.at_r(mc); + }), + }, + }; + }; + + auto mk_invocation_shard = + [&](nonempty_set const &device_coords, + std::unordered_set const &output_grad_task_shards, + ParallelTensorSpaceCoordinate const &input_grad_shard_coord) + -> DynamicNodeInvocationShardingInfo { + + return DynamicNodeInvocationShardingInfo{ + /*device_coords=*/device_coords, + /*value_sharding=*/ + binary_merge_disjoint_maps( + std::map{ + { + DynamicTensorSlot{ + /*slot_name=*/TensorSlotName::INPUT, + /*slot_tensor_role=*/mk_dynamic_tensor_role_bwd(), + /*task_shard=*/std::nullopt, + }, + DynamicValueAttrsShardingInfo{ + input_grad_shard_coord, + one_to_many_filter_keys( + input_grad_binding, + [&](ParallelTensorSpaceCoordinate const &pt_coord) -> bool { + return pt_coord == input_grad_shard_coord; + }), + }, + }, + }, + map_from_pairs(transform(output_grad_task_shards, mk_output_grad_binding))), + }; + }; + + std::unordered_set correct = { + mk_invocation_shard(nonempty_set{mc1, mc2}, {mc1, mc2}, pt1), + mk_invocation_shard(nonempty_set{mc3, mc4}, {mc3, mc4}, pt2), + }; - CHECK(result.size() == correct.size()); - CHECK(result == correct); + CHECK(result.size() == correct.size()); + CHECK(result == correct); + } + } } } diff --git a/lib/utils/include/utils/archetypes/jsonable_ordered_value_type.h b/lib/utils/include/utils/archetypes/jsonable_ordered_value_type.h new file mode 100644 index 0000000000..ad43cd52b6 --- /dev/null +++ b/lib/utils/include/utils/archetypes/jsonable_ordered_value_type.h @@ -0,0 +1,91 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ARCHETYPES_JSONABLE_ORDERED_VALUE_TYPE_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ARCHETYPES_JSONABLE_ORDERED_VALUE_TYPE_H + +#include +#include +#include +#include +#include +#include + +namespace FlexFlow { + +template +struct jsonable_ordered_value_type { + jsonable_ordered_value_type() = delete; + + jsonable_ordered_value_type(jsonable_ordered_value_type const &) { + PANIC(); + } + jsonable_ordered_value_type &operator=(jsonable_ordered_value_type const &) { + PANIC(); + } + + jsonable_ordered_value_type(jsonable_ordered_value_type &&) { + PANIC(); + } + jsonable_ordered_value_type &operator=(jsonable_ordered_value_type &&) { + PANIC(); + } + + bool operator==(jsonable_ordered_value_type const &) const { + PANIC(); + } + + bool operator!=(jsonable_ordered_value_type const &) const { + PANIC(); + } + + bool operator<(jsonable_ordered_value_type const &) const { + PANIC(); + } + bool operator>(jsonable_ordered_value_type const &) const { + PANIC(); + } + bool operator<=(jsonable_ordered_value_type const &) const { + PANIC(); + } + bool operator>=(jsonable_ordered_value_type const &) const { + PANIC(); + } +}; + +template +std::string format_as(jsonable_ordered_value_type const &) { + PANIC(); +} + +template +std::ostream &operator<<(std::ostream &s, jsonable_ordered_value_type const &x) { + PANIC(); +} + +} // namespace FlexFlow + +namespace nlohmann { + +template +struct adl_serializer<::FlexFlow::jsonable_ordered_value_type> { + static ::FlexFlow::jsonable_ordered_value_type from_json(json const &) { + PANIC(); + } + + static void to_json(json &, ::FlexFlow::jsonable_ordered_value_type const &) { + PANIC(); + } +}; + +} // namespace nlohmann + +namespace std { + +template +struct hash<::FlexFlow::jsonable_ordered_value_type> { + size_t operator()(::FlexFlow::jsonable_ordered_value_type const &) const { + PANIC(); + }; +}; + +} // namespace std + +#endif diff --git a/lib/utils/include/utils/bidict/algorithms/filter_keys.h b/lib/utils/include/utils/bidict/algorithms/bidict_filter_keys.h similarity index 53% rename from lib/utils/include/utils/bidict/algorithms/filter_keys.h rename to lib/utils/include/utils/bidict/algorithms/bidict_filter_keys.h index 2734dfaeb5..4c2ceb840b 100644 --- a/lib/utils/include/utils/bidict/algorithms/filter_keys.h +++ b/lib/utils/include/utils/bidict/algorithms/bidict_filter_keys.h @@ -1,12 +1,12 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_FILTER_KEYS_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_FILTER_KEYS_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_BIDICT_FILTER_KEYS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_BIDICT_FILTER_KEYS_H #include "utils/bidict/bidict.h" namespace FlexFlow { template -bidict filter_keys(bidict const &m, F &&f) { +bidict bidict_filter_keys(bidict const &m, F &&f) { bidict result; for (auto const &kv : m) { if (f(kv.first)) { diff --git a/lib/utils/include/utils/bidict/algorithms/filter_values.h b/lib/utils/include/utils/bidict/algorithms/bidict_filter_values.h similarity index 53% rename from lib/utils/include/utils/bidict/algorithms/filter_values.h rename to lib/utils/include/utils/bidict/algorithms/bidict_filter_values.h index 5817578e79..cb968f2d02 100644 --- a/lib/utils/include/utils/bidict/algorithms/filter_values.h +++ b/lib/utils/include/utils/bidict/algorithms/bidict_filter_values.h @@ -1,12 +1,12 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_FILTER_VALUES_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_FILTER_VALUES_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_BIDICT_FILTER_VALUES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_BIDICT_FILTER_VALUES_H #include "utils/bidict/bidict.h" namespace FlexFlow { template -bidict filter_values(bidict const &m, F &&f) { +bidict bidict_filter_values(bidict const &m, F &&f) { bidict result; for (auto const &kv : m) { if (f(kv.second)) { diff --git a/lib/utils/include/utils/bidict/algorithms/filtrans_keys.h b/lib/utils/include/utils/bidict/algorithms/bidict_filtrans_keys.h similarity index 64% rename from lib/utils/include/utils/bidict/algorithms/filtrans_keys.h rename to lib/utils/include/utils/bidict/algorithms/bidict_filtrans_keys.h index df6495b400..bd9018bd38 100644 --- a/lib/utils/include/utils/bidict/algorithms/filtrans_keys.h +++ b/lib/utils/include/utils/bidict/algorithms/bidict_filtrans_keys.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_FILTRANS_KEYS_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_FILTRANS_KEYS_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_BIDICT_FILTRANS_KEYS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_BIDICT_FILTRANS_KEYS_H #include "utils/bidict/bidict.h" @@ -9,7 +9,7 @@ template ::value_type> -bidict filtrans_keys(bidict const &m, F &&f) { +bidict bidict_filtrans_keys(bidict const &m, F &&f) { bidict result; for (auto const &[k, v] : m) { std::optional new_k = f(k); diff --git a/lib/utils/include/utils/bidict/algorithms/filtrans_values.h b/lib/utils/include/utils/bidict/algorithms/bidict_filtrans_values.h similarity index 63% rename from lib/utils/include/utils/bidict/algorithms/filtrans_values.h rename to lib/utils/include/utils/bidict/algorithms/bidict_filtrans_values.h index 11180938b8..592440f7f6 100644 --- a/lib/utils/include/utils/bidict/algorithms/filtrans_values.h +++ b/lib/utils/include/utils/bidict/algorithms/bidict_filtrans_values.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_FILTRANS_VALUES_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_FILTRANS_VALUES_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_BIDICT_FILTRANS_VALUES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_BIDICT_FILTRANS_VALUES_H #include "utils/bidict/bidict.h" @@ -9,7 +9,7 @@ template ::value_type> -bidict filtrans_values(bidict const &m, F &&f) { +bidict bidict_filtrans_values(bidict const &m, F &&f) { bidict result; for (auto const &[k, v] : m) { std::optional new_v = f(v); diff --git a/lib/utils/include/utils/bidict/algorithms/unordered_set_of.h b/lib/utils/include/utils/bidict/algorithms/bidict_unordered_set_of.h similarity index 51% rename from lib/utils/include/utils/bidict/algorithms/unordered_set_of.h rename to lib/utils/include/utils/bidict/algorithms/bidict_unordered_set_of.h index b3df2514cf..251573d441 100644 --- a/lib/utils/include/utils/bidict/algorithms/unordered_set_of.h +++ b/lib/utils/include/utils/bidict/algorithms/bidict_unordered_set_of.h @@ -1,5 +1,5 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_UNORDERED_SET_OF_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_UNORDERED_SET_OF_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_BIDICT_UNORDERED_SET_OF_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_BIDICT_UNORDERED_SET_OF_H #include "utils/bidict/bidict.h" #include "utils/hash/pair.h" @@ -7,7 +7,7 @@ namespace FlexFlow { template -std::unordered_set> unordered_set_of(bidict const &c) { +std::unordered_set> bidict_unordered_set_of(bidict const &c) { std::unordered_set> result; for (auto const &lr : c) { diff --git a/lib/utils/include/utils/bidict/algorithms/binary_merge_disjoint_bidicts.h b/lib/utils/include/utils/bidict/algorithms/binary_merge_disjoint_bidicts.h new file mode 100644 index 0000000000..5b0bb45910 --- /dev/null +++ b/lib/utils/include/utils/bidict/algorithms/binary_merge_disjoint_bidicts.h @@ -0,0 +1,37 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_BINARY_MERGE_DISJOINT_BIDICTS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_BINARY_MERGE_DISJOINT_BIDICTS_H + +#include "utils/bidict/algorithms/left_entries.h" +#include "utils/bidict/algorithms/right_entries.h" +#include "utils/bidict/bidict.h" +#include "utils/containers/are_disjoint.h" +#include "utils/exception.h" + +namespace FlexFlow { + +template +bidict binary_merge_disjoint_bidicts(bidict const &lhs, + bidict const &rhs) { + if (!are_disjoint(left_entries(lhs), left_entries(rhs))) { + throw mk_runtime_error( + fmt::format("Left entries of {} and {} are non-disjoint", lhs, rhs)); + } + if (!are_disjoint(right_entries(lhs), right_entries(rhs))) { + throw mk_runtime_error( + fmt::format("Right entries of {} and {} are non-disjoint", lhs, rhs)); + } + + bidict result; + for (auto const &kv : lhs) { + result.equate_strict(kv.first, kv.second); + } + for (auto const &kv : rhs) { + result.equate_strict(kv.first, kv.second); + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/bidict/algorithms/merge_disjoint_bidicts.h b/lib/utils/include/utils/bidict/algorithms/merge_disjoint_bidicts.h index 97e7334c26..0c944bb9bd 100644 --- a/lib/utils/include/utils/bidict/algorithms/merge_disjoint_bidicts.h +++ b/lib/utils/include/utils/bidict/algorithms/merge_disjoint_bidicts.h @@ -1,35 +1,21 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_MERGE_DISJOINT_BIDICTS_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_MERGE_DISJOINT_BIDICTS_H -#include "utils/bidict/algorithms/left_entries.h" -#include "utils/bidict/algorithms/right_entries.h" -#include "utils/bidict/bidict.h" -#include "utils/containers/are_disjoint.h" -#include "utils/exception.h" +#include "utils/bidict/algorithms/binary_merge_disjoint_bidicts.h" +#include "utils/containers/foldl.h" namespace FlexFlow { -template -bidict merge_disjoint_bidicts(bidict const &lhs, - bidict const &rhs) { - if (!are_disjoint(left_entries(lhs), left_entries(rhs))) { - throw mk_runtime_error( - fmt::format("Left entries of {} and {} are non-disjoint", lhs, rhs)); - } - if (!are_disjoint(right_entries(lhs), right_entries(rhs))) { - throw mk_runtime_error( - fmt::format("Right entries of {} and {} are non-disjoint", lhs, rhs)); - } - - bidict result; - for (auto const &kv : lhs) { - result.equate(kv.first, kv.second); - } - for (auto const &kv : rhs) { - result.equate(kv.first, kv.second); - } - - return result; +template +bidict merge_disjoint_bidicts(C const &c) { + bidict empty = {}; + return foldl(c, + /*init=*/empty, + [](bidict const &lhs, bidict const &rhs) { + return binary_merge_disjoint_bidicts(lhs, rhs); + }); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/bidict/algorithms/transform_keys.h b/lib/utils/include/utils/bidict/algorithms/transform_keys.h index 8ecb10c401..1d82464d17 100644 --- a/lib/utils/include/utils/bidict/algorithms/transform_keys.h +++ b/lib/utils/include/utils/bidict/algorithms/transform_keys.h @@ -12,7 +12,7 @@ template transform_keys(bidict const &m, F &&f) { bidict result; for (auto const &kv : m) { - result.equate(f(kv.first), kv.second); + result.equate_strict(f(kv.first), kv.second); } return result; } diff --git a/lib/utils/include/utils/bidict/algorithms/transform_values.h b/lib/utils/include/utils/bidict/algorithms/transform_values.h index ef5b34ebe9..fc8655594e 100644 --- a/lib/utils/include/utils/bidict/algorithms/transform_values.h +++ b/lib/utils/include/utils/bidict/algorithms/transform_values.h @@ -12,7 +12,7 @@ template transform_values(bidict const &m, F &&f) { bidict result; for (auto const &kv : m) { - result.equate({kv.first, f(kv.second)}); + result.equate_strict({kv.first, f(kv.second)}); } return result; } diff --git a/lib/utils/include/utils/bidict/algorithms/unstructured_relation_from_bidict.h b/lib/utils/include/utils/bidict/algorithms/unstructured_relation_from_bidict.h index 2ceb527b96..63d25332d3 100644 --- a/lib/utils/include/utils/bidict/algorithms/unstructured_relation_from_bidict.h +++ b/lib/utils/include/utils/bidict/algorithms/unstructured_relation_from_bidict.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_UNSTRUCTURED_RELATION_FROM_BIDICT_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_ALGORITHMS_UNSTRUCTURED_RELATION_FROM_BIDICT_H -#include "utils/bidict/algorithms/unordered_set_of.h" +#include "utils/bidict/algorithms/bidict_unordered_set_of.h" #include "utils/bidict/bidict.h" namespace FlexFlow { @@ -9,7 +9,7 @@ namespace FlexFlow { template std::unordered_set> unstructured_relation_from_bidict(bidict const &b) { - return unordered_set_of(b); + return bidict_unordered_set_of(b); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/bidict/bidict.h b/lib/utils/include/utils/bidict/bidict.h index 5dbd1c603d..7fcc59f116 100644 --- a/lib/utils/include/utils/bidict/bidict.h +++ b/lib/utils/include/utils/bidict/bidict.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_BIDICT_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_BIDICT_BIDICT_H -#include "utils/containers/keys.h" +#include "utils/containers/unordered_keys.h" #include "utils/containers/map_from_keys_and_values.h" #include "utils/fmt/unordered_map.h" #include "utils/hash/unordered_map.h" @@ -13,6 +13,10 @@ #include #include #include +#include "utils/containers/require_same.h" +#include "utils/containers/values.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/containers/contains_key.h" namespace FlexFlow { @@ -65,11 +69,15 @@ struct bidict { void equate(L const &l, R const &r) { fwd_map.insert({l, r}); bwd_map.insert({r, l}); + + this->check_invariants(); } void equate(std::pair const &lr) { fwd_map.insert(lr); bwd_map.insert({lr.second, lr.first}); + + this->check_invariants(); } void equate_strict(L const &l, R const &r) { @@ -87,31 +95,35 @@ struct bidict { } bool operator==(bidict const &other) const { - bool result = this->fwd_map == other.fwd_map; - assert(result == (this->bwd_map == other.bwd_map)); - return result; + return require_same( + (this->fwd_map == other.fwd_map), + (this->bwd_map == other.bwd_map) + ); } bool operator!=(bidict const &other) const { - bool result = this->fwd_map != other.fwd_map; - assert(result == (this->bwd_map != other.bwd_map)); - return result; + return require_same( + (this->fwd_map != other.fwd_map), + (this->bwd_map != other.bwd_map) + ); } R const &at_l(L const &l) const { + ASSERT(contains_key(this->fwd_map, l)); return fwd_map.at(l); } L const &at_r(R const &r) const { + ASSERT(contains_key(this->bwd_map, r)); return bwd_map.at(r); } std::unordered_set left_values() const { - return keys(this->fwd_map); + return unordered_keys(this->fwd_map); } std::unordered_set right_values() const { - return keys(this->bwd_map); + return unordered_keys(this->bwd_map); } std::size_t size() const { @@ -213,11 +225,34 @@ struct bidict { return this->fwd_map; } + std::unordered_map const &l_to_r() const { + return this->fwd_map; + } + + std::unordered_map const &r_to_l() const { + return this->bwd_map; + } + bidict(std::unordered_map const &fwd_map, std::unordered_map const &bwd_map) : fwd_map(fwd_map), bwd_map(bwd_map) {} private: + void check_invariants() const { + std::unordered_set fwd_l_vals = unordered_keys(this->fwd_map); + std::unordered_set bwd_l_vals = unordered_set_of(values(this->bwd_map)); + + std::unordered_set bwd_r_vals = unordered_keys(this->bwd_map); + std::unordered_set fwd_r_vals = unordered_set_of(values(this->fwd_map)); + + ASSERT(fwd_l_vals == bwd_l_vals); + ASSERT(fwd_r_vals == bwd_r_vals); + + for (L const &l : fwd_l_vals) { + ASSERT(bwd_map.at(fwd_map.at(l)) == l); + } + } + friend struct bidict; std::unordered_map fwd_map; diff --git a/lib/utils/include/utils/containers/all_of.h b/lib/utils/include/utils/containers/all_of.h index ef5aac1c41..15ed234511 100644 --- a/lib/utils/include/utils/containers/all_of.h +++ b/lib/utils/include/utils/containers/all_of.h @@ -8,7 +8,7 @@ namespace FlexFlow { template -bool all_of(C const &c, F &&f) { +[[nodiscard]] bool all_of(C const &c, F &&f) { for (auto const &v : c) { if (!f(v)) { return false; @@ -18,7 +18,7 @@ bool all_of(C const &c, F &&f) { } template -bool all_of(std::unordered_map const &m, F &&f) { +[[nodiscard]] bool all_of(std::unordered_map const &m, F &&f) { for (auto const &[k, v] : m) { if (!f(k, v)) { return false; @@ -29,7 +29,7 @@ bool all_of(std::unordered_map const &m, F &&f) { } template -bool all_of(std::map const &m, F &&f) { +[[nodiscard]] bool all_of(std::map const &m, F &&f) { for (auto const &[k, v] : m) { if (!f(k, v)) { return false; @@ -39,7 +39,7 @@ bool all_of(std::map const &m, F &&f) { return true; } -bool all_of(std::vector const &); +[[nodiscard]] bool all_of(std::vector const &); } // namespace FlexFlow diff --git a/lib/utils/include/utils/containers/binary_merge_disjoint_maps.h b/lib/utils/include/utils/containers/binary_merge_disjoint_maps.h index 06a42327e1..85ff3d75fa 100644 --- a/lib/utils/include/utils/containers/binary_merge_disjoint_maps.h +++ b/lib/utils/include/utils/containers/binary_merge_disjoint_maps.h @@ -3,18 +3,20 @@ #include "utils/containers/binary_merge_maps_with.h" #include +#include "utils/containers/keys.h" +#include "utils/containers/intersection.h" namespace FlexFlow { template -std::unordered_map - binary_merge_disjoint_maps(std::unordered_map const &lhs, - std::unordered_map const &rhs) { +std::map + binary_merge_disjoint_maps(std::map const &lhs, + std::map const &rhs) { - std::unordered_set lhs_keys = keys(lhs); - std::unordered_set rhs_keys = keys(rhs); + std::set lhs_keys = keys(lhs); + std::set rhs_keys = keys(rhs); - std::unordered_set shared_keys = intersection(lhs_keys, rhs_keys); + std::set shared_keys = intersection(lhs_keys, rhs_keys); ASSERT(shared_keys.empty()); return binary_merge_maps_with( diff --git a/lib/utils/include/utils/containers/binary_merge_disjoint_unordered_maps.h b/lib/utils/include/utils/containers/binary_merge_disjoint_unordered_maps.h new file mode 100644 index 0000000000..536d402be0 --- /dev/null +++ b/lib/utils/include/utils/containers/binary_merge_disjoint_unordered_maps.h @@ -0,0 +1,28 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_BINARY_MERGE_DISJOINT_UNORDERED_MAPS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_BINARY_MERGE_DISJOINT_UNORDERED_MAPS_H + +#include +#include "utils/containers/binary_merge_unordered_maps_with.h" +#include "utils/containers/unordered_keys.h" +#include "utils/containers/intersection.h" + +namespace FlexFlow { + +template +std::unordered_map + binary_merge_disjoint_unordered_maps(std::unordered_map const &lhs, + std::unordered_map const &rhs) { + + std::unordered_set lhs_keys = unordered_keys(lhs); + std::unordered_set rhs_keys = unordered_keys(rhs); + + std::unordered_set shared_keys = intersection(lhs_keys, rhs_keys); + ASSERT(shared_keys.empty()); + + return binary_merge_unordered_maps_with( + lhs, rhs, [](V const &, V const &) -> V { PANIC(); }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/binary_merge_maps_with.h b/lib/utils/include/utils/containers/binary_merge_maps_with.h index a7c196d061..3c3b556830 100644 --- a/lib/utils/include/utils/containers/binary_merge_maps_with.h +++ b/lib/utils/include/utils/containers/binary_merge_maps_with.h @@ -12,22 +12,22 @@ namespace FlexFlow { template -std::unordered_map - binary_merge_maps_with(std::unordered_map const &lhs, - std::unordered_map const &rhs, +std::map + binary_merge_maps_with(std::map const &lhs, + std::map const &rhs, F &&f) { - std::unordered_set l_keys = keys(lhs); - std::unordered_set r_keys = keys(rhs); + std::set l_keys = keys(lhs); + std::set r_keys = keys(rhs); - std::unordered_set l_only_keys = set_minus(l_keys, r_keys); - std::unordered_set r_only_keys = set_minus(r_keys, l_keys); - std::unordered_set both_keys = intersection(r_keys, l_keys); + std::set l_only_keys = set_minus(l_keys, r_keys); + std::set r_only_keys = set_minus(r_keys, l_keys); + std::set both_keys = intersection(r_keys, l_keys); - std::unordered_map l_only = restrict_keys(lhs, l_only_keys); - std::unordered_map r_only = restrict_keys(rhs, r_only_keys); + std::map l_only = restrict_keys(lhs, l_only_keys); + std::map r_only = restrict_keys(rhs, r_only_keys); - std::unordered_map merged = generate_map( + std::map merged = generate_map( both_keys, [&](K const &k) { return f(lhs.at(k), rhs.at(k)); }); return merge_maps_with_right_dominating(std::vector{ diff --git a/lib/utils/include/utils/containers/binary_merge_maps_with_left_dominating.h b/lib/utils/include/utils/containers/binary_merge_maps_with_left_dominating.h index f6e23af11c..25be62d9c3 100644 --- a/lib/utils/include/utils/containers/binary_merge_maps_with_left_dominating.h +++ b/lib/utils/include/utils/containers/binary_merge_maps_with_left_dominating.h @@ -6,9 +6,9 @@ namespace FlexFlow { template -std::unordered_map binary_merge_maps_with_left_dominating( - std::unordered_map const &lhs, std::unordered_map const &rhs) { - std::unordered_map result; +std::map binary_merge_maps_with_left_dominating( + std::map const &lhs, std::map const &rhs) { + std::map result; merge_in_map(rhs, result); merge_in_map(lhs, result); return result; diff --git a/lib/utils/include/utils/containers/binary_merge_maps_with_right_dominating.h b/lib/utils/include/utils/containers/binary_merge_maps_with_right_dominating.h index e5e29dfcb9..e4bfdd6d29 100644 --- a/lib/utils/include/utils/containers/binary_merge_maps_with_right_dominating.h +++ b/lib/utils/include/utils/containers/binary_merge_maps_with_right_dominating.h @@ -6,9 +6,9 @@ namespace FlexFlow { template -std::unordered_map binary_merge_maps_with_right_dominating( - std::unordered_map const &lhs, std::unordered_map const &rhs) { - std::unordered_map result; +std::map binary_merge_maps_with_right_dominating( + std::map const &lhs, std::map const &rhs) { + std::map result; merge_in_map(lhs, result); merge_in_map(rhs, result); return result; diff --git a/lib/utils/include/utils/containers/binary_merge_unordered_maps_with.h b/lib/utils/include/utils/containers/binary_merge_unordered_maps_with.h new file mode 100644 index 0000000000..2f8be45802 --- /dev/null +++ b/lib/utils/include/utils/containers/binary_merge_unordered_maps_with.h @@ -0,0 +1,42 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_BINARY_MERGE_UNORDERED_MAPS_WITH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_BINARY_MERGE_UNORDERED_MAPS_WITH_H + +#include "utils/containers/generate_unordered_map.h" +#include "utils/containers/intersection.h" +#include "utils/containers/unordered_keys.h" +#include "utils/containers/merge_unordered_maps_with_right_dominating.h" +#include "utils/containers/restrict_keys.h" +#include "utils/containers/set_minus.h" +#include + +namespace FlexFlow { + +template +std::unordered_map + binary_merge_unordered_maps_with(std::unordered_map const &lhs, + std::unordered_map const &rhs, + F &&f) { + + std::unordered_set l_keys = unordered_keys(lhs); + std::unordered_set r_keys = unordered_keys(rhs); + + std::unordered_set l_only_keys = set_minus(l_keys, r_keys); + std::unordered_set r_only_keys = set_minus(r_keys, l_keys); + std::unordered_set both_keys = intersection(r_keys, l_keys); + + std::unordered_map l_only = restrict_keys(lhs, l_only_keys); + std::unordered_map r_only = restrict_keys(rhs, r_only_keys); + + std::unordered_map merged = generate_unordered_map( + both_keys, [&](K const &k) { return f(lhs.at(k), rhs.at(k)); }); + + return merge_unordered_maps_with_right_dominating(std::vector{ + l_only, + r_only, + merged, + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/binary_merge_unordered_maps_with_left_dominating.h b/lib/utils/include/utils/containers/binary_merge_unordered_maps_with_left_dominating.h new file mode 100644 index 0000000000..0d71a1b7cc --- /dev/null +++ b/lib/utils/include/utils/containers/binary_merge_unordered_maps_with_left_dominating.h @@ -0,0 +1,19 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_BINARY_MERGE_UNORDERED_MAPS_WITH_LEFT_DOMINATING_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_BINARY_MERGE_UNORDERED_MAPS_WITH_LEFT_DOMINATING_H + +#include "utils/containers/merge_in_unordered_map.h" + +namespace FlexFlow { + +template +std::unordered_map binary_merge_unordered_maps_with_left_dominating( + std::unordered_map const &lhs, std::unordered_map const &rhs) { + std::unordered_map result; + merge_in_unordered_map(rhs, result); + merge_in_unordered_map(lhs, result); + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/binary_merge_unordered_maps_with_right_dominating.h b/lib/utils/include/utils/containers/binary_merge_unordered_maps_with_right_dominating.h new file mode 100644 index 0000000000..6a3581a635 --- /dev/null +++ b/lib/utils/include/utils/containers/binary_merge_unordered_maps_with_right_dominating.h @@ -0,0 +1,19 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_BINARY_MERGE_UNORDERED_MAPS_WITH_RIGHT_DOMINATING_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_BINARY_MERGE_UNORDERED_MAPS_WITH_RIGHT_DOMINATING_H + +#include "utils/containers/merge_in_unordered_map.h" + +namespace FlexFlow { + +template +std::unordered_map binary_merge_unordered_maps_with_right_dominating( + std::unordered_map const &lhs, std::unordered_map const &rhs) { + std::unordered_map result; + merge_in_unordered_map(lhs, result); + merge_in_unordered_map(rhs, result); + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/filter.h b/lib/utils/include/utils/containers/filter.h index 07f25dc348..85a413c2c7 100644 --- a/lib/utils/include/utils/containers/filter.h +++ b/lib/utils/include/utils/containers/filter.h @@ -44,6 +44,13 @@ std::map filter(std::map const &m, F const &f) { return result; } +template +std::multiset filter(std::multiset const &m, F const &f) { + std::multiset result; + std::copy_if(m.cbegin(), m.cend(), std::inserter(result, result.begin()), f); + return result; +} + template std::unordered_multiset filter(std::unordered_multiset const &m, F const &f) { diff --git a/lib/utils/include/utils/containers/flatmap.h b/lib/utils/include/utils/containers/flatmap.h index 70de6b5020..e0440cc791 100644 --- a/lib/utils/include/utils/containers/flatmap.h +++ b/lib/utils/include/utils/containers/flatmap.h @@ -1,12 +1,12 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_FLATMAP_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_FLATMAP_H -#include "utils/containers/binary_merge_disjoint_maps.h" #include "utils/containers/extend.h" #include "utils/containers/get_element_type.h" #include #include #include +#include "utils/containers/binary_merge_disjoint_unordered_maps.h" namespace FlexFlow { @@ -77,7 +77,7 @@ std::unordered_map flatmap(std::unordered_map const &m, std::unordered_map result; for (auto const &[k, v] : m) { - result = binary_merge_disjoint_maps(result, f(k, v)); + result = binary_merge_disjoint_unordered_maps(result, f(k, v)); } return result; diff --git a/lib/utils/include/utils/containers/generate_map.h b/lib/utils/include/utils/containers/generate_map.h index 53b2a590c5..08bfc86350 100644 --- a/lib/utils/include/utils/containers/generate_map.h +++ b/lib/utils/include/utils/containers/generate_map.h @@ -5,7 +5,7 @@ #include "utils/containers/vector_of.h" #include "utils/containers/vector_transform.h" #include "utils/type_traits_core.h" -#include +#include namespace FlexFlow { @@ -13,8 +13,8 @@ template , typename V = std::invoke_result_t> -std::unordered_map generate_map(C const &c, F const &f) { - static_assert(is_hashable_v, "Key type should be hashable (but is not)"); +std::map generate_map(C const &c, F &&f) { + static_assert(is_lt_comparable_v, "Key type should be ordered (but is not)"); auto transformed = vector_transform(vector_of(c), [&](K const &k) -> std::pair { diff --git a/lib/utils/include/utils/containers/generate_unordered_map.h b/lib/utils/include/utils/containers/generate_unordered_map.h new file mode 100644 index 0000000000..d57632b7ae --- /dev/null +++ b/lib/utils/include/utils/containers/generate_unordered_map.h @@ -0,0 +1,28 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GENERATE_UNORDERED_MAP_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GENERATE_UNORDERED_MAP_H + +#include "utils/containers/get_element_type.h" +#include "utils/containers/vector_of.h" +#include "utils/containers/vector_transform.h" +#include "utils/type_traits_core.h" +#include + +namespace FlexFlow { + +template , + typename V = std::invoke_result_t> +std::unordered_map generate_unordered_map(C const &c, F &&f) { + static_assert(is_hashable_v, "Key type should be hashable (but is not)"); + + auto transformed = + vector_transform(vector_of(c), [&](K const &k) -> std::pair { + return {k, f(k)}; + }); + return {transformed.cbegin(), transformed.cend()}; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/get_all_assignments.h b/lib/utils/include/utils/containers/get_all_assignments.h index 9981948f47..8f77ffbc24 100644 --- a/lib/utils/include/utils/containers/get_all_assignments.h +++ b/lib/utils/include/utils/containers/get_all_assignments.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GET_ALL_ASSIGNMENTS_H #include "utils/containers/cartesian_product.h" -#include "utils/containers/keys.h" +#include "utils/containers/unordered_keys.h" #include "utils/containers/transform.h" #include "utils/containers/unordered_map_from_pairs.h" #include "utils/containers/unordered_set_of.h" @@ -26,7 +26,7 @@ std::unordered_set> get_all_assignments( return {{}}; } - std::vector ordered_keys = vector_of(keys(options_per_key)); + std::vector ordered_keys = vector_of(unordered_keys(options_per_key)); std::vector> ordered_value_option_sets = transform( ordered_keys, [&](K const &k) { return options_per_key.at(k); }); diff --git a/lib/utils/include/utils/containers/get_only.h b/lib/utils/include/utils/containers/get_only.h index ed44b26c36..d5e2faf67a 100644 --- a/lib/utils/include/utils/containers/get_only.h +++ b/lib/utils/include/utils/containers/get_only.h @@ -2,17 +2,15 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_GET_ONLY_H #include "utils/containers/maybe_get_only.h" -#include "utils/exception.h" +#include #include "utils/optional.h" namespace FlexFlow { template typename C::value_type get_only(C const &c) { - return unwrap(maybe_get_only(c), [&] { - throw mk_runtime_error(fmt::format( - "Encountered container with size {} in get_only", c.size())); - }); + ASSERT(c.size() == 1); + return maybe_get_only(c).value(); } template diff --git a/lib/utils/include/utils/containers/index.dox b/lib/utils/include/utils/containers/index.dox index 9b3865dd78..2ffda1cfdf 100644 --- a/lib/utils/include/utils/containers/index.dox +++ b/lib/utils/include/utils/containers/index.dox @@ -9,7 +9,7 @@ Some of the most commonly-used functions are listed below, but you should ideall - \ref containers/transform.h - \ref containers/filter.h - \ref containers/contains.h -- \ref containers/generate_map.h +- \ref containers/generate_unordered_map.h - \ref containers/get_only.h - \ref containers/slice.h - \ref containers/merge_disjoint_maps.h diff --git a/lib/utils/include/utils/containers/is_submapeq_of.h b/lib/utils/include/utils/containers/is_submapeq_of.h index 03cb5ccd78..e50e85e745 100644 --- a/lib/utils/include/utils/containers/is_submapeq_of.h +++ b/lib/utils/include/utils/containers/is_submapeq_of.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_UTILS_INCLUDE_UTILS_CONTAINERS_IS_SUBMAP_H #define _FLEXFLOW_UTILS_INCLUDE_UTILS_CONTAINERS_IS_SUBMAP_H -#include "utils/containers/keys.h" +#include "utils/containers/unordered_keys.h" #include "utils/containers/restrict_keys.h" #include @@ -10,7 +10,7 @@ namespace FlexFlow { template bool is_submapeq_of(std::unordered_map const &sub, std::unordered_map const &m) { - return restrict_keys(m, keys(sub)) == sub; + return restrict_keys(m, unordered_keys(sub)) == sub; } } // namespace FlexFlow diff --git a/lib/utils/include/utils/containers/items.h b/lib/utils/include/utils/containers/items.h index 8e3ba95d6c..13e745b17e 100644 --- a/lib/utils/include/utils/containers/items.h +++ b/lib/utils/include/utils/containers/items.h @@ -1,12 +1,12 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ITEMS_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ITEMS_H -#include +#include namespace FlexFlow { template -std::unordered_set> +std::set> items(C const &c) { return {c.begin(), c.end()}; } diff --git a/lib/utils/include/utils/containers/keys.h b/lib/utils/include/utils/containers/keys.h index e14612541e..bd080b7087 100644 --- a/lib/utils/include/utils/containers/keys.h +++ b/lib/utils/include/utils/containers/keys.h @@ -3,13 +3,13 @@ #include #include -#include +#include namespace FlexFlow { template -std::unordered_set keys(std::unordered_map const &c) { - std::unordered_set result; +std::set keys(std::unordered_map const &c) { + std::set result; for (auto const &kv : c) { result.insert(kv.first); } @@ -17,8 +17,8 @@ std::unordered_set keys(std::unordered_map const &c) { } template -std::unordered_set keys(std::map const &c) { - std::unordered_set result; +std::set keys(std::map const &c) { + std::set result; for (auto const &kv : c) { result.insert(kv.first); } diff --git a/lib/utils/include/utils/containers/lookup_in_map.h b/lib/utils/include/utils/containers/lookup_in_map.h index 946fc589db..339b13f042 100644 --- a/lib/utils/include/utils/containers/lookup_in_map.h +++ b/lib/utils/include/utils/containers/lookup_in_map.h @@ -1,24 +1,22 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_LOOKUP_IN_MAP_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_LOOKUP_IN_MAP_H -#include "utils/containers/contains.h" -#include "utils/containers/keys.h" -#include "utils/exception.h" #include "utils/fmt/unordered_map.h" +#include "utils/containers/contains_key.h" #include #include #include +#include namespace FlexFlow { template -std::function lookup_in_map(std::unordered_map const &map) { - return [map](K const &key) -> V { - if (!contains(keys(map), key)) { - throw mk_runtime_error(fmt::format( - "Key {} is not present in the underlying map {}", key, map)); +std::function lookup_in_map(std::unordered_map const &m) { + return [m](K const &key) -> V { + if (!contains_key(m, key)) { + PANIC("Key {} is not present in the underlying map {}", key, m); } - return map.at(key); + return m.at(key); }; } diff --git a/lib/utils/include/utils/containers/map_from_pairs.h b/lib/utils/include/utils/containers/map_from_pairs.h index 7c470d4d3e..f5f2dad415 100644 --- a/lib/utils/include/utils/containers/map_from_pairs.h +++ b/lib/utils/include/utils/containers/map_from_pairs.h @@ -1,18 +1,15 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MAP_FROM_PAIRS_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MAP_FROM_PAIRS_H -#include -#include +#include namespace FlexFlow { -template -std::unordered_map - map_from_pairs(std::unordered_set> const &pairs) { - - std::unordered_map result(pairs.cbegin(), pairs.cend()); - - return result; +template +std::map map_from_pairs(C const &c) { + return std::map(c.cbegin(), c.cend()); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/containers/map_from_unordered.h b/lib/utils/include/utils/containers/map_from_unordered.h new file mode 100644 index 0000000000..1451cd1aa8 --- /dev/null +++ b/lib/utils/include/utils/containers/map_from_unordered.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MAP_FROM_UNORDERED_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MAP_FROM_UNORDERED_H + +#include +#include + +namespace FlexFlow { + +template +std::map map_from_unordered(std::unordered_map const &u) { + std::map result{u.cbegin(), u.cend()}; + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/map_keys.h b/lib/utils/include/utils/containers/map_keys.h index 5cd44d8a5d..ff41248a30 100644 --- a/lib/utils/include/utils/containers/map_keys.h +++ b/lib/utils/include/utils/containers/map_keys.h @@ -2,11 +2,13 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MAP_KEYS_H #include "utils/containers/keys.h" +#include "utils/containers/unordered_keys.h" #include "utils/containers/transform.h" #include "utils/containers/unordered_multiset_of.h" #include "utils/exception.h" #include #include +#include namespace FlexFlow { @@ -19,14 +21,35 @@ template > std::unordered_map map_keys(std::unordered_map const &m, - F const &f) { + F &&f) { std::unordered_map result; for (auto const &kv : m) { result.insert({f(kv.first), kv.second}); } - ASSERT(keys(m).size() == keys(result).size(), + ASSERT(m.size() == result.size(), + "keys passed to map_keys must be transformed into distinct keys"); + + return result; +} + +/** + * @brief Applies the given function to all the keys within the given map and + * returns the updated map. + */ +template > +std::map map_keys(std::map const &m, F &&f) { + + std::map result; + for (auto const &kv : m) { + result.insert({f(kv.first), kv.second}); + } + + ASSERT(m.size() == result.size(), "keys passed to map_keys must be transformed into distinct keys"); return result; diff --git a/lib/utils/include/utils/containers/map_keys2.h b/lib/utils/include/utils/containers/map_keys2.h index fd848f18d8..da68fe05b4 100644 --- a/lib/utils/include/utils/containers/map_keys2.h +++ b/lib/utils/include/utils/containers/map_keys2.h @@ -19,7 +19,7 @@ std::unordered_map map_keys2(std::unordered_map const &m, result.insert({f(kv.first, kv.second), kv.second}); } - ASSERT(keys(m).size() == keys(result).size(), + ASSERT(m.size() == result.size(), "keys passed to map_keys must be transformed into distinct keys"); return result; diff --git a/lib/utils/include/utils/containers/map_keys_and_values.h b/lib/utils/include/utils/containers/map_keys_and_values.h index 651ffb2aeb..1873421e17 100644 --- a/lib/utils/include/utils/containers/map_keys_and_values.h +++ b/lib/utils/include/utils/containers/map_keys_and_values.h @@ -1,9 +1,9 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MAP_KEYS_AND_VALUES_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MAP_KEYS_AND_VALUES_H -#include "utils/containers/keys.h" #include #include +#include namespace FlexFlow { @@ -21,7 +21,27 @@ std::unordered_map map_keys_and_values( result.insert({fk(kv.first), fv(kv.second)}); } - ASSERT(keys(m).size() == keys(result).size(), + ASSERT(m.size() == result.size(), + "keys passed to map_keys must be transformed into distinct keys"); + + return result; +} + +template , + typename V2 = std::invoke_result_t> +std::map map_keys_and_values( + std::map const &m, FK const &fk, FV const &fv) { + + std::map result; + for (auto const &kv : m) { + result.insert({fk(kv.first), fv(kv.second)}); + } + + ASSERT(m.size() == result.size(), "keys passed to map_keys must be transformed into distinct keys"); return result; diff --git a/lib/utils/include/utils/containers/map_values.h b/lib/utils/include/utils/containers/map_values.h index bf377b2c93..575fff977e 100644 --- a/lib/utils/include/utils/containers/map_values.h +++ b/lib/utils/include/utils/containers/map_values.h @@ -3,6 +3,7 @@ #include #include +#include namespace FlexFlow { @@ -18,6 +19,18 @@ std::unordered_map map_values(std::unordered_map const &m, F &&f) { return result; } +template > +std::map map_values(std::map const &m, F &&f) { + std::map result; + for (std::pair const &kv : m) { + result.insert(std::pair{kv.first, f(kv.second)}); + } + return result; +} + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/containers/map_values2.h b/lib/utils/include/utils/containers/map_values2.h index 752a8babd3..dd943b02bb 100644 --- a/lib/utils/include/utils/containers/map_values2.h +++ b/lib/utils/include/utils/containers/map_values2.h @@ -3,6 +3,7 @@ #include #include +#include namespace FlexFlow { @@ -19,6 +20,20 @@ std::unordered_map map_values2(std::unordered_map const &m, return result; } +template > +std::map map_values2(std::map const &m, + F &&f) { + std::map result; + for (std::pair const &kv : m) { + result.insert(std::pair{kv.first, f(kv.first, kv.second)}); + } + return result; +} + + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/containers/merge_disjoint_maps.h b/lib/utils/include/utils/containers/merge_disjoint_maps.h index eccb06180a..b541fdbd53 100644 --- a/lib/utils/include/utils/containers/merge_disjoint_maps.h +++ b/lib/utils/include/utils/containers/merge_disjoint_maps.h @@ -9,12 +9,12 @@ namespace FlexFlow { template -std::unordered_map merge_disjoint_maps(C const &c) { - std::unordered_map empty = {}; +std::map merge_disjoint_maps(C const &c) { + std::map empty = {}; return foldl(c, /*init=*/empty, - [](std::unordered_map const &lhs, - std::unordered_map const &rhs) { + [](std::map const &lhs, + std::map const &rhs) { return binary_merge_disjoint_maps(lhs, rhs); }); } diff --git a/lib/utils/include/utils/containers/merge_disjoint_unordered_maps.h b/lib/utils/include/utils/containers/merge_disjoint_unordered_maps.h new file mode 100644 index 0000000000..1bd7fb019b --- /dev/null +++ b/lib/utils/include/utils/containers/merge_disjoint_unordered_maps.h @@ -0,0 +1,24 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MERGE_DISJOINT_UNORDERED_MAPS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MERGE_DISJOINT_UNORDERED_MAPS_H + +#include "utils/containers/foldl.h" +#include "utils/containers/binary_merge_disjoint_unordered_maps.h" + +namespace FlexFlow { + +template +std::unordered_map merge_disjoint_unordered_maps(C const &c) { + std::unordered_map empty = {}; + return foldl(c, + /*init=*/empty, + [](std::unordered_map const &lhs, + std::unordered_map const &rhs) { + return binary_merge_disjoint_unordered_maps(lhs, rhs); + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/merge_in_map.h b/lib/utils/include/utils/containers/merge_in_map.h index edae4b8a6a..e41c1a6826 100644 --- a/lib/utils/include/utils/containers/merge_in_map.h +++ b/lib/utils/include/utils/containers/merge_in_map.h @@ -1,13 +1,13 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MERGE_IN_MAP_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MERGE_IN_MAP_H -#include +#include namespace FlexFlow { template -void merge_in_map(std::unordered_map const &m, - std::unordered_map &result) { +void merge_in_map(std::map const &m, + std::map &result) { for (auto const &[k, v] : m) { auto it = result.find(k); if (it != result.end()) { diff --git a/lib/utils/include/utils/containers/merge_in_unordered_map.h b/lib/utils/include/utils/containers/merge_in_unordered_map.h new file mode 100644 index 0000000000..7c2b31b8fc --- /dev/null +++ b/lib/utils/include/utils/containers/merge_in_unordered_map.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MERGE_IN_UNORDERED_MAPS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MERGE_IN_UNORDERED_MAPS_H + +#include + +namespace FlexFlow { + +template +void merge_in_unordered_map(std::unordered_map const &m, + std::unordered_map &result) { + for (auto const &[k, v] : m) { + auto it = result.find(k); + if (it != result.end()) { + it->second = v; + } else { + result.insert({k, v}); + } + } +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/merge_maps_with.h b/lib/utils/include/utils/containers/merge_maps_with.h index 2f5a09e26e..eef6c9af67 100644 --- a/lib/utils/include/utils/containers/merge_maps_with.h +++ b/lib/utils/include/utils/containers/merge_maps_with.h @@ -9,13 +9,13 @@ namespace FlexFlow { template -std::unordered_map - merge_maps_with(std::vector> const &to_merge, +std::map + merge_maps_with(std::vector> const &to_merge, F &&f) { return foldl(to_merge, - std::unordered_map{}, - [&](std::unordered_map const &accum, - std::unordered_map const &m) { + std::map{}, + [&](std::map const &accum, + std::map const &m) { return binary_merge_maps_with(accum, m, f); }); } diff --git a/lib/utils/include/utils/containers/merge_maps_with_right_dominating.h b/lib/utils/include/utils/containers/merge_maps_with_right_dominating.h index 1d4f8536d8..6271cef2c0 100644 --- a/lib/utils/include/utils/containers/merge_maps_with_right_dominating.h +++ b/lib/utils/include/utils/containers/merge_maps_with_right_dominating.h @@ -8,10 +8,10 @@ namespace FlexFlow { template -std::unordered_map merge_maps_with_right_dominating(C const &c) { - std::unordered_map result; +std::map merge_maps_with_right_dominating(C const &c) { + std::map result; - for (std::unordered_map const &m : c) { + for (std::map const &m : c) { merge_in_map(m, result); } diff --git a/lib/utils/include/utils/containers/merge_unordered_maps_with.h b/lib/utils/include/utils/containers/merge_unordered_maps_with.h new file mode 100644 index 0000000000..fee7fa2fa4 --- /dev/null +++ b/lib/utils/include/utils/containers/merge_unordered_maps_with.h @@ -0,0 +1,25 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MERGE_UNORDERED_MAPS_WITH_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MERGE_UNORDERED_MAPS_WITH_H + +#include "utils/containers/binary_merge_unordered_maps_with.h" +#include "utils/containers/foldl.h" +#include +#include + +namespace FlexFlow { + +template +std::unordered_map + merge_unordered_maps_with(std::vector> const &to_merge, + F &&f) { + return foldl(to_merge, + std::unordered_map{}, + [&](std::unordered_map const &accum, + std::unordered_map const &m) { + return binary_merge_unordered_maps_with(accum, m, f); + }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/merge_unordered_maps_with_right_dominating.h b/lib/utils/include/utils/containers/merge_unordered_maps_with_right_dominating.h new file mode 100644 index 0000000000..1323378019 --- /dev/null +++ b/lib/utils/include/utils/containers/merge_unordered_maps_with_right_dominating.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MERGE_UNORDERED_MAPS_WITH_RIGHT_DOMINATING_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_MERGE_UNORDERED_MAPS_WITH_RIGHT_DOMINATING_H + +#include "utils/containers/merge_in_unordered_map.h" + +namespace FlexFlow { + +template +std::unordered_map merge_unordered_maps_with_right_dominating(C const &c) { + std::unordered_map result; + + for (std::unordered_map const &m : c) { + merge_in_unordered_map(m, result); + } + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/require_all_of.h b/lib/utils/include/utils/containers/require_all_of.h new file mode 100644 index 0000000000..c085161659 --- /dev/null +++ b/lib/utils/include/utils/containers/require_all_of.h @@ -0,0 +1,33 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_REQUIRE_ALL_OF_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_REQUIRE_ALL_OF_H + +#include +#include +#include + +namespace FlexFlow { + +template +void require_all_of(C const &c, F &&f) { + for (auto const &v : c) { + f(v); + } +} + +template +void require_all_of(std::unordered_map const &m, F &&f) { + for (auto const &[k, v] : m) { + f(k, v); + } +} + +template +void require_all_of(std::map const &m, F &&f) { + for (auto const &[k, v] : m) { + f(k, v); + } +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/require_only_key.h b/lib/utils/include/utils/containers/require_only_key.h index c63ff4d440..ef142921ec 100644 --- a/lib/utils/include/utils/containers/require_only_key.h +++ b/lib/utils/include/utils/containers/require_only_key.h @@ -4,6 +4,7 @@ #include "utils/containers/contains_key.h" #include #include +#include namespace FlexFlow { @@ -15,6 +16,14 @@ V require_only_key(std::unordered_map const &m, K const &k) { return m.at(k); } +template +V require_only_key(std::map const &m, K const &k) { + ASSERT(m.size() == 1); + ASSERT(contains_key(m, k)); + + return m.at(k); +} + } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/containers/require_same.h b/lib/utils/include/utils/containers/require_same.h index 2f3439db32..2f6251064c 100644 --- a/lib/utils/include/utils/containers/require_same.h +++ b/lib/utils/include/utils/containers/require_same.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_REQUIRE_SAME_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_REQUIRE_SAME_H -#include "utils/exception.h" +#include #include namespace FlexFlow { diff --git a/lib/utils/include/utils/containers/restrict_keys.h b/lib/utils/include/utils/containers/restrict_keys.h index bedcc4ed8e..353b2ec237 100644 --- a/lib/utils/include/utils/containers/restrict_keys.h +++ b/lib/utils/include/utils/containers/restrict_keys.h @@ -4,6 +4,8 @@ #include "utils/containers/contains.h" #include #include +#include +#include namespace FlexFlow { @@ -19,6 +21,17 @@ std::unordered_map restrict_keys(std::unordered_map const &m, return result; } +template +std::map restrict_keys(std::map const &m, + std::set const &mask) { + std::map result; + for (auto const &kv : m) { + if (contains(mask, kv.first)) { + result.insert(kv); + } + } + return result; +} } // namespace FlexFlow #endif diff --git a/lib/utils/include/utils/containers/transform.h b/lib/utils/include/utils/containers/transform.h index 14ef782690..bb34b2b5a5 100644 --- a/lib/utils/include/utils/containers/transform.h +++ b/lib/utils/include/utils/containers/transform.h @@ -2,13 +2,15 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_TRANSFORM_H #include "utils/containers/vector_transform.h" -#include "utils/required_core.h" #include #include #include #include #include #include +#include +#include +#include namespace FlexFlow { @@ -17,12 +19,6 @@ std::vector transform(std::vector const &v, F const &f) { return vector_transform(v, f); } -template -auto transform(req const &c, F const &f) - -> decltype(transform(std::declval(), std::declval())) { - return transform(static_cast(c), f); -} - template > std::unordered_set transform(std::unordered_set const &v, F const &f) { std::unordered_set result; @@ -86,8 +82,8 @@ template ::first_type, typename V2 = typename std::invoke_result_t::second_type> -std::unordered_map transform(std::map const &m, F const &f) { - std::unordered_map result; +std::map transform(std::map const &m, F const &f) { + std::map result; for (auto const &[k, v] : m) { result.insert(f(k, v)); } diff --git a/lib/utils/include/utils/containers/transform_pairs.h b/lib/utils/include/utils/containers/transform_pairs.h new file mode 100644 index 0000000000..3e421ea445 --- /dev/null +++ b/lib/utils/include/utils/containers/transform_pairs.h @@ -0,0 +1,47 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_TRANSFORM_PAIRS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_TRANSFORM_PAIRS_H + +#include "utils/containers/transform.h" + +namespace FlexFlow { + +template > +std::vector transform_pairs(std::vector> const &c, F &&f) { + auto ff = [&](std::pair const &p) -> Out { + return f(p.first, p.second); + }; + + return transform(c, ff); +} + +template > +std::unordered_set + transform_pairs(std::unordered_set> const &c, F &&f) { + auto ff = [&](std::pair const &p) -> Out { + return f(p.first, p.second); + }; + + return transform(c, ff); +} + +template > +std::set transform_pairs(std::set> const &c, F &&f) { + auto ff = [&](std::pair const &p) -> Out { + return f(p.first, p.second); + }; + + return transform(c, ff); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/unordered_items.h b/lib/utils/include/utils/containers/unordered_items.h new file mode 100644 index 0000000000..1bd8da1498 --- /dev/null +++ b/lib/utils/include/utils/containers/unordered_items.h @@ -0,0 +1,17 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_UNORDERED_ITEMS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_UNORDERED_ITEMS_H + +#include +#include "utils/hash/pair.h" + +namespace FlexFlow { + +template +std::unordered_set> + unordered_items(C const &c) { + return {c.begin(), c.end()}; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/unordered_keys.h b/lib/utils/include/utils/containers/unordered_keys.h new file mode 100644 index 0000000000..e4b74a6f5e --- /dev/null +++ b/lib/utils/include/utils/containers/unordered_keys.h @@ -0,0 +1,30 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_UNORDERED_KEYS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_UNORDERED_KEYS_H + +#include +#include +#include + +namespace FlexFlow { + +template +std::unordered_set unordered_keys(std::unordered_map const &c) { + std::unordered_set result; + for (auto const &kv : c) { + result.insert(kv.first); + } + return result; +} + +template +std::unordered_set unordered_keys(std::map const &c) { + std::unordered_set result; + for (auto const &kv : c) { + result.insert(kv.first); + } + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/unordered_map_from_map.h b/lib/utils/include/utils/containers/unordered_map_from_map.h new file mode 100644 index 0000000000..4410e5a767 --- /dev/null +++ b/lib/utils/include/utils/containers/unordered_map_from_map.h @@ -0,0 +1,18 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_UNORDERED_MAP_FROM_MAP_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_UNORDERED_MAP_FROM_MAP_H + +#include +#include + +namespace FlexFlow { + +template +std::unordered_map unordered_map_from_map(std::map const &u) { + std::unordered_map result{u.cbegin(), u.cend()}; + + return result; +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/containers/zip_values_strict.h b/lib/utils/include/utils/containers/zip_values_strict.h index 60a7985bc5..b891490e39 100644 --- a/lib/utils/include/utils/containers/zip_values_strict.h +++ b/lib/utils/include/utils/containers/zip_values_strict.h @@ -1,11 +1,14 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_VALUES_STRICT_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_VALUES_STRICT_H -#include "utils/containers/generate_map.h" -#include "utils/containers/keys.h" +#include "utils/containers/generate_unordered_map.h" +#include "utils/containers/unordered_keys.h" #include "utils/containers/require_same.h" #include #include +#include +#include "utils/containers/keys.h" +#include "utils/containers/generate_map.h" namespace FlexFlow { @@ -14,6 +17,21 @@ std::unordered_map> zip_values_strict(std::unordered_map const &m1, std::unordered_map const &m2) { + ASSERT(unordered_keys(m1) == unordered_keys(m2)); + + return generate_unordered_map(require_same(unordered_keys(m1), unordered_keys(m2)), [&](K const &k) { + return std::pair{ + m1.at(k), + m2.at(k), + }; + }); +} + +template +std::map> + zip_values_strict(std::map const &m1, + std::map const &m2) { + ASSERT(keys(m1) == keys(m2)); return generate_map(require_same(keys(m1), keys(m2)), [&](K const &k) { diff --git a/lib/utils/include/utils/containers/zip_values_strict_with.h b/lib/utils/include/utils/containers/zip_values_strict_with.h index 3b0530db8a..5fc4bb7f5b 100644 --- a/lib/utils/include/utils/containers/zip_values_strict_with.h +++ b/lib/utils/include/utils/containers/zip_values_strict_with.h @@ -1,8 +1,8 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_VALUES_STRICT_WITH_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_CONTAINERS_ZIP_VALUES_STRICT_WITH_H -#include "utils/containers/generate_map.h" -#include "utils/containers/keys.h" +#include "utils/containers/generate_unordered_map.h" +#include "utils/containers/unordered_keys.h" #include "utils/containers/require_same.h" #include #include @@ -19,9 +19,9 @@ std::unordered_map std::unordered_map const &m2, F &&f) { - ASSERT(keys(m1) == keys(m2)); + ASSERT(unordered_keys(m1) == unordered_keys(m2)); - return generate_map(require_same(keys(m1), keys(m2)), + return generate_unordered_map(require_same(unordered_keys(m1), unordered_keys(m2)), [&](K const &k) -> Out { return f(m1.at(k), m2.at(k)); }); } diff --git a/lib/utils/include/utils/fmt/map.h b/lib/utils/include/utils/fmt/map.h index 9225040d4d..5b0d41a9cc 100644 --- a/lib/utils/include/utils/fmt/map.h +++ b/lib/utils/include/utils/fmt/map.h @@ -22,10 +22,8 @@ struct formatter< CHECK_FMTABLE(K); CHECK_FMTABLE(V); - std::vector> items = ::FlexFlow::sorted(m); - std::string result = ::FlexFlow::join_strings( - items.cbegin(), items.cend(), ", ", [](std::pair const &p) { + m.cbegin(), m.cend(), ", ", [](std::pair const &p) { return fmt::to_string(p); }); diff --git a/lib/utils/include/utils/full_binary_tree/get_path_to_leaf_map.h b/lib/utils/include/utils/full_binary_tree/get_path_to_leaf_map.h index fd77509e4d..e3947605c0 100644 --- a/lib/utils/include/utils/full_binary_tree/get_path_to_leaf_map.h +++ b/lib/utils/include/utils/full_binary_tree/get_path_to_leaf_map.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_PATH_TO_LEAF_MAP_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_FULL_BINARY_TREE_GET_PATH_TO_LEAF_MAP_H -#include "utils/containers/binary_merge_disjoint_maps.h" +#include "utils/containers/binary_merge_disjoint_unordered_maps.h" #include "utils/containers/map_keys.h" #include "utils/containers/multiset_union.h" #include "utils/full_binary_tree/binary_tree_path.dtg.h" @@ -30,7 +30,7 @@ std::unordered_map get_path_to_leaf_map( get_path_to_leaf_map(impl.get_right_child(parent), impl), [](BinaryTreePath const &p) { return nest_inside_right_child(p); }); - return binary_merge_disjoint_maps(left_map, right_map); + return binary_merge_disjoint_unordered_maps(left_map, right_map); }, [](Leaf const &leaf) -> std::unordered_map { return std::unordered_map{ diff --git a/lib/utils/include/utils/graph/instances/unordered_set_kwarg_dataflow_graph.h b/lib/utils/include/utils/graph/instances/unordered_set_kwarg_dataflow_graph.h index 418346bb36..4bb8373666 100644 --- a/lib/utils/include/utils/graph/instances/unordered_set_kwarg_dataflow_graph.h +++ b/lib/utils/include/utils/graph/instances/unordered_set_kwarg_dataflow_graph.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_INSTANCES_UNORDERED_SET_KWARG_DATAFLOW_GRAPH_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_INSTANCES_UNORDERED_SET_KWARG_DATAFLOW_GRAPH_H -#include "utils/containers/generate_map.h" +#include "utils/containers/generate_unordered_map.h" #include "utils/containers/set_union.h" #include "utils/containers/values.h" #include "utils/graph/kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_edges.h" @@ -26,7 +26,7 @@ struct UnorderedSetKwargDataflowGraph final Node new_node = this->node_source.new_node(); std::unordered_map> outputs = - generate_map( + generate_unordered_map( output_slots, [&](SlotName const &output_slot) -> KwargDataflowOutput { KwargDataflowOutput output = diff --git a/lib/utils/include/utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h b/lib/utils/include/utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h index 159778bb6d..f05e6cb58a 100644 --- a/lib/utils/include/utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h +++ b/lib/utils/include/utils/graph/instances/unordered_set_labelled_open_dataflow_graph.h @@ -4,8 +4,8 @@ #include "utils/containers/count.h" #include "utils/containers/enumerate_vector.h" #include "utils/containers/filter.h" -#include "utils/containers/generate_map.h" -#include "utils/containers/keys.h" +#include "utils/containers/generate_unordered_map.h" +#include "utils/containers/unordered_keys.h" #include "utils/containers/map_keys.h" #include "utils/containers/transform.h" #include "utils/containers/without_nullopts.h" @@ -23,6 +23,7 @@ #include "utils/graph/open_dataflow_graph/dataflow_graph_input_source.h" #include "utils/graph/open_dataflow_graph/open_dataflow_edge.h" #include "utils/graph/open_dataflow_graph/open_dataflow_edge_query.h" +#include "utils/containers/unordered_keys.h" namespace FlexFlow { @@ -81,7 +82,7 @@ struct UnorderedSetLabelledOpenDataflowGraph final } std::unordered_set query_nodes(NodeQuery const &q) const override { - return filter(keys(this->nodes), + return filter(unordered_keys(this->nodes), [&](Node const &n) { return includes(q.nodes, n); }); } @@ -95,7 +96,7 @@ struct UnorderedSetLabelledOpenDataflowGraph final std::unordered_set query_outputs(DataflowOutputQuery const &q) const override { return without_nullopts(transform( - keys(this->values), + unordered_keys(this->values), [&](OpenDataflowValue const &v) -> std::optional { if (!v.has()) { return std::nullopt; @@ -128,12 +129,12 @@ struct UnorderedSetLabelledOpenDataflowGraph final std::unordered_set outputs = get_all_dataflow_outputs(view); std::unordered_set edges = get_edges(view); std::unordered_map labelled_outputs = - generate_map(outputs, + generate_unordered_map(outputs, [&](DataflowOutput const &o) { return view.at(o); }); this->inputs.clear(); this->nodes = - generate_map(nodes, [&](Node const &n) { return view.at(n); }); + generate_unordered_map(nodes, [&](Node const &n) { return view.at(n); }); this->edges = transform( edges, [](DataflowEdge const &e) { return OpenDataflowEdge{e}; }); this->values = map_keys(labelled_outputs, [](DataflowOutput const &o) { @@ -145,14 +146,14 @@ struct UnorderedSetLabelledOpenDataflowGraph final LabelledOpenDataflowGraphView const &view) override { - std::unordered_map nodes = generate_map( + std::unordered_map nodes = generate_unordered_map( get_nodes(view), [&](Node const &n) { return view.at(n); }); std::unordered_set edges = get_edges(view); std::unordered_set inputs = ::FlexFlow::get_open_dataflow_graph_inputs(view); std::unordered_map values = - generate_map(get_open_dataflow_values(view), + generate_unordered_map(get_open_dataflow_values(view), [&](OpenDataflowValue const &v) { return view.at(v); }); this->inputs = inputs; diff --git a/lib/utils/include/utils/graph/instances/unordered_set_labelled_open_kwarg_dataflow_graph.h b/lib/utils/include/utils/graph/instances/unordered_set_labelled_open_kwarg_dataflow_graph.h index 2b20b94c96..ab60e3d364 100644 --- a/lib/utils/include/utils/graph/instances/unordered_set_labelled_open_kwarg_dataflow_graph.h +++ b/lib/utils/include/utils/graph/instances/unordered_set_labelled_open_kwarg_dataflow_graph.h @@ -4,7 +4,7 @@ #include "utils/containers/contains_key.h" #include "utils/containers/enumerate.h" #include "utils/containers/extend.h" -#include "utils/containers/generate_map.h" +#include "utils/containers/generate_unordered_map.h" #include "utils/containers/map_values.h" #include "utils/graph/kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_edges.h" #include "utils/graph/kwarg_dataflow_graph/algorithms/get_all_kwarg_dataflow_outputs.h" @@ -17,6 +17,7 @@ #include "utils/graph/open_kwarg_dataflow_graph/algorithms/get_all_open_kwarg_dataflow_edges.h" #include "utils/graph/open_kwarg_dataflow_graph/open_kwarg_dataflow_edge.h" #include "utils/overload.h" +#include "utils/containers/unordered_keys.h" namespace FlexFlow { @@ -68,8 +69,8 @@ struct UnorderedSetLabelledOpenKwargDataflowGraph final } std::unordered_map> outputs = - generate_map( - keys(output_labels), + generate_unordered_map( + unordered_keys(output_labels), [&](SlotName const &output_slot) -> KwargDataflowOutput { ValueLabel value_label = output_labels.at(output_slot); @@ -106,7 +107,7 @@ struct UnorderedSetLabelledOpenKwargDataflowGraph final } std::unordered_set query_nodes(NodeQuery const &q) const override { - return filter(keys(this->nodes), + return filter(unordered_keys(this->nodes), [&](Node const &n) { return includes(q.nodes, n); }); } @@ -122,7 +123,7 @@ struct UnorderedSetLabelledOpenKwargDataflowGraph final std::unordered_set> query_outputs( KwargDataflowOutputQuery const &q) const override { - return filter(keys(this->outputs), + return filter(unordered_keys(this->outputs), [&](KwargDataflowOutput const &output) { return kwarg_dataflow_output_query_includes(q, output); }); @@ -130,7 +131,7 @@ struct UnorderedSetLabelledOpenKwargDataflowGraph final std::unordered_set> get_inputs() const override { - return keys(this->graph_inputs); + return unordered_keys(this->graph_inputs); } NodeLabel at(Node const &n) const override { @@ -159,7 +160,7 @@ struct UnorderedSetLabelledOpenKwargDataflowGraph final this->graph_inputs.clear(); this->nodes = - generate_map(view_nodes, [&](Node const &n) { return view.at(n); }); + generate_unordered_map(view_nodes, [&](Node const &n) { return view.at(n); }); this->edges = transform(view_edges, @@ -168,7 +169,7 @@ struct UnorderedSetLabelledOpenKwargDataflowGraph final return OpenKwargDataflowEdge{e}; }); this->outputs = - generate_map(view_outputs, [&](KwargDataflowOutput const &o) { + generate_unordered_map(view_outputs, [&](KwargDataflowOutput const &o) { return view.at(o); }); } @@ -186,16 +187,16 @@ struct UnorderedSetLabelledOpenKwargDataflowGraph final std::unordered_set> view_outputs = get_all_kwarg_dataflow_outputs(view); - this->graph_inputs = generate_map( + this->graph_inputs = generate_unordered_map( view_inputs, [&](KwargDataflowGraphInput const &i) { return view.at(OpenKwargDataflowValue{i}); }); this->nodes = - generate_map(view_nodes, [&](Node const &n) { return view.at(n); }); + generate_unordered_map(view_nodes, [&](Node const &n) { return view.at(n); }); this->edges = view_edges; this->outputs = - generate_map(view_outputs, [&](KwargDataflowOutput const &o) { + generate_unordered_map(view_outputs, [&](KwargDataflowOutput const &o) { return view.at(OpenKwargDataflowValue{o}); }); } diff --git a/lib/utils/include/utils/graph/instances/unordered_set_open_kwarg_dataflow_graph.h b/lib/utils/include/utils/graph/instances/unordered_set_open_kwarg_dataflow_graph.h index 3c66b2c689..9746533c17 100644 --- a/lib/utils/include/utils/graph/instances/unordered_set_open_kwarg_dataflow_graph.h +++ b/lib/utils/include/utils/graph/instances/unordered_set_open_kwarg_dataflow_graph.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_INSTANCES_UNORDERED_SET_OPEN_KWARG_DATAFLOW_GRAPH_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_INSTANCES_UNORDERED_SET_OPEN_KWARG_DATAFLOW_GRAPH_H -#include "utils/containers/generate_map.h" +#include "utils/containers/generate_unordered_map.h" #include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_output_query.h" #include "utils/graph/node/node_source.h" #include "utils/graph/open_kwarg_dataflow_graph/i_open_kwarg_dataflow_graph.h" @@ -36,7 +36,7 @@ struct UnorderedSetOpenKwargDataflowGraph final } std::unordered_map> outputs = - generate_map( + generate_unordered_map( output_slots, [&](SlotName const &output_slot) -> KwargDataflowOutput { KwargDataflowOutput output = diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_incoming_slots_for_node.h b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_incoming_slots_for_node.h index 32848f38a6..87ebab4c2d 100644 --- a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_incoming_slots_for_node.h +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_incoming_slots_for_node.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_INCOMING_SLOTS_FOR_NODE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_INCOMING_SLOTS_FOR_NODE_H -#include "utils/containers/keys.h" +#include "utils/containers/unordered_keys.h" #include "utils/graph/kwarg_dataflow_graph/algorithms/get_incoming_kwarg_dataflow_edges_for_node.h" namespace FlexFlow { @@ -10,7 +10,7 @@ template std::unordered_set get_incoming_slots_for_node(KwargDataflowGraphView const &g, Node n) { - return keys(get_incoming_kwarg_dataflow_edges_for_node(g, n)); + return unordered_keys(get_incoming_kwarg_dataflow_edges_for_node(g, n)); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_value_uses.h b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_value_uses.h new file mode 100644 index 0000000000..52c225d157 --- /dev/null +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_value_uses.h @@ -0,0 +1,28 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_KWARG_DATAFLOW_VALUE_USES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_KWARG_DATAFLOW_VALUE_USES_H + +#include "utils/graph/kwarg_dataflow_graph/kwarg_dataflow_graph_view.h" + +namespace FlexFlow { + +template +std::unordered_set> + get_kwarg_dataflow_value_uses(KwargDataflowGraphView const &g, + KwargDataflowOutput const &v) { + + KwargDataflowEdgeQuery query = KwargDataflowEdgeQuery{ + /*src_nodes=*/query_set::match_single_value(v.node), + /*src_slots=*/query_set::match_single_value(v.slot_name), + /*dst_nodes=*/query_set::matchall(), + /*dst_slots=*/query_set::matchall(), + }; + + std::unordered_set> edges = g.query_edges(query); + + return transform(edges, + [&](KwargDataflowEdge const &e) { return e.dst; }); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_slots_for_node.h b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_slots_for_node.h index 372dfed1e8..6cf2b4b8a4 100644 --- a/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_slots_for_node.h +++ b/lib/utils/include/utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_slots_for_node.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_OUTGOING_SLOTS_FOR_NODE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_OUTGOING_SLOTS_FOR_NODE_H -#include "utils/containers/keys.h" +#include "utils/containers/unordered_keys.h" #include "utils/graph/kwarg_dataflow_graph/algorithms/get_outgoing_kwarg_dataflow_outputs_for_node.h" namespace FlexFlow { @@ -10,7 +10,7 @@ template std::unordered_set get_outgoing_slots_for_node(KwargDataflowGraphView const &g, Node n) { - return keys(get_outgoing_kwarg_dataflow_outputs_for_node(g, n)); + return unordered_keys(get_outgoing_kwarg_dataflow_outputs_for_node(g, n)); } } // namespace FlexFlow diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/get_graph_data.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/get_graph_data.h index 2115a03cda..502eeab73b 100644 --- a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/get_graph_data.h +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/get_graph_data.h @@ -14,14 +14,14 @@ LabelledOpenDataflowGraphData get_graph_data( LabelledOpenDataflowGraphView const &g) { std::unordered_map node_data = - generate_map(get_nodes(g), [&](Node const &n) { return g.at(n); }); + generate_unordered_map(get_nodes(g), [&](Node const &n) { return g.at(n); }); std::unordered_set edges = get_edges(g); std::unordered_set inputs = g.get_inputs(); std::unordered_map value_data = - generate_map(get_open_dataflow_values(g), + generate_unordered_map(get_open_dataflow_values(g), [&](OpenDataflowValue const &v) { return g.at(v); }); return LabelledOpenDataflowGraphData{ diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/permute_input_ids.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/permute_input_ids.h index 88132e0a79..580a35b3f7 100644 --- a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/permute_input_ids.h +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/permute_input_ids.h @@ -30,10 +30,10 @@ LabelledOpenDataflowGraphView permute_input_ids( }; std::unordered_map node_labels = - generate_map(get_nodes(permuted), [&](Node const &n) { return g.at(n); }); + generate_unordered_map(get_nodes(permuted), [&](Node const &n) { return g.at(n); }); std::unordered_map value_labels = - generate_map(get_open_dataflow_values(permuted), + generate_unordered_map(get_open_dataflow_values(permuted), [&](OpenDataflowValue const &new_value) { return g.at(old_value_from_new(new_value)); }); diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/permute_node_ids.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/permute_node_ids.h index 88950635d2..5119587654 100644 --- a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/permute_node_ids.h +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/permute_node_ids.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_PERMUTE_NODE_IDS_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_PERMUTE_NODE_IDS_H -#include "utils/containers/generate_map.h" +#include "utils/containers/generate_unordered_map.h" #include "utils/graph/labelled_open_dataflow_graph/algorithms/with_labelling.h" #include "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.h" #include "utils/graph/node/algorithms.h" @@ -37,12 +37,12 @@ LabelledOpenDataflowGraphView permute_node_ids( }; std::unordered_map node_labels = - generate_map(get_nodes(permuted), [&](Node const &new_node) { + generate_unordered_map(get_nodes(permuted), [&](Node const &new_node) { return g.at(old_node_from_new(new_node)); }); std::unordered_map value_labels = - generate_map(get_open_dataflow_values(permuted), + generate_unordered_map(get_open_dataflow_values(permuted), [&](OpenDataflowValue const &new_value) { return g.at(old_value_from_new(new_value)); }); diff --git a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_labels.h b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_labels.h index 92938d7142..fde90497e7 100644 --- a/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_labels.h +++ b/lib/utils/include/utils/graph/labelled_open_dataflow_graph/algorithms/rewrite_labels.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_REWRITE_LABELS_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_DATAFLOW_GRAPH_ALGORITHMS_REWRITE_LABELS_H -#include "utils/containers/generate_map.h" +#include "utils/containers/generate_unordered_map.h" #include "utils/graph/labelled_open_dataflow_graph/algorithms/with_labelling.h" #include "utils/graph/labelled_open_dataflow_graph/labelled_open_dataflow_graph_view.h" #include "utils/graph/open_dataflow_graph/algorithms/get_open_dataflow_values.h" @@ -27,9 +27,9 @@ LabelledOpenDataflowGraphView rewrite_labels( }; std::unordered_map node_labels = - generate_map(get_nodes(g), get_new_node_label); + generate_unordered_map(get_nodes(g), get_new_node_label); std::unordered_map value_labels = - generate_map(get_open_dataflow_values(g), get_new_value_label); + generate_unordered_map(get_open_dataflow_values(g), get_new_value_label); return with_labelling(g, node_labels, value_labels); } diff --git a/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/get_labelled_open_kwarg_dataflow_graph_data.h b/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/get_labelled_open_kwarg_dataflow_graph_data.h index d60c396274..e98b858019 100644 --- a/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/get_labelled_open_kwarg_dataflow_graph_data.h +++ b/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/get_labelled_open_kwarg_dataflow_graph_data.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_DATA_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_GET_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_DATA_H -#include "utils/containers/generate_map.h" +#include "utils/containers/generate_unordered_map.h" #include "utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/labelled_open_kwarg_dataflow_graph_data.dtg.h" #include "utils/graph/labelled_open_kwarg_dataflow_graph/labelled_open_kwarg_dataflow_graph_view.h" #include "utils/graph/node/algorithms.h" @@ -28,12 +28,12 @@ LabelledOpenKwargDataflowGraphData{ - /*nodes=*/generate_map( + /*nodes=*/generate_unordered_map( get_nodes(g), [&](Node const &n) -> NodeLabel { return g.at(n); }), /*edges=*/get_all_open_kwarg_dataflow_edges(g), /*inputs=*/get_all_kwarg_dataflow_graph_inputs(g), /*outputs=*/ - generate_map( + generate_unordered_map( get_all_open_kwarg_dataflow_values(g), [&](OpenKwargDataflowValue const &v) -> ValueLabel { return g.at(v); }), diff --git a/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/labelled_open_kwarg_dataflow_graph_data.h b/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/labelled_open_kwarg_dataflow_graph_data.h index d06b96e37f..b6a4366fc2 100644 --- a/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/labelled_open_kwarg_dataflow_graph_data.h +++ b/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/labelled_open_kwarg_dataflow_graph_data.h @@ -21,12 +21,12 @@ OpenKwargDataflowGraphData SlotName> const &labelled_data) { OpenKwargDataflowGraphData result = OpenKwargDataflowGraphData{ - /*nodes=*/keys(labelled_data.node_data), + /*nodes=*/unordered_keys(labelled_data.node_data), /*edges=*/labelled_data.edges, /*inputs=*/labelled_data.inputs, /*outputs=*/ filtrans( - keys(labelled_data.value_data), + unordered_keys(labelled_data.value_data), [](OpenKwargDataflowValue const &v) { return v.try_require_internal(); }), diff --git a/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/permute_labelled_open_kwarg_dataflow_graph_input_ids.h b/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/permute_labelled_open_kwarg_dataflow_graph_input_ids.h index 223c2e7673..6e1c8bdc57 100644 --- a/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/permute_labelled_open_kwarg_dataflow_graph_input_ids.h +++ b/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/permute_labelled_open_kwarg_dataflow_graph_input_ids.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_PERMUTE_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_INPUT_IDS_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_PERMUTE_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_INPUT_IDS_H -#include "utils/containers/generate_map.h" +#include "utils/containers/generate_unordered_map.h" #include "utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_view_with_labelling.h" #include "utils/graph/labelled_open_kwarg_dataflow_graph/labelled_open_kwarg_dataflow_graph_view.h" #include "utils/graph/node/algorithms.h" @@ -57,11 +57,11 @@ LabelledOpenKwargDataflowGraphView node_labels = - generate_map(get_nodes(permuted), [&](Node const &n) { return g.at(n); }); + generate_unordered_map(get_nodes(permuted), [&](Node const &n) { return g.at(n); }); std::unordered_map, ValueLabel> - value_labels = generate_map( + value_labels = generate_unordered_map( get_all_open_kwarg_dataflow_values(permuted), [&](OpenKwargDataflowValue const &new_value) { return g.at(old_value_from_new(new_value)); }); diff --git a/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/permute_labelled_open_kwarg_dataflow_graph_node_ids.h b/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/permute_labelled_open_kwarg_dataflow_graph_node_ids.h index 06728949df..d89c935d52 100644 --- a/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/permute_labelled_open_kwarg_dataflow_graph_node_ids.h +++ b/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/permute_labelled_open_kwarg_dataflow_graph_node_ids.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_PERMUTE_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_NODE_IDS_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_PERMUTE_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_NODE_IDS_H -#include "utils/containers/generate_map.h" +#include "utils/containers/generate_unordered_map.h" #include "utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_view_with_labelling.h" #include "utils/graph/labelled_open_kwarg_dataflow_graph/labelled_open_kwarg_dataflow_graph_view.h" #include "utils/graph/node/algorithms/new_node.dtg.h" @@ -56,13 +56,13 @@ LabelledOpenKwargDataflowGraphView node_labels = - generate_map(get_nodes(permuted), [&](Node const &new_node) { + generate_unordered_map(get_nodes(permuted), [&](Node const &new_node) { return g.at(old_node_from_new(new_node)); }); std::unordered_map, ValueLabel> - value_labels = generate_map( + value_labels = generate_unordered_map( get_all_open_kwarg_dataflow_values(permuted), [&](OpenKwargDataflowValue const &new_value) { return g.at(old_value_from_new(new_value)); }); diff --git a/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/rewrite_labelled_open_kwarg_dataflow_graph_labels.h b/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/rewrite_labelled_open_kwarg_dataflow_graph_labels.h index a632cd7b64..d5d1435462 100644 --- a/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/rewrite_labelled_open_kwarg_dataflow_graph_labels.h +++ b/lib/utils/include/utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/rewrite_labelled_open_kwarg_dataflow_graph_labels.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_REWRITE_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_LABELS_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_ALGORITHMS_REWRITE_LABELLED_OPEN_KWARG_DATAFLOW_GRAPH_LABELS_H -#include "utils/containers/generate_map.h" +#include "utils/containers/generate_unordered_map.h" #include "utils/graph/labelled_open_kwarg_dataflow_graph/algorithms/open_kwarg_dataflow_graph_view_with_labelling.h" #include "utils/graph/labelled_open_kwarg_dataflow_graph/labelled_open_kwarg_dataflow_graph.h" #include "utils/graph/node/algorithms.h" @@ -39,10 +39,10 @@ LabelledOpenKwargDataflowGraphView NewValueLabel { return f(v, g.at(v)); }; std::unordered_map node_labels = - generate_map(get_nodes(g), get_new_node_label); + generate_unordered_map(get_nodes(g), get_new_node_label); std::unordered_map, NewValueLabel> - value_labels = generate_map(get_all_open_kwarg_dataflow_values(g), + value_labels = generate_unordered_map(get_all_open_kwarg_dataflow_values(g), get_new_value_label); return open_kwarg_dataflow_graph_view_with_labelling( g, node_labels, value_labels); diff --git a/lib/utils/include/utils/graph/series_parallel/get_ancestors.h b/lib/utils/include/utils/graph/series_parallel/get_ancestors.h index f8e9f52cb5..e658c04060 100644 --- a/lib/utils/include/utils/graph/series_parallel/get_ancestors.h +++ b/lib/utils/include/utils/graph/series_parallel/get_ancestors.h @@ -2,6 +2,7 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_GRAPH_SERIAL_PARALLLEL_GET_ANCESTORS_H #include "utils/graph/series_parallel/series_parallel_decomposition.dtg.h" +#include namespace FlexFlow { diff --git a/lib/utils/include/utils/graph/series_parallel/non_normal_parallel_split.dtg.toml b/lib/utils/include/utils/graph/series_parallel/non_normal_parallel_split.dtg.toml index b4b975bb4e..a262de1a6b 100644 --- a/lib/utils/include/utils/graph/series_parallel/non_normal_parallel_split.dtg.toml +++ b/lib/utils/include/utils/graph/series_parallel/non_normal_parallel_split.dtg.toml @@ -4,6 +4,7 @@ type = "struct" features = [ "eq", "hash", + "ord", "fmt", ] @@ -16,19 +17,19 @@ post_includes = [ ] includes = [ - "", + "", "", "utils/graph/node/node.dtg.h", ] src_includes = [ "utils/fmt/variant.h", - "utils/fmt/unordered_multiset.h", - "utils/hash/unordered_multiset.h", + "utils/fmt/multiset.h", + "utils/hash/multiset.h", ] [[fields]] name = "children" -type = "std::unordered_multiset>" +type = "std::multiset>" indirect = true diff --git a/lib/utils/include/utils/graph/series_parallel/non_normal_series_split.dtg.toml b/lib/utils/include/utils/graph/series_parallel/non_normal_series_split.dtg.toml index 008e58dc3f..8f0de7082d 100644 --- a/lib/utils/include/utils/graph/series_parallel/non_normal_series_split.dtg.toml +++ b/lib/utils/include/utils/graph/series_parallel/non_normal_series_split.dtg.toml @@ -5,6 +5,7 @@ features = [ "eq", "hash", "fmt", + "ord", ] fwd_decls = [ @@ -23,6 +24,7 @@ includes = [ src_includes = [ "utils/fmt/variant.h", + "utils/ord/vector.h", "utils/fmt/vector.h", "utils/hash/vector.h", ] diff --git a/lib/utils/include/utils/graph/series_parallel/non_normal_sp_decomposition.dtg.toml b/lib/utils/include/utils/graph/series_parallel/non_normal_sp_decomposition.dtg.toml index c82e771385..d6289f1d75 100644 --- a/lib/utils/include/utils/graph/series_parallel/non_normal_sp_decomposition.dtg.toml +++ b/lib/utils/include/utils/graph/series_parallel/non_normal_sp_decomposition.dtg.toml @@ -3,6 +3,7 @@ name = "NonNormalSPDecomposition" type = "variant" features = [ "eq", + "ord", "hash", "fmt", ] diff --git a/lib/utils/include/utils/graph/series_parallel/parallel_split.dtg.toml b/lib/utils/include/utils/graph/series_parallel/parallel_split.dtg.toml index a3315d506b..eb907d4d43 100644 --- a/lib/utils/include/utils/graph/series_parallel/parallel_split.dtg.toml +++ b/lib/utils/include/utils/graph/series_parallel/parallel_split.dtg.toml @@ -3,6 +3,7 @@ name = "ParallelSplit" type = "struct" features = [ "eq", + "ord", "hash", "fmt", ] @@ -16,18 +17,18 @@ post_includes = [ ] includes = [ - "", + "", "", "utils/graph/node/node.dtg.h", ] src_includes = [ "utils/fmt/variant.h", - "utils/fmt/unordered_multiset.h", - "utils/hash/unordered_multiset.h", + "utils/fmt/multiset.h", + "utils/hash/multiset.h", ] [[fields]] name = "children" -type = "std::unordered_multiset>" +type = "std::multiset>" indirect = true diff --git a/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.dtg.toml b/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.dtg.toml index 4635bdd877..b47a8eabaa 100644 --- a/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.dtg.toml +++ b/lib/utils/include/utils/graph/series_parallel/series_parallel_decomposition.dtg.toml @@ -3,6 +3,7 @@ name = "SeriesParallelDecomposition" type = "variant" features = [ "eq", + "ord", "hash", "fmt", ] diff --git a/lib/utils/include/utils/graph/series_parallel/series_split.dtg.toml b/lib/utils/include/utils/graph/series_parallel/series_split.dtg.toml index e37762a059..cb5753627d 100644 --- a/lib/utils/include/utils/graph/series_parallel/series_split.dtg.toml +++ b/lib/utils/include/utils/graph/series_parallel/series_split.dtg.toml @@ -3,6 +3,7 @@ name = "SeriesSplit" type = "struct" features = [ "eq", + "ord", "hash", "fmt", ] @@ -23,6 +24,7 @@ includes = [ src_includes = [ "utils/fmt/variant.h", + "utils/ord/vector.h", "utils/fmt/vector.h", "utils/hash/vector.h", ] diff --git a/lib/utils/include/utils/json/check_is_jsonable.h b/lib/utils/include/utils/json/check_is_jsonable.h index 9d4f65f005..8597e11c22 100644 --- a/lib/utils/include/utils/json/check_is_jsonable.h +++ b/lib/utils/include/utils/json/check_is_jsonable.h @@ -7,9 +7,9 @@ namespace FlexFlow { #define CHECK_IS_JSONABLE(...) \ - static_assert(is_json_serializable<__VA_ARGS__>::value, \ + static_assert(::FlexFlow::is_json_serializable<__VA_ARGS__>::value, \ #__VA_ARGS__ " should be json serializeable"); \ - static_assert(is_json_deserializable<__VA_ARGS__>::value, \ + static_assert(::FlexFlow::is_json_deserializable<__VA_ARGS__>::value, \ #__VA_ARGS__ " should be json deserializeable") } // namespace FlexFlow diff --git a/lib/utils/include/utils/many_to_one/many_to_one.h b/lib/utils/include/utils/many_to_one/many_to_one.h index d2f727661c..a501a0672c 100644 --- a/lib/utils/include/utils/many_to_one/many_to_one.h +++ b/lib/utils/include/utils/many_to_one/many_to_one.h @@ -1,7 +1,7 @@ #ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_MANY_TO_ONE_MANY_TO_ONE_H #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_MANY_TO_ONE_MANY_TO_ONE_H -#include "utils/containers/keys.h" +#include "utils/containers/require_same.h" #include "utils/containers/try_at.h" #include "utils/containers/unordered_set_of.h" #include "utils/containers/values.h" @@ -19,6 +19,8 @@ #include #include #include +#include "utils/containers/set_of.h" +#include "utils/containers/unordered_keys.h" namespace FlexFlow { @@ -87,7 +89,7 @@ struct ManyToOne { } std::unordered_set left_values() const { - return keys(this->m_l_to_r); + return unordered_keys(this->m_l_to_r); } std::unordered_set> left_groups() const { @@ -95,7 +97,7 @@ struct ManyToOne { } std::unordered_set right_values() const { - return keys(this->m_r_to_l); + return unordered_keys(this->m_r_to_l); } std::unordered_map const &l_to_r() const { @@ -106,6 +108,10 @@ struct ManyToOne { return this->m_r_to_l; } + bool empty() const { + return require_same(this->m_l_to_r.empty(), this->m_r_to_l.empty()); + } + private: std::unordered_map m_l_to_r; std::unordered_map> m_r_to_l; @@ -136,6 +142,22 @@ std::ostream &operator<<(std::ostream &s, ManyToOne const &m) { return (s << fmt::to_string(m)); } +template +std::unordered_set> + unstructured_relation_from_many_to_one(ManyToOne const &many_to_one) { + return unordered_set_of(many_to_one.l_to_r()); +} + +template +ManyToOne many_to_one_from_unstructured_relation( + std::unordered_set> const &relation) { + ManyToOne result; + for (auto const &lr : relation) { + result.insert(lr); + } + return result; +} + } // namespace FlexFlow namespace nlohmann { @@ -146,14 +168,16 @@ struct adl_serializer<::FlexFlow::ManyToOne> { CHECK_IS_JSON_DESERIALIZABLE(L); CHECK_IS_JSON_DESERIALIZABLE(R); - NOT_IMPLEMENTED(); + std::unordered_set> s = j; + + return ::FlexFlow::many_to_one_from_unstructured_relation(s); } static void to_json(json &j, ::FlexFlow::ManyToOne const &m) { CHECK_IS_JSON_SERIALIZABLE(L); CHECK_IS_JSON_SERIALIZABLE(R); - NOT_IMPLEMENTED(); + j = ::FlexFlow::set_of(::FlexFlow::unstructured_relation_from_many_to_one(m)); } }; diff --git a/lib/utils/include/utils/many_to_one/many_to_one_from_unstructured_relation.h b/lib/utils/include/utils/many_to_one/many_to_one_from_unstructured_relation.h deleted file mode 100644 index 171c6c15d6..0000000000 --- a/lib/utils/include/utils/many_to_one/many_to_one_from_unstructured_relation.h +++ /dev/null @@ -1,20 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_MANY_TO_ONE_MANY_TO_ONE_FROM_UNSTRUCTURED_RELATION_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_MANY_TO_ONE_MANY_TO_ONE_FROM_UNSTRUCTURED_RELATION_H - -#include "utils/many_to_one/many_to_one.h" - -namespace FlexFlow { - -template -ManyToOne many_to_one_from_unstructured_relation( - std::unordered_set> const &relation) { - ManyToOne result; - for (auto const &lr : relation) { - result.insert(lr); - } - return result; -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/many_to_one/unstructured_relation_from_many_to_one.h b/lib/utils/include/utils/many_to_one/unstructured_relation_from_many_to_one.h deleted file mode 100644 index 676c5efa5d..0000000000 --- a/lib/utils/include/utils/many_to_one/unstructured_relation_from_many_to_one.h +++ /dev/null @@ -1,17 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_MANY_TO_ONE_UNSTRUCTURED_RELATION_FROM_MANY_TO_ONE_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_MANY_TO_ONE_UNSTRUCTURED_RELATION_FROM_MANY_TO_ONE_H - -#include "utils/containers/unordered_set_of.h" -#include "utils/many_to_one/many_to_one.h" - -namespace FlexFlow { - -template -std::unordered_set> - unstructured_relation_from_many_to_one(ManyToOne const &many_to_one) { - return unordered_set_of(many_to_one.l_to_r()); -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/nonempty_set/nonempty_set.h b/lib/utils/include/utils/nonempty_set/nonempty_set.h new file mode 100644 index 0000000000..fe4b152bd5 --- /dev/null +++ b/lib/utils/include/utils/nonempty_set/nonempty_set.h @@ -0,0 +1,162 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_NONEMPTY_SET_NONEMPTY_SET_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_NONEMPTY_SET_NONEMPTY_SET_H + +#include +#include +#include "utils/hash-utils.h" +#include "utils/hash/set.h" +#include "utils/fmt/set.h" +#include "utils/positive_int/positive_int.h" +#include "utils/containers/unordered_set_of.h" +#include "utils/json/check_is_json_deserializable.h" +#include "utils/json/check_is_json_serializable.h" + +namespace FlexFlow { + +template +struct nonempty_set { +public: + nonempty_set() = delete; + + nonempty_set(std::initializer_list const &vs) : raw(vs) { + ASSERT(this->raw.size() > 0); + } + + explicit nonempty_set(std::set const &s) : raw(s) { + ASSERT(this->raw.size() > 0); + } + + bool operator==(nonempty_set const &other) const { + return this->unwrap_as_set() == other.unwrap_as_set(); + } + + bool operator!=(nonempty_set const &other) const { + return this->unwrap_as_set() != other.unwrap_as_set(); + } + + bool operator<(nonempty_set const &other) const { + return this->unwrap_as_set() < other.unwrap_as_set(); + } + + bool operator<=(nonempty_set const &other) const { + return this->unwrap_as_set() <= other.unwrap_as_set(); + } + + bool operator>(nonempty_set const &other) const { + return this->unwrap_as_set() > other.unwrap_as_set(); + } + + bool operator>=(nonempty_set const &other) const { + return this->unwrap_as_set() >= other.unwrap_as_set(); + } + + bool operator==(std::set const &other) const { + return this->unwrap_as_set() == other; + } + + bool operator!=(std::set const &other) const { + return this->unwrap_as_set() != other; + } + + void insert(T const &t) { + this->raw.insert(t); + } + + size_t size() const { + return this->raw.size(); + }; + + positive_int num_elements() const { + return positive_int{this->raw.size()}; + }; + + std::set const &unwrap_as_set() const { + return this->raw; + } + + std::unordered_set unwrap_as_unordered_set() const { + return unordered_set_of(this->raw); + } + + using const_iterator = typename std::set::const_iterator; + using value_type = T; + using reference = value_type &; + using const_reference = value_type const &; + + const_iterator begin() const { + return this->raw.cbegin(); + } + + const_iterator cbegin() const { + return this->raw.cbegin(); + } + + const_iterator end() const { + return this->raw.cend(); + } + + const_iterator cend() const { + return this->raw.cend(); + } + +private: + std::set raw; +}; + +template +bool operator==(std::set const &lhs, + nonempty_set const &rhs) { + return lhs == rhs.unwrap_as_set(); +} + +template +bool operator!=(std::set const &lhs, + nonempty_set const &rhs) { + return lhs != rhs.unwrap_as_set(); +} + +template +std::set format_as(nonempty_set const &s) { + return s.unwrap_as_set(); +} + +template +std::ostream &operator<<(std::ostream &s, nonempty_set const &m) { + return (s << fmt::to_string(m)); +} + +} // namespace FlexFlow + +namespace nlohmann { + +template +struct adl_serializer<::FlexFlow::nonempty_set> { + static ::FlexFlow::nonempty_set from_json(json const &j) { + CHECK_IS_JSON_DESERIALIZABLE(T); + + std::set s = j; + + return ::FlexFlow::nonempty_set{s}; + } + + static void to_json(json &j, ::FlexFlow::nonempty_set const &s) { + CHECK_IS_JSON_SERIALIZABLE(T); + + j = s.unwrap_as_set(); + } +}; + +} // namespace nlohmann + +namespace std { + +template +struct hash<::FlexFlow::nonempty_set> { + size_t operator()(::FlexFlow::nonempty_set const &x) const { + return ::FlexFlow::get_std_hash(x.unwrap_as_set()); + }; +}; + +} // namespace std + +#endif diff --git a/lib/utils/include/utils/one_to_many/one_to_many.h b/lib/utils/include/utils/one_to_many/one_to_many.h index 30d84d34c3..d57622b950 100644 --- a/lib/utils/include/utils/one_to_many/one_to_many.h +++ b/lib/utils/include/utils/one_to_many/one_to_many.h @@ -4,25 +4,26 @@ #include "utils/containers/generate_map.h" #include "utils/containers/items.h" #include "utils/containers/keys.h" +#include "utils/containers/require_same.h" #include "utils/containers/transform.h" #include "utils/containers/try_at.h" -#include "utils/containers/unordered_set_of.h" #include "utils/containers/values.h" #include "utils/exception.h" -#include "utils/fmt/unordered_map.h" -#include "utils/fmt/unordered_set.h" +#include "utils/fmt/map.h" +#include "utils/fmt/set.h" #include "utils/hash-utils.h" #include "utils/hash/tuple.h" -#include "utils/hash/unordered_map.h" -#include "utils/hash/unordered_set.h" +#include "utils/hash/map.h" +#include "utils/hash/set.h" #include "utils/json/check_is_json_deserializable.h" #include "utils/json/check_is_json_serializable.h" -#include "utils/nonempty_unordered_set/nonempty_unordered_set.h" +#include "utils/nonempty_set/nonempty_set.h" #include #include #include -#include -#include +#include +#include +#include "utils/containers/set_of.h" namespace FlexFlow { @@ -53,6 +54,22 @@ struct OneToMany { return this->tie() != other.tie(); } + bool operator<(OneToMany const &other) const { + return this->tie() < other.tie(); + } + + bool operator<=(OneToMany const &other) const { + return this->tie() <= other.tie(); + } + + bool operator>(OneToMany const &other) const { + return this->tie() > other.tie(); + } + + bool operator>=(OneToMany const &other) const { + return this->tie() >= other.tie(); + } + void insert(std::pair const &p) { L l = p.first; R r = p.second; @@ -65,7 +82,7 @@ struct OneToMany { if (contains_key(this->m_l_to_r, l)) { this->m_l_to_r.at(l).insert(r); } else { - this->m_l_to_r.insert({l, nonempty_unordered_set{{r}}}); + this->m_l_to_r.insert({l, nonempty_set{{r}}}); } } else if (found_l.value() == l) { return; @@ -79,14 +96,14 @@ struct OneToMany { } } - std::unordered_set> relation() const { + std::set> relation() const { return transform(items(this->m_r_to_l), [](std::pair const &p) -> std::pair { return {p.second, p.first}; }); } - nonempty_unordered_set const &at_l(L const &l) const { + nonempty_set const &at_l(L const &l) const { return this->m_l_to_r.at(l); } @@ -94,29 +111,33 @@ struct OneToMany { return this->m_r_to_l.at(r); } - std::unordered_set left_values() const { + std::set left_values() const { return keys(this->m_l_to_r); } - std::unordered_set right_values() const { + std::set right_values() const { return keys(this->m_r_to_l); } - std::unordered_set> right_groups() const { - return unordered_set_of(values(this->m_l_to_r)); + std::set> right_groups() const { + return set_of(values(this->m_l_to_r)); } - std::unordered_map> const &l_to_r() const { + std::map> const &l_to_r() const { return this->m_l_to_r; } - std::unordered_map const &r_to_l() const { + std::map const &r_to_l() const { return this->m_r_to_l; } + bool empty() const { + return require_same(this->m_l_to_r.empty(), this->m_r_to_l.empty()); + } + private: - std::unordered_map> m_l_to_r; - std::unordered_map m_r_to_l; + std::map> m_l_to_r; + std::map m_r_to_l; private: std::tuple @@ -128,7 +149,7 @@ struct OneToMany { }; template -std::unordered_map> +std::map> format_as(OneToMany const &m) { return generate_map(m.left_values(), [&](L const &l) { return m.at_l(l); }); } @@ -138,6 +159,25 @@ std::ostream &operator<<(std::ostream &s, OneToMany const &m) { return (s << fmt::to_string(m)); } +template +std::unordered_set> + unstructured_relation_from_one_to_many(OneToMany const &one_to_many) { + return transform(unordered_set_of(one_to_many.r_to_l()), + [](std::pair const &rl) -> std::pair { + return std::pair{rl.second, rl.first}; + }); +} + +template +OneToMany one_to_many_from_unstructured_relation( + std::unordered_set> const &rel) { + OneToMany result; + for (auto const &lr : rel) { + result.insert(lr); + } + return result; +} + } // namespace FlexFlow namespace nlohmann { @@ -148,14 +188,16 @@ struct adl_serializer<::FlexFlow::OneToMany> { CHECK_IS_JSON_DESERIALIZABLE(L); CHECK_IS_JSON_DESERIALIZABLE(R); - NOT_IMPLEMENTED(); + std::unordered_set> s = j; + + return ::FlexFlow::one_to_many_from_unstructured_relation(s); } static void to_json(json &j, ::FlexFlow::OneToMany const &m) { CHECK_IS_JSON_SERIALIZABLE(L); CHECK_IS_JSON_SERIALIZABLE(R); - NOT_IMPLEMENTED(); + j = ::FlexFlow::set_of(::FlexFlow::unstructured_relation_from_one_to_many(m)); } }; diff --git a/lib/utils/include/utils/one_to_many/one_to_many_from_unstructured_relation.h b/lib/utils/include/utils/one_to_many/one_to_many_filter_keys.h similarity index 50% rename from lib/utils/include/utils/one_to_many/one_to_many_from_unstructured_relation.h rename to lib/utils/include/utils/one_to_many/one_to_many_filter_keys.h index 11c0a767d6..5f31fbb584 100644 --- a/lib/utils/include/utils/one_to_many/one_to_many_from_unstructured_relation.h +++ b/lib/utils/include/utils/one_to_many/one_to_many_filter_keys.h @@ -1,16 +1,17 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ONE_TO_MANY_ONE_TO_MANY_FROM_UNSTRUCTURED_RELATION_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ONE_TO_MANY_ONE_TO_MANY_FROM_UNSTRUCTURED_RELATION_H +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ONE_TO_MANY_ONE_TO_MANY_FILTER_KEYS_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ONE_TO_MANY_ONE_TO_MANY_FILTER_KEYS_H #include "utils/one_to_many/one_to_many.h" namespace FlexFlow { -template -OneToMany one_to_many_from_unstructured_relation( - std::unordered_set> const &rel) { +template +OneToMany one_to_many_filter_keys(OneToMany const &m, F &&f) { OneToMany result; - for (auto const &lr : rel) { - result.insert(lr); + for (auto const &kv : unstructured_relation_from_one_to_many(m)) { + if (f(kv.first)) { + result.insert(kv); + } } return result; } diff --git a/lib/utils/include/utils/one_to_many/one_to_many_filter_values.h b/lib/utils/include/utils/one_to_many/one_to_many_filter_values.h new file mode 100644 index 0000000000..4694b06d65 --- /dev/null +++ b/lib/utils/include/utils/one_to_many/one_to_many_filter_values.h @@ -0,0 +1,22 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ONE_TO_MANY_ONE_TO_MANY_FILTER_VALUES_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ONE_TO_MANY_ONE_TO_MANY_FILTER_VALUES_H + +#include "utils/one_to_many/one_to_many.h" + +namespace FlexFlow { + +template +OneToMany one_to_many_filter_values(OneToMany const &m, F &&f) { + OneToMany result; + for (auto const &kv : unstructured_relation_from_one_to_many(m)) { + if (f(kv.second)) { + result.insert(kv); + } + } + return result; +} + +} // namespace FlexFlow + + +#endif diff --git a/lib/utils/include/utils/one_to_many/one_to_many_transform_values.h b/lib/utils/include/utils/one_to_many/one_to_many_transform_values.h index a9afe98988..050f6ad7dd 100644 --- a/lib/utils/include/utils/one_to_many/one_to_many_transform_values.h +++ b/lib/utils/include/utils/one_to_many/one_to_many_transform_values.h @@ -2,7 +2,8 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ONE_TO_MANY_ONE_TO_MANY_TRANSFORM_VALUES_H #include "utils/containers/transform.h" -#include "utils/one_to_many/one_to_many_from_unstructured_relation.h" +#include "utils/one_to_many/one_to_many.h" + namespace FlexFlow { @@ -13,7 +14,8 @@ template one_to_many_transform_values(OneToMany const &input, F f) { return one_to_many_from_unstructured_relation(transform( - input.relation(), [&](std::pair const &p) -> std::pair { + unordered_set_of(input.relation()), + [&](std::pair const &p) -> std::pair { return {p.first, f(p.second)}; })); } diff --git a/lib/utils/include/utils/one_to_many/require_one_to_many_is_bijection.h b/lib/utils/include/utils/one_to_many/require_one_to_many_is_bijection.h new file mode 100644 index 0000000000..9d31dc968c --- /dev/null +++ b/lib/utils/include/utils/one_to_many/require_one_to_many_is_bijection.h @@ -0,0 +1,23 @@ +#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ONE_TO_MANY_REQUIRE_ONE_TO_MANY_IS_BIJECTION_H +#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ONE_TO_MANY_REQUIRE_ONE_TO_MANY_IS_BIJECTION_H + +#include "utils/bidict/algorithms/bidict_from_map.h" +#include "utils/containers/map_values.h" +#include "utils/containers/get_only.h" +#include "utils/one_to_many/one_to_many.h" +#include "utils/nonempty_set/nonempty_set.h" + +namespace FlexFlow { + +template +bidict require_one_to_many_is_bijection(OneToMany const &otm) { + return bidict_from_map( + map_values(otm.l_to_r(), + [](nonempty_set const &s) -> R { + return get_only(s.unwrap_as_set()); + })); +} + +} // namespace FlexFlow + +#endif diff --git a/lib/utils/include/utils/one_to_many/unstructured_relation_from_one_to_many.h b/lib/utils/include/utils/one_to_many/unstructured_relation_from_one_to_many.h deleted file mode 100644 index 02ed225610..0000000000 --- a/lib/utils/include/utils/one_to_many/unstructured_relation_from_one_to_many.h +++ /dev/null @@ -1,21 +0,0 @@ -#ifndef _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ONE_TO_MANY_UNSTRUCTURED_RELATION_FROM_ONE_TO_MANY_H -#define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ONE_TO_MANY_UNSTRUCTURED_RELATION_FROM_ONE_TO_MANY_H - -#include "utils/containers/transform.h" -#include "utils/containers/unordered_set_of.h" -#include "utils/one_to_many/one_to_many.h" - -namespace FlexFlow { - -template -std::unordered_set> - unstructured_relation_from_one_to_many(OneToMany const &one_to_many) { - return transform(unordered_set_of(one_to_many.r_to_l()), - [](std::pair const &rl) -> std::pair { - return std::pair{rl.second, rl.first}; - }); -} - -} // namespace FlexFlow - -#endif diff --git a/lib/utils/include/utils/orthotope/dim_coord.h b/lib/utils/include/utils/orthotope/dim_coord.h index 87a05a7315..b9a10d6750 100644 --- a/lib/utils/include/utils/orthotope/dim_coord.h +++ b/lib/utils/include/utils/orthotope/dim_coord.h @@ -3,10 +3,10 @@ #include "utils/containers/all_of.h" #include "utils/containers/contains_key.h" -#include "utils/containers/generate_map.h" +#include "utils/containers/generate_unordered_map.h" #include "utils/containers/get_all_assignments.h" #include "utils/containers/is_subseteq_of.h" -#include "utils/containers/keys.h" +#include "utils/containers/unordered_keys.h" #include "utils/containers/map_from_keys_and_values.h" #include "utils/containers/map_values.h" #include "utils/containers/product.h" @@ -29,7 +29,7 @@ namespace FlexFlow { template std::unordered_set get_coord_dims(DimCoord const &coord) { - return keys(coord.raw); + return unordered_keys(coord.raw); } template @@ -65,7 +65,7 @@ DimCoord lift_dim_coord(DimCoord const &coord, ASSERT(is_subseteq_of(get_coord_dims(coord), lifted_dims)); return DimCoord{ - generate_map(lifted_dims, + generate_unordered_map(lifted_dims, [&](T const &dim) { if (contains_key(coord.raw, dim)) { return coord.raw.at(dim); diff --git a/lib/utils/include/utils/orthotope/dim_domain.h b/lib/utils/include/utils/orthotope/dim_domain.h index c940745a78..6bd63faeae 100644 --- a/lib/utils/include/utils/orthotope/dim_domain.h +++ b/lib/utils/include/utils/orthotope/dim_domain.h @@ -2,7 +2,7 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_DIM_DOMAIN_H #include "utils/containers/filter.h" -#include "utils/containers/keys.h" +#include "utils/containers/unordered_keys.h" #include "utils/containers/map_from_keys_and_values.h" #include "utils/containers/restrict_keys.h" #include "utils/containers/set_minus.h" @@ -27,7 +27,7 @@ nonnegative_int dim_domain_num_dims(DimDomain const &domain) { template std::unordered_set get_domain_dims(DimDomain const &domain) { - return keys(domain.dims); + return unordered_keys(domain.dims); } template diff --git a/lib/utils/include/utils/orthotope/dim_projection.dtg.toml b/lib/utils/include/utils/orthotope/dim_projection.dtg.toml index a530adac5d..9133c1a03f 100644 --- a/lib/utils/include/utils/orthotope/dim_projection.dtg.toml +++ b/lib/utils/include/utils/orthotope/dim_projection.dtg.toml @@ -3,6 +3,7 @@ name = "DimProjection" type = "variant" features = [ "eq", + "ord", "hash", "fmt", ] diff --git a/lib/utils/include/utils/orthotope/down_projection.dtg.toml b/lib/utils/include/utils/orthotope/down_projection.dtg.toml index 9a642d2b9f..a83fd04e64 100644 --- a/lib/utils/include/utils/orthotope/down_projection.dtg.toml +++ b/lib/utils/include/utils/orthotope/down_projection.dtg.toml @@ -3,6 +3,7 @@ name = "DownProjection" type = "struct" features = [ "eq", + "ord", "hash", "fmt", ] diff --git a/lib/utils/include/utils/orthotope/down_projection.h b/lib/utils/include/utils/orthotope/down_projection.h index f46a0f16c8..8fb1487ad3 100644 --- a/lib/utils/include/utils/orthotope/down_projection.h +++ b/lib/utils/include/utils/orthotope/down_projection.h @@ -48,7 +48,7 @@ DimCoord compute_down_projection(DownProjection const &projection, output_dims_of_down_projection(projection); return DimCoord{ - generate_map( + generate_unordered_map( output_dims, [&](R const &output_dim) { std::unordered_set src_dims = diff --git a/lib/utils/include/utils/orthotope/eq_projection.dtg.toml b/lib/utils/include/utils/orthotope/eq_projection.dtg.toml index 972952f907..456a44b2d3 100644 --- a/lib/utils/include/utils/orthotope/eq_projection.dtg.toml +++ b/lib/utils/include/utils/orthotope/eq_projection.dtg.toml @@ -3,6 +3,7 @@ name = "EqProjection" type = "struct" features = [ "eq", + "ord", "hash", "fmt", "rapidcheck", diff --git a/lib/utils/include/utils/orthotope/minimal_dim_domain.h b/lib/utils/include/utils/orthotope/minimal_dim_domain.h index 3934e2af62..f9bf5fa979 100644 --- a/lib/utils/include/utils/orthotope/minimal_dim_domain.h +++ b/lib/utils/include/utils/orthotope/minimal_dim_domain.h @@ -2,10 +2,8 @@ #define _FLEXFLOW_LIB_UTILS_INCLUDE_UTILS_ORTHOTOPE_MINIMAL_DIM_DOMAIN_H #include "utils/containers/are_disjoint.h" -#include "utils/containers/binary_merge_disjoint_maps.h" #include "utils/containers/filtermap_values.h" -#include "utils/containers/generate_map.h" -#include "utils/containers/keys.h" +#include "utils/containers/generate_unordered_map.h" #include "utils/containers/map_from_keys_and_values.h" #include "utils/containers/map_values.h" #include "utils/containers/restrict_keys.h" @@ -16,6 +14,8 @@ #include "utils/orthotope/dim_ordering.dtg.h" #include "utils/orthotope/minimal_dim_domain.dtg.h" #include "utils/orthotope/minimal_orthotope.dtg.h" +#include "utils/containers/unordered_keys.h" +#include "utils/containers/binary_merge_disjoint_unordered_maps.h" namespace FlexFlow { @@ -66,18 +66,18 @@ DimDomain dim_domain_from_minimal_dim_domain( ASSERT(are_disjoint(nontrivial_dims, trivial_dims)); return DimDomain{ - /*dims=*/binary_merge_disjoint_maps( + /*dims=*/binary_merge_disjoint_unordered_maps( map_values( minimal_dim_domain.dims, [](int_ge_two x) { return x.positive_int_from_int_ge_two(); }), - generate_map(trivial_dims, [](T const &) { return 1_p; })), + generate_unordered_map(trivial_dims, [](T const &) { return 1_p; })), }; } template std::unordered_set get_minimal_domain_dims(MinimalDimDomain const &domain) { - return keys(domain.dims); + return unordered_keys(domain.dims); } template diff --git a/lib/utils/include/utils/orthotope/up_projection.dtg.toml b/lib/utils/include/utils/orthotope/up_projection.dtg.toml index c99e6eec93..7f69c9dd0e 100644 --- a/lib/utils/include/utils/orthotope/up_projection.dtg.toml +++ b/lib/utils/include/utils/orthotope/up_projection.dtg.toml @@ -3,6 +3,7 @@ name = "UpProjection" type = "struct" features = [ "eq", + "ord", "hash", "fmt", ] diff --git a/lib/utils/include/utils/orthotope/up_projection.h b/lib/utils/include/utils/orthotope/up_projection.h index e485419fbb..1e241108e2 100644 --- a/lib/utils/include/utils/orthotope/up_projection.h +++ b/lib/utils/include/utils/orthotope/up_projection.h @@ -24,13 +24,13 @@ UpProjection make_empty_up_projection() { template std::unordered_set input_dims_of_up_projection(UpProjection const &projection) { - return projection.dim_mapping.left_values(); + return unordered_set_of(projection.dim_mapping.left_values()); } template std::unordered_set output_dims_of_up_projection(UpProjection const &projection) { - return projection.dim_mapping.right_values(); + return unordered_set_of(projection.dim_mapping.right_values()); } template diff --git a/lib/utils/src/utils/archetypes/jsonable_ordered_value_type.cc b/lib/utils/src/utils/archetypes/jsonable_ordered_value_type.cc new file mode 100644 index 0000000000..b83da321f7 --- /dev/null +++ b/lib/utils/src/utils/archetypes/jsonable_ordered_value_type.cc @@ -0,0 +1,7 @@ +#include "utils/archetypes/jsonable_ordered_value_type.h" + +namespace FlexFlow { + +template struct jsonable_ordered_value_type<0>; + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/bidict/algorithms/filter_keys.cc b/lib/utils/src/utils/bidict/algorithms/bidict_filter_keys.cc similarity index 58% rename from lib/utils/src/utils/bidict/algorithms/filter_keys.cc rename to lib/utils/src/utils/bidict/algorithms/bidict_filter_keys.cc index 57ef4e873d..5c84ec85b1 100644 --- a/lib/utils/src/utils/bidict/algorithms/filter_keys.cc +++ b/lib/utils/src/utils/bidict/algorithms/bidict_filter_keys.cc @@ -1,4 +1,4 @@ -#include "utils/bidict/algorithms/filter_keys.h" +#include "utils/bidict/algorithms/bidict_filter_keys.h" #include "utils/archetypes/value_type.h" namespace FlexFlow { @@ -7,6 +7,6 @@ using K = value_type<0>; using V = value_type<1>; using F = std::function; -template bidict filter_keys(bidict const &, F &&); +template bidict bidict_filter_keys(bidict const &, F &&); } // namespace FlexFlow diff --git a/lib/utils/src/utils/bidict/algorithms/filter_values.cc b/lib/utils/src/utils/bidict/algorithms/bidict_filter_values.cc similarity index 57% rename from lib/utils/src/utils/bidict/algorithms/filter_values.cc rename to lib/utils/src/utils/bidict/algorithms/bidict_filter_values.cc index 4cf58037ee..7adf808f85 100644 --- a/lib/utils/src/utils/bidict/algorithms/filter_values.cc +++ b/lib/utils/src/utils/bidict/algorithms/bidict_filter_values.cc @@ -1,4 +1,4 @@ -#include "utils/bidict/algorithms/filter_values.h" +#include "utils/bidict/algorithms/bidict_filter_values.h" #include "utils/archetypes/value_type.h" namespace FlexFlow { @@ -7,6 +7,6 @@ using K = value_type<0>; using V = value_type<1>; using F = std::function; -template bidict filter_values(bidict const &, F &&); +template bidict bidict_filter_values(bidict const &, F &&); } // namespace FlexFlow diff --git a/lib/utils/src/utils/bidict/algorithms/filtrans_keys.cc b/lib/utils/src/utils/bidict/algorithms/bidict_filtrans_keys.cc similarity index 61% rename from lib/utils/src/utils/bidict/algorithms/filtrans_keys.cc rename to lib/utils/src/utils/bidict/algorithms/bidict_filtrans_keys.cc index 1e506b8a51..954558a66a 100644 --- a/lib/utils/src/utils/bidict/algorithms/filtrans_keys.cc +++ b/lib/utils/src/utils/bidict/algorithms/bidict_filtrans_keys.cc @@ -1,4 +1,4 @@ -#include "utils/bidict/algorithms/filtrans_keys.h" +#include "utils/bidict/algorithms/bidict_filtrans_keys.h" #include "utils/archetypes/value_type.h" namespace FlexFlow { @@ -8,6 +8,6 @@ using V = value_type<1>; using K2 = value_type<2>; using F = std::function(K)>; -template bidict filtrans_keys(bidict const &, F &&); +template bidict bidict_filtrans_keys(bidict const &, F &&); } // namespace FlexFlow diff --git a/lib/utils/src/utils/bidict/algorithms/filtrans_values.cc b/lib/utils/src/utils/bidict/algorithms/bidict_filtrans_values.cc similarity index 61% rename from lib/utils/src/utils/bidict/algorithms/filtrans_values.cc rename to lib/utils/src/utils/bidict/algorithms/bidict_filtrans_values.cc index 8d1352196c..303d40575a 100644 --- a/lib/utils/src/utils/bidict/algorithms/filtrans_values.cc +++ b/lib/utils/src/utils/bidict/algorithms/bidict_filtrans_values.cc @@ -1,4 +1,4 @@ -#include "utils/bidict/algorithms/filtrans_values.h" +#include "utils/bidict/algorithms/bidict_filtrans_values.h" #include "utils/archetypes/value_type.h" namespace FlexFlow { @@ -8,6 +8,6 @@ using V = value_type<1>; using V2 = value_type<2>; using F = std::function(V)>; -template bidict filtrans_values(bidict const &, F &&); +template bidict bidict_filtrans_values(bidict const &, F &&); } // namespace FlexFlow diff --git a/lib/utils/src/utils/bidict/algorithms/bidict_unordered_set_of.cc b/lib/utils/src/utils/bidict/algorithms/bidict_unordered_set_of.cc new file mode 100644 index 0000000000..425c1d099c --- /dev/null +++ b/lib/utils/src/utils/bidict/algorithms/bidict_unordered_set_of.cc @@ -0,0 +1,11 @@ +#include "utils/bidict/algorithms/bidict_unordered_set_of.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using K = value_type<0>; +using V = value_type<1>; + +std::unordered_set> bidict_unordered_set_of(bidict const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/bidict/algorithms/binary_merge_disjoint_bidicts.cc b/lib/utils/src/utils/bidict/algorithms/binary_merge_disjoint_bidicts.cc new file mode 100644 index 0000000000..13a1bcd968 --- /dev/null +++ b/lib/utils/src/utils/bidict/algorithms/binary_merge_disjoint_bidicts.cc @@ -0,0 +1,12 @@ +#include "utils/bidict/algorithms/binary_merge_disjoint_bidicts.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using K = value_type<0>; +using V = value_type<1>; + +template bidict binary_merge_disjoint_bidicts(bidict const &, + bidict const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/bidict/algorithms/merge_disjoint_bidicts.cc b/lib/utils/src/utils/bidict/algorithms/merge_disjoint_bidicts.cc index 754b8d2e90..2c27821d3b 100644 --- a/lib/utils/src/utils/bidict/algorithms/merge_disjoint_bidicts.cc +++ b/lib/utils/src/utils/bidict/algorithms/merge_disjoint_bidicts.cc @@ -1 +1,11 @@ #include "utils/bidict/algorithms/merge_disjoint_bidicts.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using K = value_type<0>; +using V = value_type<1>; + +template bidict merge_disjoint_bidicts(std::vector> const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/bidict/algorithms/unordered_set_of.cc b/lib/utils/src/utils/bidict/algorithms/unordered_set_of.cc deleted file mode 100644 index a0bfa1525e..0000000000 --- a/lib/utils/src/utils/bidict/algorithms/unordered_set_of.cc +++ /dev/null @@ -1,11 +0,0 @@ -#include "utils/bidict/algorithms/unordered_set_of.h" -#include "utils/archetypes/value_type.h" - -namespace FlexFlow { - -using K = value_type<0>; -using V = value_type<1>; - -std::unordered_set> unordered_set_of(bidict const &); - -} // namespace FlexFlow diff --git a/lib/utils/src/utils/cli/cli_parse.cc b/lib/utils/src/utils/cli/cli_parse.cc index 36d5837f9c..8f5f81324c 100644 --- a/lib/utils/src/utils/cli/cli_parse.cc +++ b/lib/utils/src/utils/cli/cli_parse.cc @@ -2,7 +2,7 @@ #include "utils/cli/cli_spec.h" #include "utils/containers/contains.h" #include "utils/containers/enumerate.h" -#include "utils/containers/generate_map.h" +#include "utils/containers/generate_unordered_map.h" namespace FlexFlow { @@ -27,7 +27,7 @@ tl::expected cli_parse_flag(CLISpec const &cli, tl::expected cli_parse(CLISpec const &cli, std::vector const &args) { CLIParseResult result = CLIParseResult{ - generate_map(cli_get_flag_keys(cli), + generate_unordered_map(cli_get_flag_keys(cli), [](CLIFlagKey const &) { return false; }), {}, }; diff --git a/lib/utils/src/utils/containers/binary_merge_disjoint_maps.cc b/lib/utils/src/utils/containers/binary_merge_disjoint_maps.cc index 0569b3ed0b..95d19083ac 100644 --- a/lib/utils/src/utils/containers/binary_merge_disjoint_maps.cc +++ b/lib/utils/src/utils/containers/binary_merge_disjoint_maps.cc @@ -1,13 +1,14 @@ #include "utils/containers/binary_merge_disjoint_maps.h" #include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" namespace FlexFlow { -using K = value_type<0>; +using K = ordered_value_type<0>; using V = value_type<1>; -template std::unordered_map - binary_merge_disjoint_maps(std::unordered_map const &, - std::unordered_map const &); +template std::map + binary_merge_disjoint_maps(std::map const &, + std::map const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/binary_merge_disjoint_unordered_maps.cc b/lib/utils/src/utils/containers/binary_merge_disjoint_unordered_maps.cc new file mode 100644 index 0000000000..60ccf3a7e5 --- /dev/null +++ b/lib/utils/src/utils/containers/binary_merge_disjoint_unordered_maps.cc @@ -0,0 +1,13 @@ +#include "utils/containers/binary_merge_disjoint_unordered_maps.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using K = value_type<0>; +using V = value_type<1>; + +template std::unordered_map + binary_merge_disjoint_unordered_maps(std::unordered_map const &, + std::unordered_map const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/binary_merge_maps_with.cc b/lib/utils/src/utils/containers/binary_merge_maps_with.cc index 35f771f60c..4679d21227 100644 --- a/lib/utils/src/utils/containers/binary_merge_maps_with.cc +++ b/lib/utils/src/utils/containers/binary_merge_maps_with.cc @@ -1,13 +1,14 @@ #include "utils/containers/binary_merge_maps_with.h" #include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" namespace FlexFlow { -using K = value_type<0>; +using K = ordered_value_type<0>; using V = value_type<1>; using F = std::function; -template std::unordered_map binary_merge_maps_with( - std::unordered_map const &, std::unordered_map const &, F &&); +template std::map binary_merge_maps_with( + std::map const &, std::map const &, F &&); } // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/binary_merge_maps_with_left_dominating.cc b/lib/utils/src/utils/containers/binary_merge_maps_with_left_dominating.cc index c459e82061..d5b4f6cebe 100644 --- a/lib/utils/src/utils/containers/binary_merge_maps_with_left_dominating.cc +++ b/lib/utils/src/utils/containers/binary_merge_maps_with_left_dominating.cc @@ -1,13 +1,14 @@ #include "utils/containers/binary_merge_maps_with_left_dominating.h" #include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" namespace FlexFlow { -using K = value_type<0>; +using K = ordered_value_type<0>; using V = value_type<1>; -template std::unordered_map - binary_merge_maps_with_left_dominating(std::unordered_map const &, - std::unordered_map const &); +template std::map + binary_merge_maps_with_left_dominating(std::map const &, + std::map const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/binary_merge_maps_with_right_dominating.cc b/lib/utils/src/utils/containers/binary_merge_maps_with_right_dominating.cc index df934387d2..bbc799150e 100644 --- a/lib/utils/src/utils/containers/binary_merge_maps_with_right_dominating.cc +++ b/lib/utils/src/utils/containers/binary_merge_maps_with_right_dominating.cc @@ -1,13 +1,14 @@ #include "utils/containers/binary_merge_maps_with_right_dominating.h" #include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" namespace FlexFlow { -using K = value_type<0>; +using K = ordered_value_type<0>; using V = value_type<1>; -template std::unordered_map - binary_merge_maps_with_right_dominating(std::unordered_map const &, - std::unordered_map const &); +template std::map + binary_merge_maps_with_right_dominating(std::map const &, + std::map const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/binary_merge_unordered_maps_with.cc b/lib/utils/src/utils/containers/binary_merge_unordered_maps_with.cc new file mode 100644 index 0000000000..de8c484d0c --- /dev/null +++ b/lib/utils/src/utils/containers/binary_merge_unordered_maps_with.cc @@ -0,0 +1,13 @@ +#include "utils/containers/binary_merge_unordered_maps_with.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using K = value_type<0>; +using V = value_type<1>; +using F = std::function; + +template std::unordered_map binary_merge_unordered_maps_with( + std::unordered_map const &, std::unordered_map const &, F &&); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/binary_merge_unordered_maps_with_left_dominating.cc b/lib/utils/src/utils/containers/binary_merge_unordered_maps_with_left_dominating.cc new file mode 100644 index 0000000000..d777eb0e29 --- /dev/null +++ b/lib/utils/src/utils/containers/binary_merge_unordered_maps_with_left_dominating.cc @@ -0,0 +1,13 @@ +#include "utils/containers/binary_merge_unordered_maps_with_left_dominating.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using K = value_type<0>; +using V = value_type<1>; + +template std::unordered_map + binary_merge_unordered_maps_with_left_dominating(std::unordered_map const &, + std::unordered_map const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/binary_merge_unordered_maps_with_right_dominating.cc b/lib/utils/src/utils/containers/binary_merge_unordered_maps_with_right_dominating.cc new file mode 100644 index 0000000000..f5586cec6b --- /dev/null +++ b/lib/utils/src/utils/containers/binary_merge_unordered_maps_with_right_dominating.cc @@ -0,0 +1,13 @@ +#include "utils/containers/binary_merge_unordered_maps_with_right_dominating.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using K = value_type<0>; +using V = value_type<1>; + +template + std::unordered_map binary_merge_unordered_maps_with_right_dominating( + std::unordered_map const &, std::unordered_map const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/filter.cc b/lib/utils/src/utils/containers/filter.cc index dc11d0dffa..4931d97704 100644 --- a/lib/utils/src/utils/containers/filter.cc +++ b/lib/utils/src/utils/containers/filter.cc @@ -1 +1,45 @@ #include "utils/containers/filter.h" +#include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" + +namespace FlexFlow { + +using VT0 = value_type<0>; +using VT1 = value_type<1>; +using OVT0 = ordered_value_type<0>; + +template + std::vector filter(std::vector const &, + std::function const &); + +template + std::unordered_set filter(std::unordered_set const &, + std::function const &); + +template + std::unordered_map + filter(std::unordered_map const &, + std::function const &)> const &); + +template + std::set filter( + std::set const &, + std::function const &); + +template + std::map + filter(std::map const &, + std::function const &)> const &); + +template + std::multiset + filter(std::multiset const &, + std::function const &); + +template + std::unordered_multiset + filter(std::unordered_multiset const &, + std::function const &); + + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/generate_map.cc b/lib/utils/src/utils/containers/generate_map.cc index 54bbe13dc9..57caa3d062 100644 --- a/lib/utils/src/utils/containers/generate_map.cc +++ b/lib/utils/src/utils/containers/generate_map.cc @@ -1 +1,13 @@ #include "utils/containers/generate_map.h" +#include "utils/archetypes/ordered_value_type.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using K = ordered_value_type<0>; +using V = value_type<1>; +using F = std::function; + +template std::map generate_map(std::vector const &, F &&); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/generate_unordered_map.cc b/lib/utils/src/utils/containers/generate_unordered_map.cc new file mode 100644 index 0000000000..2287e45632 --- /dev/null +++ b/lib/utils/src/utils/containers/generate_unordered_map.cc @@ -0,0 +1,13 @@ +#include "utils/containers/generate_unordered_map.h" +#include "utils/archetypes/value_type.h" +#include + +namespace FlexFlow { + +using K = value_type<0>; +using V = value_type<1>; +using F = std::function; + +template std::unordered_map generate_unordered_map(std::unordered_set const &, F &&); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/group_by.cc b/lib/utils/src/utils/containers/group_by.cc index efd8c2032b..a41ab4dc62 100644 --- a/lib/utils/src/utils/containers/group_by.cc +++ b/lib/utils/src/utils/containers/group_by.cc @@ -4,8 +4,8 @@ namespace FlexFlow { -using K = value_type<0>; -using V = value_type<1>; +using K = ordered_value_type<0>; +using V = ordered_value_type<1>; using F = std::function; template OneToMany group_by(std::unordered_set const &, F &&); diff --git a/lib/utils/src/utils/containers/is_submapeq_of.cc b/lib/utils/src/utils/containers/is_submapeq_of.cc index 567d94fac5..f8fd627b3d 100644 --- a/lib/utils/src/utils/containers/is_submapeq_of.cc +++ b/lib/utils/src/utils/containers/is_submapeq_of.cc @@ -1 +1,12 @@ #include "utils/containers/is_submapeq_of.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using K = value_type<0>; +using V = value_type<1>; + +bool is_submapeq_of(std::unordered_map const &, std::unordered_map const &); + + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/items.cc b/lib/utils/src/utils/containers/items.cc index 3b1e80452a..193cec9ca8 100644 --- a/lib/utils/src/utils/containers/items.cc +++ b/lib/utils/src/utils/containers/items.cc @@ -1 +1,14 @@ #include "utils/containers/items.h" +#include "utils/archetypes/ordered_value_type.h" +#include +#include + +namespace FlexFlow { + +using K = ordered_value_type<0>; +using V = ordered_value_type<1>; + +template std::set> items(std::unordered_map const &); +template std::set> items(std::map const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/keys.cc b/lib/utils/src/utils/containers/keys.cc index 6c6abadd56..96db33f4c7 100644 --- a/lib/utils/src/utils/containers/keys.cc +++ b/lib/utils/src/utils/containers/keys.cc @@ -1 +1,13 @@ #include "utils/containers/keys.h" +#include "utils/archetypes/ordered_value_type.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using K = ordered_value_type<0>; +using V = value_type<1>; + +template std::set keys(std::unordered_map const &); +template std::set keys(std::map const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/map_from_pairs.cc b/lib/utils/src/utils/containers/map_from_pairs.cc index ba0eed8c15..8dc8ffa29c 100644 --- a/lib/utils/src/utils/containers/map_from_pairs.cc +++ b/lib/utils/src/utils/containers/map_from_pairs.cc @@ -1,12 +1,21 @@ #include "utils/containers/map_from_pairs.h" -#include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" +#include +#include +#include namespace FlexFlow { -using K = value_type<0>; -using V = value_type<1>; +using K = ordered_value_type<0>; +using V = ordered_value_type<1>; -template std::unordered_map +template std::map + map_from_pairs(std::set> const &); + +template std::map map_from_pairs(std::unordered_set> const &); +template std::map + map_from_pairs(std::vector> const &); + } // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/map_from_unordered.cc b/lib/utils/src/utils/containers/map_from_unordered.cc new file mode 100644 index 0000000000..11558af765 --- /dev/null +++ b/lib/utils/src/utils/containers/map_from_unordered.cc @@ -0,0 +1,13 @@ +#include "utils/containers/map_from_unordered.h" +#include "utils/archetypes/ordered_value_type.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using K = ordered_value_type<0>; +using V = value_type<1>; + +template + std::map map_from_unordered(std::unordered_map const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/map_keys.cc b/lib/utils/src/utils/containers/map_keys.cc index 7473c7e16d..daf9ed25d4 100644 --- a/lib/utils/src/utils/containers/map_keys.cc +++ b/lib/utils/src/utils/containers/map_keys.cc @@ -1 +1,22 @@ #include "utils/containers/map_keys.h" +#include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" + +namespace FlexFlow { + +using VT0 = value_type<0>; +using VT1 = value_type<1>; +using VT2 = value_type<2>; + +template + std::unordered_map map_keys(std::unordered_map const &, + std::function &&); + +using OV0 = ordered_value_type<0>; +using OV1 = ordered_value_type<1>; + +template + std::map map_keys(std::map const &m, + std::function &&); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/map_keys_and_values.cc b/lib/utils/src/utils/containers/map_keys_and_values.cc index b3b306988e..95608a09bd 100644 --- a/lib/utils/src/utils/containers/map_keys_and_values.cc +++ b/lib/utils/src/utils/containers/map_keys_and_values.cc @@ -1,5 +1,6 @@ #include "utils/containers/map_keys_and_values.h" #include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" namespace FlexFlow { @@ -13,4 +14,11 @@ using FV = std::function; template std::unordered_map map_keys_and_values( std::unordered_map const &, FK const &, FV const &); +using OK = ordered_value_type<0>; +using OK2 = ordered_value_type<1>; +using OFK = std::function; + +template std::map map_keys_and_values( + std::map const &, OFK const &, FV const &); + } // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/map_values.cc b/lib/utils/src/utils/containers/map_values.cc index e26035e8b1..e850ecf31e 100644 --- a/lib/utils/src/utils/containers/map_values.cc +++ b/lib/utils/src/utils/containers/map_values.cc @@ -1,5 +1,6 @@ #include "utils/containers/map_values.h" #include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" namespace FlexFlow { @@ -11,4 +12,11 @@ using F = std::function; template std::unordered_map map_values(std::unordered_map const &, F &&); +using KO = ordered_value_type<0>; +using VO = value_type<1>; +using VO2 = value_type<2>; +using FO = std::function; + +template std::map map_values(std::map const &, FO &&); + } // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/map_values2.cc b/lib/utils/src/utils/containers/map_values2.cc index 6aba8f4db0..8840f15aee 100644 --- a/lib/utils/src/utils/containers/map_values2.cc +++ b/lib/utils/src/utils/containers/map_values2.cc @@ -1,14 +1,21 @@ #include "utils/containers/map_values2.h" #include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" namespace FlexFlow { -using K = value_type<0>; -using V = value_type<1>; -using V2 = value_type<2>; -using F = std::function; +using VT0 = value_type<0>; +using VT1 = value_type<1>; +using VT2 = value_type<2>; -template std::unordered_map map_values2(std::unordered_map const &, - F &&); +template std::unordered_map map_values2( + std::unordered_map const &, + std::function &&); + +using OT0 = ordered_value_type<0>; + +template std::map map_values2( + std::map const &, + std::function &&); } // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/merge_disjoint_maps.cc b/lib/utils/src/utils/containers/merge_disjoint_maps.cc index dec8ee0618..e810b6b4f0 100644 --- a/lib/utils/src/utils/containers/merge_disjoint_maps.cc +++ b/lib/utils/src/utils/containers/merge_disjoint_maps.cc @@ -1,12 +1,13 @@ #include "utils/containers/merge_disjoint_maps.h" #include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" namespace FlexFlow { -using K = value_type<0>; +using K = ordered_value_type<0>; using V = value_type<1>; -using C = std::vector>; +using C = std::vector>; -template std::unordered_map merge_disjoint_maps(C const &); +template std::map merge_disjoint_maps(C const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/merge_disjoint_unordered_maps.cc b/lib/utils/src/utils/containers/merge_disjoint_unordered_maps.cc new file mode 100644 index 0000000000..1ef6af0877 --- /dev/null +++ b/lib/utils/src/utils/containers/merge_disjoint_unordered_maps.cc @@ -0,0 +1,12 @@ +#include "utils/containers/merge_disjoint_unordered_maps.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using K = value_type<0>; +using V = value_type<1>; +using C = std::vector>; + +template std::unordered_map merge_disjoint_unordered_maps(C const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/merge_in_map.cc b/lib/utils/src/utils/containers/merge_in_map.cc index ada1a803ad..618128ff3b 100644 --- a/lib/utils/src/utils/containers/merge_in_map.cc +++ b/lib/utils/src/utils/containers/merge_in_map.cc @@ -1,12 +1,12 @@ #include "utils/containers/merge_in_map.h" -#include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" namespace FlexFlow { -using K = value_type<0>; -using V = value_type<1>; +using K = ordered_value_type<0>; +using V = ordered_value_type<1>; -template void merge_in_map(std::unordered_map const &, - std::unordered_map &); +template void merge_in_map(std::map const &, + std::map &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/merge_in_unordered_map.cc b/lib/utils/src/utils/containers/merge_in_unordered_map.cc new file mode 100644 index 0000000000..95228dc5f2 --- /dev/null +++ b/lib/utils/src/utils/containers/merge_in_unordered_map.cc @@ -0,0 +1,12 @@ +#include "utils/containers/merge_in_unordered_map.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using K = value_type<0>; +using V = value_type<1>; + +template void merge_in_unordered_map(std::unordered_map const &, + std::unordered_map &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/merge_maps_with.cc b/lib/utils/src/utils/containers/merge_maps_with.cc index 0375b16bc4..b5b471a1f8 100644 --- a/lib/utils/src/utils/containers/merge_maps_with.cc +++ b/lib/utils/src/utils/containers/merge_maps_with.cc @@ -1,13 +1,14 @@ #include "utils/containers/merge_maps_with.h" #include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" namespace FlexFlow { -using K = value_type<0>; +using K = ordered_value_type<0>; using V = value_type<1>; using F = std::function; -template std::unordered_map - merge_maps_with(std::vector> const &, F &&); +template std::map + merge_maps_with(std::vector> const &, F &&); } // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/merge_maps_with_right_dominating.cc b/lib/utils/src/utils/containers/merge_maps_with_right_dominating.cc index d8c269d6e9..f33b46780c 100644 --- a/lib/utils/src/utils/containers/merge_maps_with_right_dominating.cc +++ b/lib/utils/src/utils/containers/merge_maps_with_right_dominating.cc @@ -1,12 +1,13 @@ #include "utils/containers/merge_maps_with_right_dominating.h" #include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" namespace FlexFlow { -using K = value_type<0>; +using K = ordered_value_type<0>; using V = value_type<1>; -using C = std::vector>; +using C = std::vector>; -template std::unordered_map merge_maps_with_right_dominating(C const &); +template std::map merge_maps_with_right_dominating(C const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/merge_unordered_maps_with.cc b/lib/utils/src/utils/containers/merge_unordered_maps_with.cc new file mode 100644 index 0000000000..60218312f3 --- /dev/null +++ b/lib/utils/src/utils/containers/merge_unordered_maps_with.cc @@ -0,0 +1,14 @@ +#include "utils/containers/merge_unordered_maps_with.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using K = value_type<0>; +using V = value_type<1>; +using F = std::function; + +std::unordered_map + merge_unordered_maps_with(std::vector> const &, + F &&); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/merge_unordered_maps_with_right_dominating.cc b/lib/utils/src/utils/containers/merge_unordered_maps_with_right_dominating.cc new file mode 100644 index 0000000000..1dd7da70a3 --- /dev/null +++ b/lib/utils/src/utils/containers/merge_unordered_maps_with_right_dominating.cc @@ -0,0 +1,12 @@ +#include "utils/containers/merge_unordered_maps_with_right_dominating.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using K = value_type<0>; +using V = value_type<1>; +using C = std::vector>; + +template std::unordered_map merge_unordered_maps_with_right_dominating(C const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/require_all_of.cc b/lib/utils/src/utils/containers/require_all_of.cc new file mode 100644 index 0000000000..7fbf48c54f --- /dev/null +++ b/lib/utils/src/utils/containers/require_all_of.cc @@ -0,0 +1,34 @@ +#include "utils/containers/require_all_of.h" +#include "utils/archetypes/ordered_value_type.h" +#include "utils/archetypes/value_type.h" +#include +#include + +namespace FlexFlow { + +using T1 = value_type<0>; +using F1 = std::function; + +template void require_all_of(std::vector const &, F1 &&); +template void require_all_of(std::unordered_set const &, F1 &&); +template void require_all_of(std::unordered_multiset const &, F1 &&); + +using T2 = ordered_value_type<0>; +using F2 = std::function; + +template void require_all_of(std::set const &, F2 &&); +template void require_all_of(std::multiset const &, F2 &&); + +using K3 = value_type<0>; +using V3 = value_type<1>; +using F3 = std::function; + +template void require_all_of(std::unordered_map const &, F3 &&); + +using K4 = ordered_value_type<0>; +using V4 = ordered_value_type<1>; +using F4 = std::function; + +template void require_all_of(std::map const &, F4 &&); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/require_only_key.cc b/lib/utils/src/utils/containers/require_only_key.cc index ac8c201303..26ec81528a 100644 --- a/lib/utils/src/utils/containers/require_only_key.cc +++ b/lib/utils/src/utils/containers/require_only_key.cc @@ -1,5 +1,6 @@ #include "utils/containers/require_only_key.h" #include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" namespace FlexFlow { @@ -8,4 +9,8 @@ using V = value_type<1>; template V require_only_key(std::unordered_map const &, K const &); +using K2 = ordered_value_type<0>; + +template V require_only_key(std::map const &, K2 const &); + } // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/restrict_keys.cc b/lib/utils/src/utils/containers/restrict_keys.cc index d2749b7ea2..13584abec1 100644 --- a/lib/utils/src/utils/containers/restrict_keys.cc +++ b/lib/utils/src/utils/containers/restrict_keys.cc @@ -1 +1,21 @@ #include "utils/containers/restrict_keys.h" +#include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" + +namespace FlexFlow { + +using VT0 = value_type<0>; +using VT1 = value_type<1>; + +template + std::unordered_map restrict_keys(std::unordered_map const &, + std::unordered_set const &); + +using OV0 = ordered_value_type<0>; + +template + std::map restrict_keys(std::map const &, + std::set const &); + + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/transform.cc b/lib/utils/src/utils/containers/transform.cc index 7cd5a56ed4..55255cbc1e 100644 --- a/lib/utils/src/utils/containers/transform.cc +++ b/lib/utils/src/utils/containers/transform.cc @@ -1 +1,47 @@ #include "utils/containers/transform.h" +#include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" + +namespace FlexFlow { + +using In = value_type<0>; +using Out = value_type<1>; +using F = std::function; + +template std::vector transform(std::vector const &, F const &); + +template std::unordered_set transform(std::unordered_set const &, F const &); + +template std::unordered_multiset transform(std::unordered_multiset const &, F const &); + +using In2 = ordered_value_type<0>; +using Out2 = ordered_value_type<1>; +using F2 = std::function; + +template std::set transform(std::set const &, F2 const &); + +template std::multiset transform(std::multiset const &v, F2 const &f); + +using F3 = std::function; + +template std::string transform(std::string const &, F3 const &); + +using K = value_type<3>; +using V = value_type<4>; +using K2 = value_type<5>; +using V2 = value_type<6>; + +template std::unordered_map transform(std::unordered_map const &, + std::function(K const &, V const &)> const &); + +using K3 = ordered_value_type<3>; +using V3 = value_type<4>; +using K4 = ordered_value_type<5>; +using V4 = value_type<6>; + +template std::map transform(std::map const &, + std::function(K3 const &, V3 const &)> const &); + +template std::optional transform(std::optional const &o, + std::function const &); +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/transform_pairs.cc b/lib/utils/src/utils/containers/transform_pairs.cc new file mode 100644 index 0000000000..4afda936e4 --- /dev/null +++ b/lib/utils/src/utils/containers/transform_pairs.cc @@ -0,0 +1,17 @@ +#include "utils/containers/transform_pairs.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using L = value_type<0>; +using R = value_type<1>; +using Out = value_type<2>; +using F = std::function; + +template std::vector transform_pairs(std::vector> const &, + F &&); + +template std::unordered_set + transform_pairs(std::unordered_set> const &, F &&); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/unordered_items.cc b/lib/utils/src/utils/containers/unordered_items.cc new file mode 100644 index 0000000000..9b58cfd18e --- /dev/null +++ b/lib/utils/src/utils/containers/unordered_items.cc @@ -0,0 +1,15 @@ +#include "utils/containers/unordered_items.h" +#include "utils/archetypes/ordered_value_type.h" +#include "utils/archetypes/value_type.h" +#include +#include + +namespace FlexFlow { + +using K = ordered_value_type<0>; +using V = value_type<1>; + +template std::unordered_set> unordered_items(std::unordered_map const &); +template std::unordered_set> unordered_items(std::map const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/unordered_keys.cc b/lib/utils/src/utils/containers/unordered_keys.cc new file mode 100644 index 0000000000..e850b5f460 --- /dev/null +++ b/lib/utils/src/utils/containers/unordered_keys.cc @@ -0,0 +1,13 @@ +#include "utils/containers/unordered_keys.h" +#include "utils/archetypes/ordered_value_type.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using K = ordered_value_type<0>; +using V = value_type<1>; + +template std::unordered_set unordered_keys(std::unordered_map const &); +std::unordered_set unordered_keys(std::map const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/unordered_map_from_map.cc b/lib/utils/src/utils/containers/unordered_map_from_map.cc new file mode 100644 index 0000000000..a0ffa034b5 --- /dev/null +++ b/lib/utils/src/utils/containers/unordered_map_from_map.cc @@ -0,0 +1,12 @@ +#include "utils/containers/unordered_map_from_map.h" +#include "utils/archetypes/ordered_value_type.h" +#include "utils/archetypes/value_type.h" + +namespace FlexFlow { + +using K = ordered_value_type<0>; +using V = value_type<0>; + +template std::unordered_map unordered_map_from_map(std::map const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/containers/zip_values_strict.cc b/lib/utils/src/utils/containers/zip_values_strict.cc index b9bed29a1b..c1710129a5 100644 --- a/lib/utils/src/utils/containers/zip_values_strict.cc +++ b/lib/utils/src/utils/containers/zip_values_strict.cc @@ -1,14 +1,23 @@ #include "utils/containers/zip_values_strict.h" #include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" namespace FlexFlow { -using K = value_type<0>; -using V1 = value_type<1>; -using V2 = value_type<2>; +using VT0 = value_type<0>; +using VT1 = value_type<1>; +using VT2 = value_type<2>; + +template std::unordered_map> + zip_values_strict(std::unordered_map const &, + std::unordered_map const &); + +using OV0 = ordered_value_type<0>; + +template std::map> + zip_values_strict(std::map const &, + std::map const &); + -template std::unordered_map> - zip_values_strict(std::unordered_map const &, - std::unordered_map const &); } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/dataflow_graph/algorithms/dataflow_graph_as_dot.cc b/lib/utils/src/utils/graph/dataflow_graph/algorithms/dataflow_graph_as_dot.cc index f617c52593..c9889a49e7 100644 --- a/lib/utils/src/utils/graph/dataflow_graph/algorithms/dataflow_graph_as_dot.cc +++ b/lib/utils/src/utils/graph/dataflow_graph/algorithms/dataflow_graph_as_dot.cc @@ -1,5 +1,4 @@ #include "utils/graph/dataflow_graph/algorithms/dataflow_graph_as_dot.h" -#include "utils/containers/generate_map.h" #include "utils/containers/map_keys.h" #include "utils/dot/dot_file.h" #include "utils/dot/dot_html_from_json.h" diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_dominators_map.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_dominators_map.cc index 2da1b208f4..5a9ca06c5e 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/get_dominators_map.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_dominators_map.cc @@ -1,5 +1,5 @@ #include "utils/graph/digraph/algorithms/get_dominators_map.h" -#include "utils/containers/generate_map.h" +#include "utils/containers/generate_unordered_map.h" #include "utils/containers/restrict_keys.h" #include "utils/containers/transform.h" #include "utils/containers/values.h" @@ -25,7 +25,7 @@ std::unordered_map> } std::unordered_map> result = - generate_map(get_nodes(g), [&](Node const &) { return get_nodes(g); }); + generate_unordered_map(get_nodes(g), [&](Node const &) { return get_nodes(g); }); while (!queue.empty()) { Node n = queue.front(); queue.pop(); diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_imm_dominators_map.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_imm_dominators_map.cc index 34cc7fcc6f..42a9830837 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/get_imm_dominators_map.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_imm_dominators_map.cc @@ -1,7 +1,7 @@ #include "utils/graph/digraph/algorithms/get_imm_dominators_map.h" #include "utils/containers/concat_vectors.h" #include "utils/containers/filter_values.h" -#include "utils/containers/generate_map.h" +#include "utils/containers/generate_unordered_map.h" #include "utils/containers/get_element_counts.h" #include "utils/containers/get_only.h" #include "utils/containers/keys.h" @@ -27,14 +27,14 @@ std::unordered_map> })); std::unordered_map dominator_counts = get_element_counts(recursive_dominator_list); - std::unordered_set imm_dominators = keys( + std::unordered_set imm_dominators = unordered_keys( filter_values(dominator_counts, [](int count) { return count <= 1; })); - assert(imm_dominators.size() <= 1); + ASSERT(imm_dominators.size() <= 1); return maybe_get_only(imm_dominators); }; - return generate_map(get_nodes(g), get_imm_dominator); + return generate_unordered_map(get_nodes(g), get_imm_dominator); } } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_imm_post_dominator.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_imm_post_dominator.cc index 39523f2ec1..899ed385c7 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/get_imm_post_dominator.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_imm_post_dominator.cc @@ -1,5 +1,5 @@ #include "utils/graph/digraph/algorithms/get_imm_post_dominator.h" -#include "utils/containers/generate_map.h" +#include "utils/containers/generate_unordered_map.h" #include "utils/containers/get_one_of.h" #include "utils/containers/get_only.h" #include "utils/containers/intersection.h" @@ -33,7 +33,7 @@ std::optional Node contracted_node = get_one_of(nodes); std::unordered_map contraction = - generate_map(nodes, [&](Node const &) { return contracted_node; }); + generate_unordered_map(nodes, [&](Node const &) { return contracted_node; }); return get_imm_post_dominator(apply_contraction(g, contraction), contracted_node); } diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_incoming_edges.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_incoming_edges.cc index db09dd07d6..e6c9d5e557 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/get_incoming_edges.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_incoming_edges.cc @@ -2,6 +2,8 @@ #include "utils/containers/group_by.h" #include "utils/containers/map_values.h" #include "utils/containers/set_of.h" +#include "utils/nonempty_unordered_set/nonempty_unordered_set.h" +#include "utils/containers/unordered_map_from_map.h" namespace FlexFlow { @@ -16,14 +18,18 @@ std::unordered_set get_incoming_edges(DiGraphView const &g, std::unordered_map> get_incoming_edges(DiGraphView const &g, std::unordered_set const &ns) { - std::unordered_map> result = - map_values(group_by(g.query_edges(DirectedEdgeQuery{ - query_set::matchall(), - query_set::match_values_in(set_of(ns)), - }), - [](DirectedEdge const &e) { return e.dst; }) - .l_to_r(), - [](nonempty_unordered_set const &s) + + std::map> by_dst = + group_by(g.query_edges(DirectedEdgeQuery{ + query_set::matchall(), + query_set::match_values_in(set_of(ns)), + }), + [](DirectedEdge const &e) { return e.dst; }) + .l_to_r(); + + std::map> result = + map_values(by_dst, + [](nonempty_set const &s) -> std::unordered_set { return s.unwrap_as_unordered_set(); }); @@ -32,7 +38,7 @@ std::unordered_map> result[n]; } - return result; + return unordered_map_from_map(result); } } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/digraph/algorithms/get_outgoing_edges.cc b/lib/utils/src/utils/graph/digraph/algorithms/get_outgoing_edges.cc index c2057472cf..883d8e3725 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/get_outgoing_edges.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/get_outgoing_edges.cc @@ -2,21 +2,26 @@ #include "utils/containers/group_by.h" #include "utils/containers/map_values.h" #include "utils/containers/set_of.h" +#include "utils/nonempty_unordered_set/nonempty_unordered_set.h" +#include "utils/containers/unordered_map_from_map.h" namespace FlexFlow { std::unordered_map> get_outgoing_edges(DiGraphView const &g, std::unordered_set const &ns) { - std::unordered_map> result = - map_values(group_by(g.query_edges(DirectedEdgeQuery{ - query_set::match_values_in(set_of(ns)), - query_set::matchall(), - }), - [](DirectedEdge const &e) { return e.src; }) - .l_to_r(), - [](nonempty_unordered_set const &s) - -> std::unordered_set { + + std::map> by_src = + group_by(g.query_edges(DirectedEdgeQuery{ + query_set::match_values_in(set_of(ns)), + query_set::matchall(), + }), + [](DirectedEdge const &e) { return e.src; }) + .l_to_r(); + + std::map> result = + map_values(by_src, + [](nonempty_set const &s) -> std::unordered_set { return s.unwrap_as_unordered_set(); }); @@ -24,7 +29,7 @@ std::unordered_map> result[n]; } - return result; + return unordered_map_from_map(result); } std::unordered_set get_outgoing_edges(DiGraphView const &g, diff --git a/lib/utils/src/utils/graph/digraph/algorithms/is_acyclic.cc b/lib/utils/src/utils/graph/digraph/algorithms/is_acyclic.cc index 096efd49e9..66c04ec59c 100644 --- a/lib/utils/src/utils/graph/digraph/algorithms/is_acyclic.cc +++ b/lib/utils/src/utils/graph/digraph/algorithms/is_acyclic.cc @@ -1,5 +1,5 @@ #include "utils/graph/digraph/algorithms/is_acyclic.h" -#include "utils/containers/generate_map.h" +#include "utils/containers/generate_unordered_map.h" #include "utils/graph/digraph/algorithms/get_successors.h" #include "utils/graph/node/algorithms.h" #include @@ -11,7 +11,7 @@ enum class ExplorationStatus { NOT_EXPLORED, BEING_EXPLORED, FULLY_EXPLORED }; bool is_acyclic(DiGraphView const &g) { std::unordered_map status = - generate_map(get_nodes(g), [](Node const &n) { + generate_unordered_map(get_nodes(g), [](Node const &n) { return ExplorationStatus::NOT_EXPLORED; }); diff --git a/lib/utils/src/utils/graph/instances/adjacency_multidigraph.cc b/lib/utils/src/utils/graph/instances/adjacency_multidigraph.cc index 941c8e8e3e..903ba0c589 100644 --- a/lib/utils/src/utils/graph/instances/adjacency_multidigraph.cc +++ b/lib/utils/src/utils/graph/instances/adjacency_multidigraph.cc @@ -1,12 +1,12 @@ #include "utils/graph/instances/adjacency_multidigraph.h" #include "utils/containers/contains_key.h" #include "utils/containers/extend.h" -#include "utils/containers/generate_map.h" -#include "utils/containers/keys.h" +#include "utils/containers/generate_unordered_map.h" #include "utils/containers/values.h" #include "utils/graph/multidigraph/algorithms/get_edges.h" #include "utils/graph/node/algorithms.h" #include "utils/hash/unordered_set.h" +#include "utils/containers/unordered_keys.h" namespace FlexFlow { @@ -26,8 +26,8 @@ AdjacencyMultiDiGraph::AdjacencyMultiDiGraph( Node AdjacencyMultiDiGraph::add_node() { Node new_node = this->node_source.new_node(); std::unordered_set all_nodes = - set_union(keys(this->adjacency), {new_node}); - this->adjacency[new_node] = generate_map(all_nodes, [](Node const &) { + set_union(unordered_keys(this->adjacency), {new_node}); + this->adjacency[new_node] = generate_unordered_map(all_nodes, [](Node const &) { return std::unordered_set{}; }); @@ -77,15 +77,15 @@ void AdjacencyMultiDiGraph::remove_edge(MultiDiEdge const &e) { std::unordered_set AdjacencyMultiDiGraph::query_nodes(NodeQuery const &q) const { - return apply_query(q.nodes, keys(this->adjacency)); + return apply_query(q.nodes, unordered_keys(this->adjacency)); } std::unordered_set AdjacencyMultiDiGraph::query_edges(MultiDiEdgeQuery const &q) const { std::unordered_set result; - std::unordered_set srcs = apply_query(q.srcs, keys(this->adjacency)); - std::unordered_set dsts = apply_query(q.dsts, keys(this->adjacency)); + std::unordered_set srcs = apply_query(q.srcs, unordered_keys(this->adjacency)); + std::unordered_set dsts = apply_query(q.dsts, unordered_keys(this->adjacency)); for (Node const &src : srcs) { for (Node const &dst : dsts) { extend(result, this->adjacency.at(src).at(dst)); @@ -108,8 +108,8 @@ void AdjacencyMultiDiGraph::inplace_materialize_from( std::unordered_set nodes = get_nodes(g); std::unordered_set edges = get_edges(g); - this->adjacency = generate_map(nodes, [&](Node const &) { - return generate_map( + this->adjacency = generate_unordered_map(nodes, [&](Node const &) { + return generate_unordered_map( nodes, [&](Node const &) { return std::unordered_set{}; }); }); this->edge_nodes.clear(); diff --git a/lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_value_uses.cc b/lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_value_uses.cc new file mode 100644 index 0000000000..b1d2988223 --- /dev/null +++ b/lib/utils/src/utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_value_uses.cc @@ -0,0 +1,12 @@ +#include "utils/graph/kwarg_dataflow_graph/algorithms/get_kwarg_dataflow_value_uses.h" +#include "utils/archetypes/ordered_value_type.h" + +namespace FlexFlow { + +using SlotName = ordered_value_type<0>; + +template std::unordered_set> + get_kwarg_dataflow_value_uses(KwargDataflowGraphView const &, + KwargDataflowOutput const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc b/lib/utils/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc index db181fbe73..ddc3f4f7ae 100644 --- a/lib/utils/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc +++ b/lib/utils/src/utils/graph/multidigraph/algorithms/get_incoming_edges.cc @@ -7,6 +7,7 @@ #include "utils/graph/multidigraph/multidiedge_query.dtg.h" #include "utils/graph/node/algorithms.h" #include "utils/graph/query_set.h" +#include "utils/containers/unordered_map_from_map.h" namespace FlexFlow { @@ -28,11 +29,11 @@ std::unordered_map> query_set::match_values_in(set_of(ns)), }; - std::unordered_map> result = map_values( + std::map> result = map_values( group_by(g.query_edges(query), [&](MultiDiEdge const &e) { return g.get_multidiedge_dst(e); }) .l_to_r(), - [](nonempty_unordered_set const &s) + [](nonempty_set const &s) -> std::unordered_set { return s.unwrap_as_unordered_set(); }); @@ -41,7 +42,7 @@ std::unordered_map> result[n]; } - return result; + return unordered_map_from_map(result); } } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/multidigraph/algorithms/get_multidiedge_to_diedge_map.cc b/lib/utils/src/utils/graph/multidigraph/algorithms/get_multidiedge_to_diedge_map.cc index 826c03f476..466bb903d9 100644 --- a/lib/utils/src/utils/graph/multidigraph/algorithms/get_multidiedge_to_diedge_map.cc +++ b/lib/utils/src/utils/graph/multidigraph/algorithms/get_multidiedge_to_diedge_map.cc @@ -1,5 +1,5 @@ #include "utils/graph/multidigraph/algorithms/get_multidiedge_to_diedge_map.h" -#include "utils/containers/generate_map.h" +#include "utils/containers/generate_unordered_map.h" #include "utils/graph/multidigraph/algorithms/get_directed_edge.h" #include "utils/graph/multidigraph/algorithms/get_edges.h" @@ -7,7 +7,7 @@ namespace FlexFlow { std::unordered_map get_multidiedge_to_diedge_map(MultiDiGraphView const &g) { - return generate_map(get_edges(g), [&](MultiDiEdge const &e) { + return generate_unordered_map(get_edges(g), [&](MultiDiEdge const &e) { return get_directed_edge(g, e); }); } diff --git a/lib/utils/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc b/lib/utils/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc index 28e181ebb9..143e59b3db 100644 --- a/lib/utils/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc +++ b/lib/utils/src/utils/graph/multidigraph/algorithms/get_outgoing_edges.cc @@ -5,6 +5,7 @@ #include "utils/graph/multidigraph/algorithms/get_edges.h" #include "utils/graph/node/algorithms.h" #include +#include "utils/containers/unordered_map_from_map.h" namespace FlexFlow { @@ -26,11 +27,11 @@ std::unordered_map> query_set::matchall(), }; - std::unordered_map> result = map_values( + std::map> result = map_values( group_by(g.query_edges(query), [&](MultiDiEdge const &e) { return g.get_multidiedge_src(e); }) .l_to_r(), - [](nonempty_unordered_set const &s) + [](nonempty_set const &s) -> std::unordered_set { return s.unwrap_as_unordered_set(); }); @@ -39,7 +40,7 @@ std::unordered_map> result[n]; } - return result; + return unordered_map_from_map(result); } } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_incoming_edges.cc b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_incoming_edges.cc index 0228fdd8e9..b68aa4b1d8 100644 --- a/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_incoming_edges.cc +++ b/lib/utils/src/utils/graph/open_dataflow_graph/algorithms/get_incoming_edges.cc @@ -1,5 +1,5 @@ #include "utils/graph/open_dataflow_graph/algorithms/get_incoming_edges.h" -#include "utils/containers/generate_map.h" +#include "utils/containers/generate_unordered_map.h" #include "utils/containers/sorted_by.h" #include "utils/containers/transform.h" #include "utils/graph/dataflow_graph/dataflow_edge_query.h" @@ -45,7 +45,7 @@ std::vector get_incoming_edges(OpenDataflowGraphView const &g, std::unordered_map> get_incoming_edges(OpenDataflowGraphView const &g, std::unordered_set const &ns) { - return generate_map(ns, + return generate_unordered_map(ns, [&](Node const &n) { return get_incoming_edges(g, n); }); } diff --git a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/balanced_binary_sp_tree_from_nary.cc b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/balanced_binary_sp_tree_from_nary.cc index 6f1c2ace68..398abb3faf 100644 --- a/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/balanced_binary_sp_tree_from_nary.cc +++ b/lib/utils/src/utils/graph/series_parallel/binary_sp_decomposition_tree/balanced_binary_sp_tree_from_nary.cc @@ -13,6 +13,7 @@ #include "utils/overload.h" #include #include +#include "utils/containers/multiset_of.h" namespace FlexFlow { @@ -43,8 +44,8 @@ BinarySPDecompositionTree from_parallel_child(children[0]), from_parallel_child(children[1])}}; } - auto s1 = unordered_multiset_of(slice(children, 0, children.size() / 2)); - auto s2 = unordered_multiset_of( + auto s1 = multiset_of(slice(children, 0, children.size() / 2)); + auto s2 = multiset_of( slice(children, children.size() / 2, std::nullopt)); return BinarySPDecompositionTree{BinaryParallelSplit{ diff --git a/lib/utils/src/utils/graph/series_parallel/non_normal_sp_decomposition.cc b/lib/utils/src/utils/graph/series_parallel/non_normal_sp_decomposition.cc index 6105dda704..b5a07a5a4d 100644 --- a/lib/utils/src/utils/graph/series_parallel/non_normal_sp_decomposition.cc +++ b/lib/utils/src/utils/graph/series_parallel/non_normal_sp_decomposition.cc @@ -11,6 +11,8 @@ #include "utils/graph/series_parallel/series_split.dtg.h" #include "utils/overload.h" #include "utils/variant.h" +#include "utils/containers/unordered_multiset_of.h" +#include "utils/containers/multiset_of.h" namespace FlexFlow { @@ -43,7 +45,7 @@ NonNormalSPDecomposition non_normal_parallel_composition( for (NonNormalSPDecomposition const &sp_comp : sp_compositions) { if (sp_comp.has()) { composition = multiset_union( - composition, sp_comp.get().get_children()); + composition, unordered_multiset_of(sp_comp.get().get_children())); } else if (sp_comp.has()) { composition.insert(sp_comp.get()); } else { @@ -51,7 +53,7 @@ NonNormalSPDecomposition non_normal_parallel_composition( composition.insert(sp_comp.get()); } } - return NonNormalSPDecomposition(NonNormalParallelSplit{composition}); + return NonNormalSPDecomposition(NonNormalParallelSplit{multiset_of(composition)}); } static Node as_non_normal(Node const &n) { @@ -70,11 +72,11 @@ static NonNormalSeriesSplit as_non_normal(SeriesSplit const &s) { static NonNormalParallelSplit as_non_normal(ParallelSplit const &p) { return non_normal_parallel_composition( - transform(p.get_children(), + unordered_multiset_of(transform(p.get_children(), [](std::variant const &child) { return as_non_normal( widen(child)); - })) + }))) .get(); } diff --git a/lib/utils/src/utils/graph/series_parallel/normalize_sp_decomposition.cc b/lib/utils/src/utils/graph/series_parallel/normalize_sp_decomposition.cc index 3851ca38d9..5eda579f81 100644 --- a/lib/utils/src/utils/graph/series_parallel/normalize_sp_decomposition.cc +++ b/lib/utils/src/utils/graph/series_parallel/normalize_sp_decomposition.cc @@ -6,6 +6,7 @@ #include "utils/graph/series_parallel/non_normal_sp_decomposition.h" #include "utils/graph/series_parallel/series_parallel_decomposition.h" #include "utils/variant.h" +#include "utils/containers/unordered_multiset_of.h" namespace FlexFlow { @@ -41,7 +42,7 @@ static SeriesParallelDecomposition static SeriesParallelDecomposition normalize_sp_decomposition(NonNormalParallelSplit const ¶llel) { - std::unordered_multiset normalized_children = + std::multiset normalized_children = transform(filter_empty(parallel.get_children()), [](std::variant const &child) { return normalize_sp_decomposition( @@ -54,7 +55,7 @@ static SeriesParallelDecomposition if (normalized_children.size() == 1) { return get_only(normalized_children); } - return parallel_composition(normalized_children); + return parallel_composition(unordered_multiset_of(normalized_children)); } SeriesParallelDecomposition diff --git a/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc b/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc index e0075c2584..8c9f655fe5 100644 --- a/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc +++ b/lib/utils/src/utils/graph/series_parallel/series_parallel_decomposition.cc @@ -16,6 +16,7 @@ #include "utils/nonnegative_int/nonnegative_int.h" #include "utils/variant.h" #include +#include "utils/containers/multiset_of.h" namespace FlexFlow { @@ -31,7 +32,7 @@ struct ToFinalAST { .value(); })}; } else { - return ParallelSplit{unordered_multiset_of(transform( + return ParallelSplit{multiset_of(transform( node.children, [](std::variant const &s) { return narrow>( @@ -137,7 +138,7 @@ SeriesParallelDecomposition parallel_composition( for (SeriesParallelDecomposition const &sp_comp : sp_compositions) { if (sp_comp.has()) { composition = multiset_union(composition, - sp_comp.get().get_children()); + unordered_multiset_of(sp_comp.get().get_children())); } else if (sp_comp.has()) { composition.insert(sp_comp.get()); } else { @@ -145,7 +146,7 @@ SeriesParallelDecomposition parallel_composition( composition.insert(sp_comp.get()); } } - return SeriesParallelDecomposition(ParallelSplit{composition}); + return SeriesParallelDecomposition(ParallelSplit{multiset_of(composition)}); } } // namespace FlexFlow diff --git a/lib/utils/src/utils/graph/series_parallel/series_parallel_metrics.cc b/lib/utils/src/utils/graph/series_parallel/series_parallel_metrics.cc index fc7cad225a..590912e93f 100644 --- a/lib/utils/src/utils/graph/series_parallel/series_parallel_metrics.cc +++ b/lib/utils/src/utils/graph/series_parallel/series_parallel_metrics.cc @@ -5,6 +5,7 @@ #include "utils/containers/values.h" #include "utils/containers/vector_of.h" #include "utils/fmt/unordered_multiset.h" +#include "utils/fmt/multiset.h" #include "utils/graph/digraph/algorithms/get_edges.h" #include "utils/graph/digraph/algorithms/get_longest_path_lengths_from_root.h" #include "utils/graph/digraph/digraph_view.h" @@ -15,6 +16,7 @@ #include "utils/nonnegative_int/nonnegative_int.h" #include "utils/variant.h" #include + namespace FlexFlow { static std::unordered_map diff --git a/lib/utils/src/utils/graph/series_parallel/sp_ization/escribano_algo.cc b/lib/utils/src/utils/graph/series_parallel/sp_ization/escribano_algo.cc index 36b8e7294b..e3e24294b9 100644 --- a/lib/utils/src/utils/graph/series_parallel/sp_ization/escribano_algo.cc +++ b/lib/utils/src/utils/graph/series_parallel/sp_ization/escribano_algo.cc @@ -147,7 +147,7 @@ static std::unordered_set return filter_out_sync_nodes(forest, node_roles); } -static std::pair, nonempty_unordered_set> +static std::pair, nonempty_set> get_up_and_down_sets( DiGraph const &g, std::unordered_set const &forest, @@ -229,7 +229,7 @@ SeriesParallelDecomposition escribano_sp_ization(DiGraph g) { std::unordered_set forest = get_forest_escribano(sp, handle, component, node_roles); - std::pair, nonempty_unordered_set> + std::pair, nonempty_set> up_down_sets = get_up_and_down_sets(sp, forest, depth_map); std::unordered_set up = up_down_sets.first.unwrap_as_unordered_set(); diff --git a/lib/utils/src/utils/graph/series_parallel/sp_ization/flexible_algo.cc b/lib/utils/src/utils/graph/series_parallel/sp_ization/flexible_algo.cc index 7206ec5cda..b22d7d3811 100644 --- a/lib/utils/src/utils/graph/series_parallel/sp_ization/flexible_algo.cc +++ b/lib/utils/src/utils/graph/series_parallel/sp_ization/flexible_algo.cc @@ -3,7 +3,7 @@ #include "utils/containers/argmin.h" #include "utils/containers/contains.h" #include "utils/containers/filter.h" -#include "utils/containers/generate_map.h" +#include "utils/containers/generate_unordered_map.h" #include "utils/containers/get_only.h" #include "utils/containers/intersection.h" #include "utils/containers/is_subseteq_of.h" @@ -233,7 +233,7 @@ static std::unordered_set ASSERT(!candidate_nodes.empty()); std::unordered_map critical_path_costs = - generate_map(candidate_nodes, [&](Node const &node) { + generate_unordered_map(candidate_nodes, [&](Node const &node) { std::unordered_set preds = get_predecessors(g, node); float max_parent_cost = maximum(transform(preds, [&](Node const &pred) { return sp_longest_paths.at(pred); @@ -253,7 +253,7 @@ static std::unordered_set static bool cost_map_is_valid(DiGraphView const &g, std::unordered_map const &cost_map) { - bool has_correct_nodes = get_nodes(g) == keys(cost_map); + bool has_correct_nodes = (get_nodes(g) == unordered_keys(cost_map)); bool has_nonnegative_costs = all_of(values(cost_map), [&](float const &cost) { return cost >= 0.0f; }); return has_correct_nodes && has_nonnegative_costs; diff --git a/lib/utils/src/utils/graph/series_parallel/sp_ization/naive_stratum_sync.cc b/lib/utils/src/utils/graph/series_parallel/sp_ization/naive_stratum_sync.cc index 3c38b23f2b..4ebaf89756 100644 --- a/lib/utils/src/utils/graph/series_parallel/sp_ization/naive_stratum_sync.cc +++ b/lib/utils/src/utils/graph/series_parallel/sp_ization/naive_stratum_sync.cc @@ -1,6 +1,5 @@ #include "utils/graph/series_parallel/sp_ization/naive_stratum_sync.h" #include "utils/containers/group_by.h" -#include "utils/containers/keys.h" #include "utils/containers/maximum.h" #include "utils/containers/range.h" #include "utils/containers/transform.h" @@ -13,6 +12,7 @@ #include "utils/graph/series_parallel/series_parallel_decomposition.h" #include "utils/graph/series_parallel/sp_ization/dependencies_are_maintained.h" #include +#include "utils/containers/unordered_keys.h" namespace FlexFlow { @@ -21,7 +21,7 @@ std::vector> std::unordered_map node_to_stratum = get_longest_path_lengths_from_root(g); - std::unordered_set nodes = keys(node_to_stratum); + std::unordered_set nodes = unordered_keys(node_to_stratum); OneToMany strata_to_nodes = group_by(nodes, [&](Node const &n) { return node_to_stratum.at(n); }); diff --git a/lib/utils/src/utils/graph/series_parallel/sp_ization/node_role.cc b/lib/utils/src/utils/graph/series_parallel/sp_ization/node_role.cc index 97b8f11ec3..a6d4183a23 100644 --- a/lib/utils/src/utils/graph/series_parallel/sp_ization/node_role.cc +++ b/lib/utils/src/utils/graph/series_parallel/sp_ization/node_role.cc @@ -1,5 +1,5 @@ #include "utils/graph/series_parallel/sp_ization/node_role.h" -#include "utils/containers/generate_map.h" +#include "utils/containers/generate_unordered_map.h" #include "utils/graph/algorithms.h" #include "utils/graph/digraph/algorithms/get_predecessors.h" #include "utils/graph/digraph/algorithms/get_successors.h" @@ -10,7 +10,7 @@ namespace FlexFlow { std::unordered_map get_initial_node_role_map(DiGraphView const &g) { - return generate_map(get_nodes(g), + return generate_unordered_map(get_nodes(g), [](Node const &) { return NodeRole::PURE; }); } diff --git a/lib/utils/src/utils/many_to_one/exhaustive_relational_join.cc b/lib/utils/src/utils/many_to_one/exhaustive_relational_join.cc index a12de02656..b987153285 100644 --- a/lib/utils/src/utils/many_to_one/exhaustive_relational_join.cc +++ b/lib/utils/src/utils/many_to_one/exhaustive_relational_join.cc @@ -1,11 +1,11 @@ #include "utils/many_to_one/exhaustive_relational_join.h" -#include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" namespace FlexFlow { -using T1 = value_type<0>; -using T2 = value_type<1>; -using T3 = value_type<2>; +using T1 = ordered_value_type<0>; +using T2 = ordered_value_type<1>; +using T3 = ordered_value_type<2>; template ManyToOne exhaustive_relational_join(ManyToOne const &, diff --git a/lib/utils/src/utils/many_to_one/invert_many_to_one.cc b/lib/utils/src/utils/many_to_one/invert_many_to_one.cc index 92570a1c7f..adb63f5fd8 100644 --- a/lib/utils/src/utils/many_to_one/invert_many_to_one.cc +++ b/lib/utils/src/utils/many_to_one/invert_many_to_one.cc @@ -1,10 +1,10 @@ #include "utils/many_to_one/invert_many_to_one.h" -#include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" namespace FlexFlow { -using L = value_type<0>; -using R = value_type<1>; +using L = ordered_value_type<0>; +using R = ordered_value_type<1>; template OneToMany invert_many_to_one(ManyToOne const &); diff --git a/lib/utils/src/utils/many_to_one/many_to_one.cc b/lib/utils/src/utils/many_to_one/many_to_one.cc index bbb3bfcc14..52ae52153b 100644 --- a/lib/utils/src/utils/many_to_one/many_to_one.cc +++ b/lib/utils/src/utils/many_to_one/many_to_one.cc @@ -1,5 +1,5 @@ #include "utils/many_to_one/many_to_one.h" -#include "utils/archetypes/jsonable_value_type.h" +#include "utils/archetypes/jsonable_ordered_value_type.h" #include "utils/archetypes/rapidcheckable_value_type.h" #include "utils/archetypes/value_type.h" @@ -17,12 +17,18 @@ template std::unordered_map, R> template std::ostream &operator<<(std::ostream &, ManyToOne const &); +template std::unordered_set> + unstructured_relation_from_many_to_one(ManyToOne const &); + +template ManyToOne many_to_one_from_unstructured_relation( + std::unordered_set> const &); + } // namespace FlexFlow namespace nlohmann { -using L = ::FlexFlow::jsonable_value_type<0>; -using R = ::FlexFlow::jsonable_value_type<1>; +using L = ::FlexFlow::jsonable_ordered_value_type<0>; +using R = ::FlexFlow::jsonable_ordered_value_type<1>; template struct adl_serializer<::FlexFlow::ManyToOne>; diff --git a/lib/utils/src/utils/many_to_one/many_to_one_from_unstructured_relation.cc b/lib/utils/src/utils/many_to_one/many_to_one_from_unstructured_relation.cc deleted file mode 100644 index dc03030f20..0000000000 --- a/lib/utils/src/utils/many_to_one/many_to_one_from_unstructured_relation.cc +++ /dev/null @@ -1,12 +0,0 @@ -#include "utils/many_to_one/many_to_one_from_unstructured_relation.h" -#include "utils/archetypes/value_type.h" - -namespace FlexFlow { - -using L = value_type<0>; -using R = value_type<1>; - -template ManyToOne many_to_one_from_unstructured_relation( - std::unordered_set> const &); - -} // namespace FlexFlow diff --git a/lib/utils/src/utils/many_to_one/unstructured_relation_from_many_to_one.cc b/lib/utils/src/utils/many_to_one/unstructured_relation_from_many_to_one.cc deleted file mode 100644 index d89df51b79..0000000000 --- a/lib/utils/src/utils/many_to_one/unstructured_relation_from_many_to_one.cc +++ /dev/null @@ -1,12 +0,0 @@ -#include "utils/many_to_one/unstructured_relation_from_many_to_one.h" -#include "utils/archetypes/value_type.h" - -namespace FlexFlow { - -using L = value_type<0>; -using R = value_type<1>; - -template std::unordered_set> - unstructured_relation_from_many_to_one(ManyToOne const &); - -} // namespace FlexFlow diff --git a/lib/utils/src/utils/nonempty_set/nonempty_set.cc b/lib/utils/src/utils/nonempty_set/nonempty_set.cc new file mode 100644 index 0000000000..a7bd7f7c5a --- /dev/null +++ b/lib/utils/src/utils/nonempty_set/nonempty_set.cc @@ -0,0 +1,31 @@ +#include "utils/nonempty_set/nonempty_set.h" +#include "utils/archetypes/ordered_value_type.h" +#include "utils/archetypes/jsonable_ordered_value_type.h" + +using T = ::FlexFlow::ordered_value_type<0>; +using J = ::FlexFlow::jsonable_ordered_value_type<0>; + +namespace FlexFlow { + +template struct nonempty_set; + +template bool operator==(std::set const &, nonempty_set const &); + +template bool operator!=(std::set const &, nonempty_set const &); + +template std::set format_as(nonempty_set const &); +template std::ostream &operator<<(std::ostream &, nonempty_set const &); + +} // namespace FlexFlow + +namespace nlohmann { + +template struct adl_serializer<::FlexFlow::nonempty_set>; + +} // namespace nlohmann + +namespace std { + +template struct hash<::FlexFlow::nonempty_set>; + +} // namespace std diff --git a/lib/utils/src/utils/one_to_many/exhaustive_relational_join.cc b/lib/utils/src/utils/one_to_many/exhaustive_relational_join.cc index 6d237732e9..ae8fdc0d99 100644 --- a/lib/utils/src/utils/one_to_many/exhaustive_relational_join.cc +++ b/lib/utils/src/utils/one_to_many/exhaustive_relational_join.cc @@ -1,11 +1,11 @@ #include "utils/one_to_many/exhaustive_relational_join.h" -#include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" namespace FlexFlow { -using T1 = value_type<0>; -using T2 = value_type<1>; -using T3 = value_type<2>; +using T1 = ordered_value_type<0>; +using T2 = ordered_value_type<1>; +using T3 = ordered_value_type<2>; template OneToMany exhaustive_relational_join(OneToMany const &, diff --git a/lib/utils/src/utils/one_to_many/invert_one_to_many.cc b/lib/utils/src/utils/one_to_many/invert_one_to_many.cc index cb911ff60a..45edf29b3c 100644 --- a/lib/utils/src/utils/one_to_many/invert_one_to_many.cc +++ b/lib/utils/src/utils/one_to_many/invert_one_to_many.cc @@ -1,10 +1,10 @@ #include "utils/one_to_many/invert_one_to_many.h" -#include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" namespace FlexFlow { -using L = value_type<0>; -using R = value_type<1>; +using L = ordered_value_type<0>; +using R = ordered_value_type<1>; template ManyToOne invert_one_to_many(OneToMany const &); diff --git a/lib/utils/src/utils/one_to_many/one_to_many.cc b/lib/utils/src/utils/one_to_many/one_to_many.cc index 158d2e10c9..ce6220c509 100644 --- a/lib/utils/src/utils/one_to_many/one_to_many.cc +++ b/lib/utils/src/utils/one_to_many/one_to_many.cc @@ -1,28 +1,32 @@ #include "utils/one_to_many/one_to_many.h" #include "utils/archetypes/jsonable_value_type.h" #include "utils/archetypes/rapidcheckable_value_type.h" -#include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" +#include "utils/archetypes/jsonable_ordered_value_type.h" using namespace ::FlexFlow; namespace FlexFlow { -using L = value_type<0>; -using R = value_type<1>; +using L = ordered_value_type<0>; +using R = ordered_value_type<1>; template struct OneToMany; -template std::unordered_map> +template std::map> format_as(OneToMany const &); template std::ostream &operator<<(std::ostream &, OneToMany const &); +template std::unordered_set> + unstructured_relation_from_one_to_many(OneToMany const &); + } // namespace FlexFlow namespace nlohmann { -using L = ::FlexFlow::jsonable_value_type<0>; -using R = ::FlexFlow::jsonable_value_type<1>; +using L = ::FlexFlow::jsonable_ordered_value_type<0>; +using R = ::FlexFlow::jsonable_ordered_value_type<1>; template struct adl_serializer<::FlexFlow::OneToMany>; @@ -40,8 +44,8 @@ template struct Arbitrary<::FlexFlow::OneToMany>; namespace std { -using L = ::FlexFlow::value_type<0>; -using R = ::FlexFlow::value_type<1>; +using L = ::FlexFlow::ordered_value_type<0>; +using R = ::FlexFlow::ordered_value_type<1>; template struct hash>; diff --git a/lib/utils/src/utils/one_to_many/one_to_many_filter_keys.cc b/lib/utils/src/utils/one_to_many/one_to_many_filter_keys.cc new file mode 100644 index 0000000000..a8855e1857 --- /dev/null +++ b/lib/utils/src/utils/one_to_many/one_to_many_filter_keys.cc @@ -0,0 +1,12 @@ +#include "utils/one_to_many/one_to_many_filter_keys.h" +#include "utils/archetypes/ordered_value_type.h" + +namespace FlexFlow { + +using L = ordered_value_type<0>; +using R = ordered_value_type<1>; +using F = std::function; + +template OneToMany one_to_many_filter_keys(OneToMany const &, F &&); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/one_to_many/one_to_many_filter_values.cc b/lib/utils/src/utils/one_to_many/one_to_many_filter_values.cc new file mode 100644 index 0000000000..fa11ccf56f --- /dev/null +++ b/lib/utils/src/utils/one_to_many/one_to_many_filter_values.cc @@ -0,0 +1,13 @@ +#include "utils/one_to_many/one_to_many_filter_values.h" +#include "utils/archetypes/ordered_value_type.h" + +namespace FlexFlow { + +using L = ordered_value_type<0>; +using R = ordered_value_type<1>; + +using F = std::function; + +template OneToMany one_to_many_filter_values(OneToMany const &, F &&); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/one_to_many/one_to_many_from_bidict.cc b/lib/utils/src/utils/one_to_many/one_to_many_from_bidict.cc index bd6c976488..3dad3d6a1d 100644 --- a/lib/utils/src/utils/one_to_many/one_to_many_from_bidict.cc +++ b/lib/utils/src/utils/one_to_many/one_to_many_from_bidict.cc @@ -1,10 +1,10 @@ #include "utils/one_to_many/one_to_many_from_bidict.h" -#include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" namespace FlexFlow { -using L = value_type<0>; -using R = value_type<1>; +using L = ordered_value_type<0>; +using R = ordered_value_type<1>; template OneToMany one_to_many_from_bidict(bidict const &); diff --git a/lib/utils/src/utils/one_to_many/one_to_many_from_l_to_r_mapping.cc b/lib/utils/src/utils/one_to_many/one_to_many_from_l_to_r_mapping.cc index 76f0a221c5..124adb20c3 100644 --- a/lib/utils/src/utils/one_to_many/one_to_many_from_l_to_r_mapping.cc +++ b/lib/utils/src/utils/one_to_many/one_to_many_from_l_to_r_mapping.cc @@ -1,10 +1,10 @@ #include "utils/one_to_many/one_to_many_from_l_to_r_mapping.h" -#include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" namespace FlexFlow { -using L = value_type<0>; -using R = value_type<1>; +using L = ordered_value_type<0>; +using R = ordered_value_type<1>; template OneToMany one_to_many_from_l_to_r_mapping( std::unordered_map> const &); diff --git a/lib/utils/src/utils/one_to_many/one_to_many_from_unstructured_relation.cc b/lib/utils/src/utils/one_to_many/one_to_many_from_unstructured_relation.cc deleted file mode 100644 index 4fdd52ef2e..0000000000 --- a/lib/utils/src/utils/one_to_many/one_to_many_from_unstructured_relation.cc +++ /dev/null @@ -1,12 +0,0 @@ -#include "utils/one_to_many/one_to_many_from_unstructured_relation.h" -#include "utils/archetypes/value_type.h" - -namespace FlexFlow { - -using L = value_type<0>; -using R = value_type<1>; - -template OneToMany one_to_many_from_unstructured_relation( - std::unordered_set> const &); - -} // namespace FlexFlow diff --git a/lib/utils/src/utils/one_to_many/one_to_many_transform_values.cc b/lib/utils/src/utils/one_to_many/one_to_many_transform_values.cc index 141db7f1da..47a4652ce2 100644 --- a/lib/utils/src/utils/one_to_many/one_to_many_transform_values.cc +++ b/lib/utils/src/utils/one_to_many/one_to_many_transform_values.cc @@ -1,11 +1,11 @@ #include "utils/one_to_many/one_to_many_transform_values.h" -#include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" namespace FlexFlow { -using L = value_type<0>; -using R1 = value_type<1>; -using R2 = value_type<2>; +using L = ordered_value_type<0>; +using R1 = ordered_value_type<1>; +using R2 = ordered_value_type<2>; using F = std::function; template OneToMany one_to_many_transform_values(OneToMany const &, diff --git a/lib/utils/src/utils/one_to_many/require_one_to_many_is_bijection.cc b/lib/utils/src/utils/one_to_many/require_one_to_many_is_bijection.cc new file mode 100644 index 0000000000..60a78bb123 --- /dev/null +++ b/lib/utils/src/utils/one_to_many/require_one_to_many_is_bijection.cc @@ -0,0 +1,12 @@ +#include "utils/one_to_many/require_one_to_many_is_bijection.h" +#include "utils/archetypes/ordered_value_type.h" + +namespace FlexFlow { + +using L = ordered_value_type<0>; +using R = ordered_value_type<1>; + +template + bidict require_one_to_many_is_bijection(OneToMany const &); + +} // namespace FlexFlow diff --git a/lib/utils/src/utils/one_to_many/unstructured_relation_from_one_to_many.cc b/lib/utils/src/utils/one_to_many/unstructured_relation_from_one_to_many.cc deleted file mode 100644 index 9a48510a3a..0000000000 --- a/lib/utils/src/utils/one_to_many/unstructured_relation_from_one_to_many.cc +++ /dev/null @@ -1,12 +0,0 @@ -#include "utils/one_to_many/unstructured_relation_from_one_to_many.h" -#include "utils/archetypes/value_type.h" - -namespace FlexFlow { - -using L = value_type<0>; -using R = value_type<1>; - -template std::unordered_set> - unstructured_relation_from_one_to_many(OneToMany const &); - -} // namespace FlexFlow diff --git a/lib/utils/src/utils/orthotope/dim_domain_mapping.cc b/lib/utils/src/utils/orthotope/dim_domain_mapping.cc index 03762dadca..bd0f46e3dc 100644 --- a/lib/utils/src/utils/orthotope/dim_domain_mapping.cc +++ b/lib/utils/src/utils/orthotope/dim_domain_mapping.cc @@ -1,9 +1,9 @@ #include "utils/orthotope/dim_domain_mapping.h" -#include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" -using ::FlexFlow::value_type; -using L = value_type<0>; -using R = value_type<1>; +using ::FlexFlow::ordered_value_type; +using L = ordered_value_type<0>; +using R = ordered_value_type<1>; namespace FlexFlow { @@ -32,9 +32,9 @@ template DimDomainMapping DimOrdering const &, DimOrdering const &); -using T1 = value_type<2>; -using T2 = value_type<3>; -using T3 = value_type<4>; +using T1 = ordered_value_type<2>; +using T2 = ordered_value_type<3>; +using T3 = ordered_value_type<4>; template DimDomainMapping compose_dim_domain_mappings(DimDomainMapping const &, diff --git a/lib/utils/src/utils/orthotope/dim_projection.cc b/lib/utils/src/utils/orthotope/dim_projection.cc index fdf0472a36..9fa250fa47 100644 --- a/lib/utils/src/utils/orthotope/dim_projection.cc +++ b/lib/utils/src/utils/orthotope/dim_projection.cc @@ -1,10 +1,11 @@ #include "utils/orthotope/dim_projection.h" #include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" namespace FlexFlow { -using L = value_type<0>; -using R = value_type<1>; +using L = ordered_value_type<0>; +using R = ordered_value_type<1>; template DimProjection dim_projection_identity_map(DimDomain const &, @@ -27,9 +28,9 @@ template DimCoord compute_dim_projection(DimProjection const &, DimOrdering const &, DimOrdering const &); -using T1 = value_type<2>; -using T2 = value_type<3>; -using T3 = value_type<4>; +using T1 = ordered_value_type<2>; +using T2 = ordered_value_type<3>; +using T3 = ordered_value_type<4>; template DimProjection right_compose_eq_projection(DimProjection const &, diff --git a/lib/utils/src/utils/orthotope/down_projection.cc b/lib/utils/src/utils/orthotope/down_projection.cc index 73842ecc11..684521a5de 100644 --- a/lib/utils/src/utils/orthotope/down_projection.cc +++ b/lib/utils/src/utils/orthotope/down_projection.cc @@ -1,10 +1,10 @@ #include "utils/orthotope/down_projection.h" -#include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" namespace FlexFlow { -using L = value_type<0>; -using R = value_type<1>; +using L = ordered_value_type<0>; +using R = ordered_value_type<1>; template DownProjection make_empty_down_projection(); @@ -26,9 +26,9 @@ template void project_dims(DownProjection &, template UpProjection invert_down_projection(DownProjection const &); -using T1 = value_type<2>; -using T2 = value_type<3>; -using T3 = value_type<4>; +using T1 = ordered_value_type<2>; +using T2 = ordered_value_type<3>; +using T3 = ordered_value_type<4>; template DownProjection compose_down_projections(DownProjection const &, diff --git a/lib/utils/src/utils/orthotope/eq_projection.cc b/lib/utils/src/utils/orthotope/eq_projection.cc index 877b3f93ae..a6965dc36e 100644 --- a/lib/utils/src/utils/orthotope/eq_projection.cc +++ b/lib/utils/src/utils/orthotope/eq_projection.cc @@ -1,10 +1,10 @@ #include "utils/orthotope/eq_projection.h" -#include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" namespace FlexFlow { -using L = value_type<0>; -using R = value_type<1>; +using L = ordered_value_type<0>; +using R = ordered_value_type<1>; template EqProjection make_empty_eq_projection(); @@ -18,9 +18,9 @@ template void project_dims(EqProjection &, L const &, R const &); template EqProjection invert_eq_projection(EqProjection const &); -using T1 = value_type<0>; -using T2 = value_type<1>; -using T3 = value_type<2>; +using T1 = ordered_value_type<0>; +using T2 = ordered_value_type<1>; +using T3 = ordered_value_type<2>; template EqProjection compose_eq_projections(EqProjection const &, diff --git a/lib/utils/src/utils/orthotope/minimal_dim_domain_mapping.cc b/lib/utils/src/utils/orthotope/minimal_dim_domain_mapping.cc index a867abfd0a..5d4c47b491 100644 --- a/lib/utils/src/utils/orthotope/minimal_dim_domain_mapping.cc +++ b/lib/utils/src/utils/orthotope/minimal_dim_domain_mapping.cc @@ -1,9 +1,9 @@ #include "utils/orthotope/minimal_dim_domain_mapping.h" -#include "utils/archetypes/value_type.h" +#include "utils/archetypes/ordered_value_type.h" -using ::FlexFlow::value_type; -using L = value_type<0>; -using R = value_type<1>; +using ::FlexFlow::ordered_value_type; +using L = ordered_value_type<0>; +using R = ordered_value_type<1>; namespace FlexFlow { @@ -40,9 +40,9 @@ template MinimalDimDomainMapping DimOrdering const &, DimOrdering const &); -using T1 = value_type<2>; -using T2 = value_type<3>; -using T3 = value_type<4>; +using T1 = ordered_value_type<2>; +using T2 = ordered_value_type<3>; +using T3 = ordered_value_type<4>; template MinimalDimDomainMapping compose_minimal_dim_domain_mappings( MinimalDimDomainMapping const &, diff --git a/lib/utils/src/utils/orthotope/up_projection.cc b/lib/utils/src/utils/orthotope/up_projection.cc index 604cec08ed..0c8909dffd 100644 --- a/lib/utils/src/utils/orthotope/up_projection.cc +++ b/lib/utils/src/utils/orthotope/up_projection.cc @@ -1,12 +1,11 @@ #include "utils/orthotope/up_projection.h" #include "utils/archetypes/ordered_value_type.h" -#include "utils/archetypes/value_type.h" namespace FlexFlow { -using T1 = value_type<0>; -using T2 = value_type<1>; -using T3 = value_type<2>; +using T1 = ordered_value_type<0>; +using T2 = ordered_value_type<1>; +using T3 = ordered_value_type<2>; template UpProjection compose_up_projections(UpProjection const &, diff --git a/lib/utils/test/src/utils/bidict/algorithms/filter_keys.cc b/lib/utils/test/src/utils/bidict/algorithms/bidict_filter_keys.cc similarity index 64% rename from lib/utils/test/src/utils/bidict/algorithms/filter_keys.cc rename to lib/utils/test/src/utils/bidict/algorithms/bidict_filter_keys.cc index 3c3097fa9b..2a0806324b 100644 --- a/lib/utils/test/src/utils/bidict/algorithms/filter_keys.cc +++ b/lib/utils/test/src/utils/bidict/algorithms/bidict_filter_keys.cc @@ -1,20 +1,21 @@ -#include "utils/bidict/algorithms/filter_keys.h" +#include "utils/bidict/algorithms/bidict_filter_keys.h" #include using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("filter_keys(bidict, F)") { + TEST_CASE("bidict_filter_keys(bidict, F)") { bidict dict = { {1, "one"}, {2, "two"}, }; bidict result = - filter_keys(dict, [](int k) { return k == 1; }); + bidict_filter_keys(dict, [](int k) { return k == 1; }); bidict correct = { {1, "one"}, }; + CHECK(result == correct); } } diff --git a/lib/utils/test/src/utils/bidict/algorithms/filter_values.cc b/lib/utils/test/src/utils/bidict/algorithms/bidict_filter_values.cc similarity index 61% rename from lib/utils/test/src/utils/bidict/algorithms/filter_values.cc rename to lib/utils/test/src/utils/bidict/algorithms/bidict_filter_values.cc index 54d0bad199..5c67e11c4f 100644 --- a/lib/utils/test/src/utils/bidict/algorithms/filter_values.cc +++ b/lib/utils/test/src/utils/bidict/algorithms/bidict_filter_values.cc @@ -1,17 +1,17 @@ -#include "utils/bidict/algorithms/filter_values.h" +#include "utils/bidict/algorithms/bidict_filter_values.h" #include using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("filter_values(bidict, F") { + TEST_CASE("bidict_filter_values(bidict, F") { bidict dict = { {1, "one"}, {2, "two"}, }; bidict result = - filter_values(dict, [](std::string const &v) { return v == "two"; }); + bidict_filter_values(dict, [](std::string const &v) { return v == "two"; }); bidict correct = { {2, "two"}, }; diff --git a/lib/utils/test/src/utils/bidict/algorithms/filtrans_keys.cc b/lib/utils/test/src/utils/bidict/algorithms/bidict_filtrans_keys.cc similarity index 75% rename from lib/utils/test/src/utils/bidict/algorithms/filtrans_keys.cc rename to lib/utils/test/src/utils/bidict/algorithms/bidict_filtrans_keys.cc index 300918f978..a2774440e5 100644 --- a/lib/utils/test/src/utils/bidict/algorithms/filtrans_keys.cc +++ b/lib/utils/test/src/utils/bidict/algorithms/bidict_filtrans_keys.cc @@ -1,17 +1,17 @@ -#include "utils/bidict/algorithms/filtrans_keys.h" +#include "utils/bidict/algorithms/bidict_filtrans_keys.h" #include using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("filtrans_keys(bidict, F)") { + TEST_CASE("bidict_filtrans_keys") { bidict dict = { {1, "one"}, {2, "two"}, }; bidict result = - filtrans_keys(dict, [](int k) -> std::optional { + bidict_filtrans_keys(dict, [](int k) -> std::optional { if (k == 1) { return std::nullopt; } else { diff --git a/lib/utils/test/src/utils/bidict/algorithms/filtrans_values.cc b/lib/utils/test/src/utils/bidict/algorithms/bidict_filtrans_values.cc similarity index 69% rename from lib/utils/test/src/utils/bidict/algorithms/filtrans_values.cc rename to lib/utils/test/src/utils/bidict/algorithms/bidict_filtrans_values.cc index 99aaef114d..8687d539bc 100644 --- a/lib/utils/test/src/utils/bidict/algorithms/filtrans_values.cc +++ b/lib/utils/test/src/utils/bidict/algorithms/bidict_filtrans_values.cc @@ -1,26 +1,28 @@ -#include "utils/bidict/algorithms/filtrans_values.h" +#include "utils/bidict/algorithms/bidict_filtrans_values.h" #include using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("filtrans_values(bidict, F)") { + TEST_CASE("bidict_filtrans_values") { bidict dict = { {1, "one"}, {2, "two"}, }; bidict result = - filtrans_values(dict, [](std::string const &v) -> std::optional { + bidict_filtrans_values(dict, [](std::string const &v) -> std::optional { if (v == "two") { return std::nullopt; } else { return v.size() + 1; } }); + bidict correct = { {1, 4}, }; + CHECK(result == correct); } } diff --git a/lib/utils/test/src/utils/bidict/algorithms/unordered_set_of.cc b/lib/utils/test/src/utils/bidict/algorithms/bidict_unordered_set_of.cc similarity index 77% rename from lib/utils/test/src/utils/bidict/algorithms/unordered_set_of.cc rename to lib/utils/test/src/utils/bidict/algorithms/bidict_unordered_set_of.cc index b88b6df0ca..d44a2fe62b 100644 --- a/lib/utils/test/src/utils/bidict/algorithms/unordered_set_of.cc +++ b/lib/utils/test/src/utils/bidict/algorithms/bidict_unordered_set_of.cc @@ -1,4 +1,4 @@ -#include "utils/bidict/algorithms/unordered_set_of.h" +#include "utils/bidict/algorithms/bidict_unordered_set_of.h" #include using namespace ::FlexFlow; diff --git a/lib/utils/test/src/utils/bidict/algorithms/merge_disjoint_bidicts.cc b/lib/utils/test/src/utils/bidict/algorithms/binary_merge_disjoint_bidicts.cc similarity index 72% rename from lib/utils/test/src/utils/bidict/algorithms/merge_disjoint_bidicts.cc rename to lib/utils/test/src/utils/bidict/algorithms/binary_merge_disjoint_bidicts.cc index 0a1babd9f9..8a3371b8d8 100644 --- a/lib/utils/test/src/utils/bidict/algorithms/merge_disjoint_bidicts.cc +++ b/lib/utils/test/src/utils/bidict/algorithms/binary_merge_disjoint_bidicts.cc @@ -1,17 +1,17 @@ -#include "utils/bidict/algorithms/merge_disjoint_bidicts.h" +#include "utils/bidict/algorithms/binary_merge_disjoint_bidicts.h" #include using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("merge_disjoint_bidicts") { + TEST_CASE("binary_merge_disjoint_bidicts") { SUBCASE("disjoint keys and values") { bidict bd1 = {{1, "one"}, {2, "two"}}; bidict bd2 = {{3, "three"}, {4, "four"}}; - bidict result = merge_disjoint_bidicts(bd1, bd2); + bidict result = binary_merge_disjoint_bidicts(bd1, bd2); bidict correct = { {1, "one"}, {2, "two"}, {3, "three"}, {4, "four"}}; @@ -22,21 +22,21 @@ TEST_SUITE(FF_TEST_SUITE) { bidict bd1 = {{1, "one"}, {2, "two"}}; bidict bd2 = {{2, "three"}, {3, "four"}}; - CHECK_THROWS(merge_disjoint_bidicts(bd1, bd2)); + CHECK_THROWS(binary_merge_disjoint_bidicts(bd1, bd2)); } SUBCASE("overlapping key, same associated value") { bidict bd1 = {{1, "one"}, {2, "two"}}; bidict bd2 = {{2, "two"}, {3, "three"}}; - CHECK_THROWS(merge_disjoint_bidicts(bd1, bd2)); + CHECK_THROWS(binary_merge_disjoint_bidicts(bd1, bd2)); } SUBCASE("overlapping values") { bidict bd1 = {{1, "one"}, {2, "two"}}; bidict bd2 = {{3, "two"}, {4, "four"}}; - CHECK_THROWS(merge_disjoint_bidicts(bd1, bd2)); + CHECK_THROWS(binary_merge_disjoint_bidicts(bd1, bd2)); } } } diff --git a/lib/utils/test/src/utils/bidict/bidict.cc b/lib/utils/test/src/utils/bidict/bidict.cc index 1365d04027..ead45fe86a 100644 --- a/lib/utils/test/src/utils/bidict/bidict.cc +++ b/lib/utils/test/src/utils/bidict/bidict.cc @@ -63,14 +63,14 @@ TEST_SUITE(FF_TEST_SUITE) { SUBCASE("bidict::erase_l") { dict.erase_l(1); CHECK(dict.size() == 1); - CHECK_THROWS_AS(dict.at_l(1), std::out_of_range); + CHECK_THROWS(dict.at_l(1)); CHECK(dict.at_r("two") == 2); } SUBCASE("bidict::erase_r") { dict.erase_r("one"); CHECK(dict.size() == 1); - CHECK_THROWS_AS(dict.at_r("one"), std::out_of_range); + CHECK_THROWS(dict.at_r("one")); CHECK(dict.at_l(2) == "two"); } @@ -113,11 +113,13 @@ TEST_SUITE(FF_TEST_SUITE) { bidict deserialized = bidict{ {2, "hello"}, {3, "goodbye"}, + {4, "yes"}, }; nlohmann::json serialized = std::vector>{ {2, "hello"}, {3, "goodbye"}, + {4, "yes"}, }; SUBCASE("to_json") { diff --git a/lib/utils/test/src/utils/containers/binary_merge_disjoint_maps.cc b/lib/utils/test/src/utils/containers/binary_merge_disjoint_maps.cc index bcc7b4149f..d4487343f2 100644 --- a/lib/utils/test/src/utils/containers/binary_merge_disjoint_maps.cc +++ b/lib/utils/test/src/utils/containers/binary_merge_disjoint_maps.cc @@ -1,27 +1,27 @@ #include "utils/containers/binary_merge_disjoint_maps.h" -#include "test/utils/doctest/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/map.h" #include using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("binary_merge_disjoint_maps") { - std::unordered_map l_map = { + std::map l_map = { {1, "one"}, {2, "two"}, }; - std::unordered_map r_map = { + std::map r_map = { {3, "three"}, }; - std::unordered_map correct = { + std::map correct = { {1, "one"}, {2, "two"}, {3, "three"}, }; SUBCASE("maps are disjoint") { - std::unordered_map result = + std::map result = binary_merge_disjoint_maps(l_map, r_map); CHECK(result == correct); diff --git a/lib/utils/test/src/utils/containers/binary_merge_disjoint_unordered_maps.cc b/lib/utils/test/src/utils/containers/binary_merge_disjoint_unordered_maps.cc new file mode 100644 index 0000000000..250d1c7f69 --- /dev/null +++ b/lib/utils/test/src/utils/containers/binary_merge_disjoint_unordered_maps.cc @@ -0,0 +1,34 @@ +#include "utils/containers/binary_merge_disjoint_unordered_maps.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("binary_merge_disjoint_unordered_maps") { + std::unordered_map l_map = { + {1, "one"}, + {2, "two"}, + }; + + std::unordered_map r_map = { + {3, "three"}, + }; + + std::unordered_map correct = { + {1, "one"}, + {2, "two"}, + {3, "three"}, + }; + SUBCASE("maps are disjoint") { + std::unordered_map result = + binary_merge_disjoint_unordered_maps(l_map, r_map); + + CHECK(result == correct); + } + + SUBCASE("maps are not disjoint") { + CHECK_THROWS(binary_merge_disjoint_unordered_maps(l_map, l_map)); + } + } +} diff --git a/lib/utils/test/src/utils/containers/binary_merge_maps_with.cc b/lib/utils/test/src/utils/containers/binary_merge_maps_with.cc index 55b9c428bf..6a848565e1 100644 --- a/lib/utils/test/src/utils/containers/binary_merge_maps_with.cc +++ b/lib/utils/test/src/utils/containers/binary_merge_maps_with.cc @@ -1,5 +1,5 @@ #include "utils/containers/binary_merge_maps_with.h" -#include "test/utils/doctest/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/map.h" #include #include @@ -11,47 +11,47 @@ TEST_SUITE(FF_TEST_SUITE) { std::string const &) -> std::string { PANIC(); }; SUBCASE("lhs and rhs do not overlap") { - std::unordered_map lhs = { + std::map lhs = { {1, "lhs_one."}, {4, "lhs_four."}, }; - std::unordered_map rhs = { + std::map rhs = { {2, "rhs_two."}, {5, "rhs_five."}, }; - std::unordered_map correct = { + std::map correct = { {1, "lhs_one."}, {2, "rhs_two."}, {4, "lhs_four."}, {5, "rhs_five."}, }; - std::unordered_map result = + std::map result = binary_merge_maps_with(lhs, rhs, fail_if_called); CHECK(result == correct); } SUBCASE("lhs and rhs overlap") { - std::unordered_map lhs = { + std::map lhs = { {1, "lhs_one."}, {4, "lhs_four."}, }; - std::unordered_map rhs = { + std::map rhs = { {2, "rhs_two."}, {4, "rhs_four."}, {5, "rhs_five."}, }; - std::unordered_map result = binary_merge_maps_with( + std::map result = binary_merge_maps_with( lhs, rhs, [](std::string const &l, std::string const &r) { return l + r; }); - std::unordered_map correct = { + std::map correct = { {1, "lhs_one."}, {2, "rhs_two."}, {4, "lhs_four.rhs_four."}, @@ -62,47 +62,47 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("lhs is empty") { - std::unordered_map lhs = {}; + std::map lhs = {}; - std::unordered_map rhs = { + std::map rhs = { {2, "rhs_two."}, {4, "rhs_four."}, {5, "rhs_five."}, }; - std::unordered_map result = + std::map result = binary_merge_maps_with(lhs, rhs, fail_if_called); - std::unordered_map correct = rhs; + std::map correct = rhs; CHECK(result == correct); } SUBCASE("rhs is empty") { - std::unordered_map lhs = { + std::map lhs = { {1, "lhs_one."}, {4, "lhs_four."}, }; - std::unordered_map rhs = {}; + std::map rhs = {}; - std::unordered_map result = + std::map result = binary_merge_maps_with(lhs, rhs, fail_if_called); - std::unordered_map correct = lhs; + std::map correct = lhs; CHECK(result == correct); } SUBCASE("both lhs and rhs are empty") { - std::unordered_map lhs = {}; + std::map lhs = {}; - std::unordered_map rhs = {}; + std::map rhs = {}; - std::unordered_map result = + std::map result = binary_merge_maps_with(lhs, rhs, fail_if_called); - std::unordered_map correct = {}; + std::map correct = {}; CHECK(result == correct); } diff --git a/lib/utils/test/src/utils/containers/binary_merge_maps_with_left_dominating.cc b/lib/utils/test/src/utils/containers/binary_merge_maps_with_left_dominating.cc index 27a389d400..fe152dd832 100644 --- a/lib/utils/test/src/utils/containers/binary_merge_maps_with_left_dominating.cc +++ b/lib/utils/test/src/utils/containers/binary_merge_maps_with_left_dominating.cc @@ -1,5 +1,5 @@ #include "utils/containers/binary_merge_maps_with_left_dominating.h" -#include "test/utils/doctest/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/map.h" #include #include @@ -7,23 +7,23 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("binary_merge_maps_with_left_dominating") { - std::unordered_map l_map = { + std::map l_map = { {1, "one"}, {2, "left_two"}, }; - std::unordered_map r_map = { + std::map r_map = { {2, "right_two"}, {3, "three"}, }; - std::unordered_map correct = { + std::map correct = { {1, "one"}, {2, "left_two"}, {3, "three"}, }; - std::unordered_map result = + std::map result = binary_merge_maps_with_left_dominating(l_map, r_map); CHECK(result == correct); diff --git a/lib/utils/test/src/utils/containers/binary_merge_maps_with_right_dominating.cc b/lib/utils/test/src/utils/containers/binary_merge_maps_with_right_dominating.cc index 153266989e..c107f2b7ff 100644 --- a/lib/utils/test/src/utils/containers/binary_merge_maps_with_right_dominating.cc +++ b/lib/utils/test/src/utils/containers/binary_merge_maps_with_right_dominating.cc @@ -1,5 +1,5 @@ #include "utils/containers/binary_merge_maps_with_right_dominating.h" -#include "test/utils/doctest/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/map.h" #include #include @@ -7,23 +7,23 @@ using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("binary_merge_maps_with_right_dominating") { - std::unordered_map l_map = { + std::map l_map = { {1, "one"}, {2, "left_two"}, }; - std::unordered_map r_map = { + std::map r_map = { {2, "right_two"}, {3, "three"}, }; - std::unordered_map correct = { + std::map correct = { {1, "one"}, {2, "right_two"}, {3, "three"}, }; - std::unordered_map result = + std::map result = binary_merge_maps_with_right_dominating(l_map, r_map); CHECK(result == correct); diff --git a/lib/utils/test/src/utils/containers/binary_merge_unordered_maps_with.cc b/lib/utils/test/src/utils/containers/binary_merge_unordered_maps_with.cc new file mode 100644 index 0000000000..4e825b99e2 --- /dev/null +++ b/lib/utils/test/src/utils/containers/binary_merge_unordered_maps_with.cc @@ -0,0 +1,110 @@ +#include "utils/containers/binary_merge_unordered_maps_with.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("binary_merge_unordered_maps_with") { + auto fail_if_called = [](std::string const &, + std::string const &) -> std::string { PANIC(); }; + + SUBCASE("lhs and rhs do not overlap") { + std::unordered_map lhs = { + {1, "lhs_one."}, + {4, "lhs_four."}, + }; + + std::unordered_map rhs = { + {2, "rhs_two."}, + {5, "rhs_five."}, + }; + + std::unordered_map correct = { + {1, "lhs_one."}, + {2, "rhs_two."}, + {4, "lhs_four."}, + {5, "rhs_five."}, + }; + + std::unordered_map result = + binary_merge_unordered_maps_with(lhs, rhs, fail_if_called); + + CHECK(result == correct); + } + + SUBCASE("lhs and rhs overlap") { + std::unordered_map lhs = { + {1, "lhs_one."}, + {4, "lhs_four."}, + }; + + std::unordered_map rhs = { + {2, "rhs_two."}, + {4, "rhs_four."}, + {5, "rhs_five."}, + }; + + std::unordered_map result = binary_merge_unordered_maps_with( + lhs, rhs, [](std::string const &l, std::string const &r) { + return l + r; + }); + + std::unordered_map correct = { + {1, "lhs_one."}, + {2, "rhs_two."}, + {4, "lhs_four.rhs_four."}, + {5, "rhs_five."}, + }; + + CHECK(result == correct); + } + + SUBCASE("lhs is empty") { + std::unordered_map lhs = {}; + + std::unordered_map rhs = { + {2, "rhs_two."}, + {4, "rhs_four."}, + {5, "rhs_five."}, + }; + + std::unordered_map result = + binary_merge_unordered_maps_with(lhs, rhs, fail_if_called); + + std::unordered_map correct = rhs; + + CHECK(result == correct); + } + + SUBCASE("rhs is empty") { + std::unordered_map lhs = { + {1, "lhs_one."}, + {4, "lhs_four."}, + }; + + std::unordered_map rhs = {}; + + std::unordered_map result = + binary_merge_unordered_maps_with(lhs, rhs, fail_if_called); + + std::unordered_map correct = lhs; + + CHECK(result == correct); + } + + SUBCASE("both lhs and rhs are empty") { + std::unordered_map lhs = {}; + + std::unordered_map rhs = {}; + + std::unordered_map result = + binary_merge_unordered_maps_with(lhs, rhs, fail_if_called); + + std::unordered_map correct = {}; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/containers/binary_merge_unordered_maps_with_left_dominating.cc b/lib/utils/test/src/utils/containers/binary_merge_unordered_maps_with_left_dominating.cc new file mode 100644 index 0000000000..d857cf2d91 --- /dev/null +++ b/lib/utils/test/src/utils/containers/binary_merge_unordered_maps_with_left_dominating.cc @@ -0,0 +1,31 @@ +#include "utils/containers/binary_merge_unordered_maps_with_left_dominating.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("binary_merge_unordered_maps_with_left_dominating") { + std::unordered_map l_map = { + {1, "one"}, + {2, "left_two"}, + }; + + std::unordered_map r_map = { + {2, "right_two"}, + {3, "three"}, + }; + + std::unordered_map correct = { + {1, "one"}, + {2, "left_two"}, + {3, "three"}, + }; + + std::unordered_map result = + binary_merge_unordered_maps_with_left_dominating(l_map, r_map); + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/containers/binary_merge_unordered_maps_with_right_dominating.cc b/lib/utils/test/src/utils/containers/binary_merge_unordered_maps_with_right_dominating.cc new file mode 100644 index 0000000000..71f50c4dac --- /dev/null +++ b/lib/utils/test/src/utils/containers/binary_merge_unordered_maps_with_right_dominating.cc @@ -0,0 +1,31 @@ +#include "utils/containers/binary_merge_unordered_maps_with_right_dominating.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("binary_merge_unordered_maps_with_right_dominating") { + std::unordered_map l_map = { + {1, "one"}, + {2, "left_two"}, + }; + + std::unordered_map r_map = { + {2, "right_two"}, + {3, "three"}, + }; + + std::unordered_map correct = { + {1, "one"}, + {2, "right_two"}, + {3, "three"}, + }; + + std::unordered_map result = + binary_merge_unordered_maps_with_right_dominating(l_map, r_map); + + CHECK(result == correct); + } +} diff --git a/lib/utils/test/src/utils/containers/enumerate.cc b/lib/utils/test/src/utils/containers/enumerate.cc index 2fdb2e481e..22bae1f613 100644 --- a/lib/utils/test/src/utils/containers/enumerate.cc +++ b/lib/utils/test/src/utils/containers/enumerate.cc @@ -4,7 +4,7 @@ #include "test/utils/doctest/fmt/unordered_multiset.h" #include "test/utils/doctest/fmt/unordered_set.h" #include "test/utils/doctest/fmt/vector.h" -#include "utils/containers/keys.h" +#include "utils/containers/unordered_keys.h" #include "utils/containers/unordered_multiset_of.h" #include "utils/containers/values.h" #include "utils/containers/vector_of.h" @@ -50,7 +50,7 @@ TEST_SUITE(FF_TEST_SUITE) { std::unordered_multiset correct_values = {"A", "B", "C", "D"}; std::map result = enumerate(input); - CHECK(keys(result) == correct_keys); + CHECK(unordered_keys(result) == correct_keys); CHECK(unordered_multiset_of(values(result)) == correct_values); } } diff --git a/lib/utils/test/src/utils/containers/keys.cc b/lib/utils/test/src/utils/containers/keys.cc index 5bdaef6d08..d2ac3dbfba 100644 --- a/lib/utils/test/src/utils/containers/keys.cc +++ b/lib/utils/test/src/utils/containers/keys.cc @@ -1,5 +1,5 @@ #include "utils/containers/keys.h" -#include "test/utils/doctest/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/set.h" #include #include #include @@ -9,10 +9,10 @@ using namespace FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("keys") { - std::unordered_map m = { + std::map m = { {1, "one"}, {2, "two"}, {3, "three"}}; - std::unordered_set result = keys(m); - std::unordered_set expected = {1, 2, 3}; + std::set result = keys(m); + std::set expected = {1, 2, 3}; CHECK(result == expected); } } diff --git a/lib/utils/test/src/utils/containers/map_from_pairs.cc b/lib/utils/test/src/utils/containers/map_from_pairs.cc index fc387f1d1a..48e8b9fe05 100644 --- a/lib/utils/test/src/utils/containers/map_from_pairs.cc +++ b/lib/utils/test/src/utils/containers/map_from_pairs.cc @@ -1,24 +1,23 @@ #include "utils/containers/map_from_pairs.h" -#include "test/utils/doctest/fmt/unordered_map.h" -#include "utils/hash/pair.h" +#include "test/utils/doctest/fmt/map.h" #include #include +#include using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("map_from_pairs") { - std::unordered_set> input = - std::unordered_set>{ + std::set> input = + std::set>{ {1, "one"}, {2, "two"}, }; - std::unordered_map result = map_from_pairs(input); - - std::unordered_map correct = - std::unordered_map{ + std::map result = map_from_pairs(input); + std::map correct = + std::map{ {1, "one"}, {2, "two"}, }; diff --git a/lib/utils/test/src/utils/containers/merge_disjoint_maps.cc b/lib/utils/test/src/utils/containers/merge_disjoint_maps.cc index 24e8d548ae..bf4b2202d7 100644 --- a/lib/utils/test/src/utils/containers/merge_disjoint_maps.cc +++ b/lib/utils/test/src/utils/containers/merge_disjoint_maps.cc @@ -1,37 +1,37 @@ #include "utils/containers/merge_disjoint_maps.h" -#include "test/utils/doctest/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/map.h" #include using namespace ::FlexFlow; TEST_SUITE(FF_TEST_SUITE) { TEST_CASE("merge_disjoint_maps") { - std::unordered_map m1 = { + std::map m1 = { {4, "four"}, {2, "two"}, }; - std::unordered_map m2 = { + std::map m2 = { {3, "four"}, }; - std::unordered_map m3 = { + std::map m3 = { {1, "one"}, }; - std::unordered_map m4 = {}; + std::map m4 = {}; SUBCASE("maps are disjoint") { - std::vector> input = { + std::vector> input = { m1, m2, m3, m4, }; - std::unordered_map result = merge_disjoint_maps(input); + std::map result = merge_disjoint_maps(input); - std::unordered_map correct = { + std::map correct = { {4, "four"}, {2, "two"}, {3, "four"}, @@ -42,12 +42,12 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("maps are not disjoint") { - std::unordered_map m5 = { + std::map m5 = { {4, "five"}, {6, "six"}, }; - std::vector> input = { + std::vector> input = { m1, m2, m3, @@ -59,12 +59,12 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("maps are not disjoint but have identical values") { - std::unordered_map m5 = { + std::map m5 = { {4, "four"}, {6, "six"}, }; - std::vector> input = { + std::vector> input = { m1, m2, m3, diff --git a/lib/utils/test/src/utils/containers/merge_disjoint_unordered_maps.cc b/lib/utils/test/src/utils/containers/merge_disjoint_unordered_maps.cc new file mode 100644 index 0000000000..ceecea96c1 --- /dev/null +++ b/lib/utils/test/src/utils/containers/merge_disjoint_unordered_maps.cc @@ -0,0 +1,78 @@ +#include "utils/containers/merge_disjoint_unordered_maps.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("merge_disjoint_unordered_maps") { + std::unordered_map m1 = { + {4, "four"}, + {2, "two"}, + }; + + std::unordered_map m2 = { + {3, "four"}, + }; + + std::unordered_map m3 = { + {1, "one"}, + }; + + std::unordered_map m4 = {}; + + SUBCASE("maps are disjoint") { + std::vector> input = { + m1, + m2, + m3, + m4, + }; + + std::unordered_map result = merge_disjoint_unordered_maps(input); + + std::unordered_map correct = { + {4, "four"}, + {2, "two"}, + {3, "four"}, + {1, "one"}, + }; + + CHECK(result == correct); + } + + SUBCASE("maps are not disjoint") { + std::unordered_map m5 = { + {4, "five"}, + {6, "six"}, + }; + + std::vector> input = { + m1, + m2, + m3, + m4, + m5, + }; + + CHECK_THROWS(merge_disjoint_unordered_maps(input)); + } + + SUBCASE("maps are not disjoint but have identical values") { + std::unordered_map m5 = { + {4, "four"}, + {6, "six"}, + }; + + std::vector> input = { + m1, + m2, + m3, + m4, + m5, + }; + + CHECK_THROWS(merge_disjoint_unordered_maps(input)); + } + } +} diff --git a/lib/utils/test/src/utils/containers/merge_maps_with.cc b/lib/utils/test/src/utils/containers/merge_maps_with.cc index ec5b31abf3..fd73a9345e 100644 --- a/lib/utils/test/src/utils/containers/merge_maps_with.cc +++ b/lib/utils/test/src/utils/containers/merge_maps_with.cc @@ -1,5 +1,5 @@ #include "utils/containers/merge_maps_with.h" -#include "test/utils/doctest/fmt/unordered_map.h" +#include "test/utils/doctest/fmt/map.h" #include "test/utils/rapidcheck.h" #include "utils/containers/binary_merge_maps_with.h" #include @@ -14,37 +14,37 @@ TEST_SUITE(FF_TEST_SUITE) { RC_SUBCASE( "with two inputs, matches binary_merge_maps_with", - [&](std::unordered_map const &lhs, - std::unordered_map const &rhs) { - std::unordered_map from_merge_maps_with = + [&](std::map const &lhs, + std::map const &rhs) { + std::map from_merge_maps_with = merge_maps_with(std::vector{lhs, rhs}, string_concat); - std::unordered_map from_binary_merge_maps_with = + std::map from_binary_merge_maps_with = binary_merge_maps_with(lhs, rhs, string_concat); CHECK(from_merge_maps_with == from_binary_merge_maps_with); }); SUBCASE("maps overlap") { - std::unordered_map map1 = { + std::map map1 = { {1, "map1_one."}, {4, "map1_four."}, }; - std::unordered_map map2 = { + std::map map2 = { {2, "map2_two."}, {4, "map2_four."}, {5, "map2_five."}, }; - std::unordered_map map3 = { + std::map map3 = { {1, "map3_one."}, }; - std::unordered_map result = + std::map result = merge_maps_with(std::vector{map1, map2, map3}, string_concat); - std::unordered_map correct = { + std::map correct = { {1, "map1_one.map3_one."}, {2, "map2_two."}, {4, "map1_four.map2_four."}, @@ -58,24 +58,24 @@ TEST_SUITE(FF_TEST_SUITE) { std::string const &) -> std::string { PANIC(); }; SUBCASE("maps do not overlap") { - std::unordered_map map1 = { + std::map map1 = { {8, "map1_eight."}, {4, "map1_four."}, }; - std::unordered_map map2 = { + std::map map2 = { {2, "map2_two."}, {5, "map2_five."}, }; - std::unordered_map map3 = { + std::map map3 = { {1, "map3_one."}, }; - std::unordered_map result = + std::map result = merge_maps_with(std::vector{map1, map2, map3}, fail_if_called); - std::unordered_map correct = { + std::map correct = { {1, "map3_one."}, {2, "map2_two."}, {4, "map1_four."}, @@ -87,12 +87,12 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("no maps are provided") { - std::vector> maps = {}; + std::vector> maps = {}; - std::unordered_map result = + std::map result = merge_maps_with(maps, fail_if_called); - std::unordered_map correct = {}; + std::map correct = {}; CHECK(result == correct); } diff --git a/lib/utils/test/src/utils/containers/merge_unordered_maps_with.cc b/lib/utils/test/src/utils/containers/merge_unordered_maps_with.cc new file mode 100644 index 0000000000..66827ca453 --- /dev/null +++ b/lib/utils/test/src/utils/containers/merge_unordered_maps_with.cc @@ -0,0 +1,100 @@ +#include "utils/containers/merge_unordered_maps_with.h" +#include "test/utils/doctest/fmt/unordered_map.h" +#include "test/utils/rapidcheck.h" +#include "utils/containers/binary_merge_unordered_maps_with.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("merge_unordered_maps_with") { + auto string_concat = [](std::string const &l, std::string const &r) { + return l + r; + }; + + RC_SUBCASE( + "with two inputs, matches binary_merge_unordered_maps_with", + [&](std::unordered_map const &lhs, + std::unordered_map const &rhs) { + std::unordered_map from_merge_unordered_maps_with = + merge_unordered_maps_with(std::vector{lhs, rhs}, string_concat); + + std::unordered_map from_binary_merge_unordered_maps_with = + binary_merge_unordered_maps_with(lhs, rhs, string_concat); + + CHECK(from_merge_unordered_maps_with == from_binary_merge_unordered_maps_with); + }); + + SUBCASE("maps overlap") { + std::unordered_map map1 = { + {1, "map1_one."}, + {4, "map1_four."}, + }; + + std::unordered_map map2 = { + {2, "map2_two."}, + {4, "map2_four."}, + {5, "map2_five."}, + }; + + std::unordered_map map3 = { + {1, "map3_one."}, + }; + + std::unordered_map result = + merge_unordered_maps_with(std::vector{map1, map2, map3}, string_concat); + + std::unordered_map correct = { + {1, "map1_one.map3_one."}, + {2, "map2_two."}, + {4, "map1_four.map2_four."}, + {5, "map2_five."}, + }; + + CHECK(result == correct); + } + + auto fail_if_called = [](std::string const &, + std::string const &) -> std::string { PANIC(); }; + + SUBCASE("maps do not overlap") { + std::unordered_map map1 = { + {8, "map1_eight."}, + {4, "map1_four."}, + }; + + std::unordered_map map2 = { + {2, "map2_two."}, + {5, "map2_five."}, + }; + + std::unordered_map map3 = { + {1, "map3_one."}, + }; + + std::unordered_map result = + merge_unordered_maps_with(std::vector{map1, map2, map3}, fail_if_called); + + std::unordered_map correct = { + {1, "map3_one."}, + {2, "map2_two."}, + {4, "map1_four."}, + {5, "map2_five."}, + {8, "map1_eight."}, + }; + + CHECK(result == correct); + } + + SUBCASE("no maps are provided") { + std::vector> maps = {}; + + std::unordered_map result = + merge_unordered_maps_with(maps, fail_if_called); + + std::unordered_map correct = {}; + + CHECK(result == correct); + } + } +} diff --git a/lib/utils/test/src/utils/graph/digraph/algorithms/get_imm_post_dominators_map.cc b/lib/utils/test/src/utils/graph/digraph/algorithms/get_imm_post_dominators_map.cc index 4435ccc26c..e92a37169b 100644 --- a/lib/utils/test/src/utils/graph/digraph/algorithms/get_imm_post_dominators_map.cc +++ b/lib/utils/test/src/utils/graph/digraph/algorithms/get_imm_post_dominators_map.cc @@ -1,5 +1,4 @@ #include "utils/graph/digraph/algorithms/get_imm_post_dominators_map.h" -#include "utils/containers/generate_map.h" #include "utils/graph/algorithms.h" #include "utils/graph/instances/adjacency_digraph.h" #include diff --git a/lib/utils/test/src/utils/graph/series_parallel/sp_ization/work_duplicating_sp_ization.cc b/lib/utils/test/src/utils/graph/series_parallel/sp_ization/work_duplicating_sp_ization.cc index 8e276daca7..85c4d66f6d 100644 --- a/lib/utils/test/src/utils/graph/series_parallel/sp_ization/work_duplicating_sp_ization.cc +++ b/lib/utils/test/src/utils/graph/series_parallel/sp_ization/work_duplicating_sp_ization.cc @@ -1,6 +1,6 @@ #include "utils/graph/series_parallel/sp_ization/work_duplicating_sp_ization.h" #include "test/utils/rapidcheck.h" -#include "utils/containers/generate_map.h" +#include "utils/containers/generate_unordered_map.h" #include "utils/graph/algorithms.h" #include "utils/graph/digraph/algorithms/get_initial_nodes.h" #include "utils/graph/digraph/algorithms/get_terminal_nodes.h" @@ -46,7 +46,7 @@ static std::pair> } std::unordered_map cost_map = - generate_map(get_nodes(g), [](Node const &) { + generate_unordered_map(get_nodes(g), [](Node const &) { return static_cast(*rc::gen::inRange(1, 101)); }); diff --git a/lib/utils/test/src/utils/many_to_one/many_to_one.cc b/lib/utils/test/src/utils/many_to_one/many_to_one.cc index 13f88fab7c..ce219676e1 100644 --- a/lib/utils/test/src/utils/many_to_one/many_to_one.cc +++ b/lib/utils/test/src/utils/many_to_one/many_to_one.cc @@ -96,6 +96,37 @@ TEST_SUITE(FF_TEST_SUITE) { } } + TEST_CASE("adl_serializer>") { + ManyToOne deserialized = ManyToOne{ + {{2, 20}, {"two"}}, + {{3}, "three"}, + {{4, 40, 400}, "four"}, + }; + + nlohmann::json serialized = std::set>{ + {2, "two"}, + {3, "three"}, + {4, "four"}, + {20, "two"}, + {40, "four"}, + {400, "four"}, + }; + + SUBCASE("to_json") { + nlohmann::json result = deserialized; + nlohmann::json correct = serialized; + + CHECK(result == correct); + } + + SUBCASE("from_json") { + ManyToOne result = serialized; + ManyToOne correct = deserialized; + + CHECK(result == correct); + } + } + TEST_CASE("fmt::to_string(ManyToOne)") { ManyToOne input = ManyToOne{ {{1, 10, 100}, "one"}, @@ -107,4 +138,52 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(multiset_of(result) == multiset_of(correct)); } + + TEST_CASE("many_to_one_from_unstructured_relation") { + SUBCASE("relation is many-to-one") { + std::unordered_set> input = { + {1, "odd"}, + {2, "even"}, + {3, "odd"}, + }; + + ManyToOne result = + many_to_one_from_unstructured_relation(input); + ManyToOne correct = { + {{1, 3}, "odd"}, + {{2}, "even"}, + }; + + CHECK(result == correct); + } + + SUBCASE("relation is one-to-one") { + std::unordered_set> input = { + {1, "one"}, + {2, "two"}, + {3, "three"}, + }; + + ManyToOne result = + many_to_one_from_unstructured_relation(input); + ManyToOne correct = { + {{1}, "one"}, + {{2}, "two"}, + {{3}, "three"}, + }; + + CHECK(result == correct); + } + + SUBCASE("relation is not many-to-one") { + std::unordered_set> input = { + {1, "one"}, + {1, "ODD"}, + {2, "two"}, + {3, "ODD"}, + }; + + CHECK_THROWS(many_to_one_from_unstructured_relation(input)); + } + } } diff --git a/lib/utils/test/src/utils/many_to_one/many_to_one_from_unstructured_relation.cc b/lib/utils/test/src/utils/many_to_one/many_to_one_from_unstructured_relation.cc deleted file mode 100644 index c9d0e866d4..0000000000 --- a/lib/utils/test/src/utils/many_to_one/many_to_one_from_unstructured_relation.cc +++ /dev/null @@ -1,54 +0,0 @@ -#include "utils/many_to_one/many_to_one_from_unstructured_relation.h" -#include - -using namespace ::FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("many_to_one_from_unstructured_relation") { - SUBCASE("relation is many-to-one") { - std::unordered_set> input = { - {1, "odd"}, - {2, "even"}, - {3, "odd"}, - }; - - ManyToOne result = - many_to_one_from_unstructured_relation(input); - ManyToOne correct = { - {{1, 3}, "odd"}, - {{2}, "even"}, - }; - - CHECK(result == correct); - } - - SUBCASE("relation is one-to-one") { - std::unordered_set> input = { - {1, "one"}, - {2, "two"}, - {3, "three"}, - }; - - ManyToOne result = - many_to_one_from_unstructured_relation(input); - ManyToOne correct = { - {{1}, "one"}, - {{2}, "two"}, - {{3}, "three"}, - }; - - CHECK(result == correct); - } - - SUBCASE("relation is not many-to-one") { - std::unordered_set> input = { - {1, "one"}, - {1, "ODD"}, - {2, "two"}, - {3, "ODD"}, - }; - - CHECK_THROWS(many_to_one_from_unstructured_relation(input)); - } - } -} diff --git a/lib/utils/test/src/utils/many_to_one/unstructured_relation_from_many_to_one.cc b/lib/utils/test/src/utils/many_to_one/unstructured_relation_from_many_to_one.cc deleted file mode 100644 index 15a8f5b390..0000000000 --- a/lib/utils/test/src/utils/many_to_one/unstructured_relation_from_many_to_one.cc +++ /dev/null @@ -1,25 +0,0 @@ -#include "utils/many_to_one/unstructured_relation_from_many_to_one.h" -#include "test/utils/doctest/fmt/pair.h" -#include "test/utils/doctest/fmt/unordered_set.h" -#include - -using namespace ::FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("unstructured_relation_from_many_to_one") { - ManyToOne input = { - {{1, 3}, "odd"}, - {{2}, "even"}, - }; - - std::unordered_set> result = - unstructured_relation_from_many_to_one(input); - std::unordered_set> correct = { - {1, "odd"}, - {2, "even"}, - {3, "odd"}, - }; - - CHECK(result == correct); - } -} diff --git a/lib/utils/test/src/utils/one_to_many/one_to_many.cc b/lib/utils/test/src/utils/one_to_many/one_to_many.cc index d2ea7d6a0b..de149ce609 100644 --- a/lib/utils/test/src/utils/one_to_many/one_to_many.cc +++ b/lib/utils/test/src/utils/one_to_many/one_to_many.cc @@ -1,8 +1,11 @@ #include "utils/one_to_many/one_to_many.h" #include "test/utils/doctest/fmt/multiset.h" #include "test/utils/doctest/fmt/unordered_set.h" +#include "test/utils/doctest/fmt/set.h" #include "utils/containers/multiset_of.h" #include "utils/one_to_many/one_to_many_from_l_to_r_mapping.h" +#include "test/utils/doctest/fmt/pair.h" +#include "test/utils/doctest/fmt/unordered_set.h" #include using namespace ::FlexFlow; @@ -31,9 +34,9 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("at_l") { - nonempty_unordered_set result = m.at_l(1); + nonempty_set result = m.at_l(1); - nonempty_unordered_set correct = {"one", "One", "ONE"}; + nonempty_set correct = {"one", "One", "ONE"}; CHECK(result == correct); } @@ -47,17 +50,17 @@ TEST_SUITE(FF_TEST_SUITE) { } SUBCASE("left_values") { - std::unordered_set result = m.left_values(); + std::set result = m.left_values(); - std::unordered_set correct = {1, 2}; + std::set correct = {1, 2}; CHECK(result == correct); } SUBCASE("right_values") { - std::unordered_set result = m.right_values(); + std::set result = m.right_values(); - std::unordered_set correct = {"one", "One", "ONE", "two"}; + std::set correct = {"one", "One", "ONE", "two"}; CHECK(result == correct); } @@ -90,6 +93,35 @@ TEST_SUITE(FF_TEST_SUITE) { } } + TEST_CASE("adl_serializer>") { + OneToMany deserialized = OneToMany{ + {2, {"two", "TWO"}}, + {3, {"three"}}, + {4, {"four"}}, + }; + + nlohmann::json serialized = std::set>{ + {2, "two"}, + {2, "TWO"}, + {3, "three"}, + {4, "four"}, + }; + + SUBCASE("to_json") { + nlohmann::json result = deserialized; + nlohmann::json correct = serialized; + + CHECK(result == correct); + } + + SUBCASE("from_json") { + OneToMany result = serialized; + OneToMany correct = deserialized; + + CHECK(result == correct); + } + } + TEST_CASE("fmt::to_string(OneToMany)") { OneToMany input = one_to_many_from_l_to_r_mapping( @@ -100,4 +132,67 @@ TEST_SUITE(FF_TEST_SUITE) { CHECK(multiset_of(result) == multiset_of(correct)); } + + TEST_CASE("unstructured_relation_from_one_to_many") { + OneToMany input = { + {1, {"one", "ONE"}}, + {2, {"two"}}, + }; + + std::unordered_set> result = + unstructured_relation_from_one_to_many(input); + std::unordered_set> correct = { + {1, "one"}, + {1, "ONE"}, + {2, "two"}, + }; + + CHECK(result == correct); + } + + TEST_CASE("one_to_many_from_unstructured_relation") { + SUBCASE("relation is one-to-many") { + std::unordered_set> input = { + {1, "one"}, + {1, "ONE"}, + {2, "two"}, + }; + + OneToMany result = + one_to_many_from_unstructured_relation(input); + OneToMany correct = { + {1, {"one", "ONE"}}, + {2, {"two"}}, + }; + + CHECK(result == correct); + } + + SUBCASE("relation is one-to-one") { + std::unordered_set> input = { + {1, "one"}, + {2, "two"}, + }; + + OneToMany result = + one_to_many_from_unstructured_relation(input); + OneToMany correct = { + {1, {"one"}}, + {2, {"two"}}, + }; + + CHECK(result == correct); + } + + SUBCASE("relation is not one-to-many") { + std::unordered_set> input = { + {1, "one"}, + {1, "ONE"}, + {2, "two"}, + {3, "ONE"}, + }; + + CHECK_THROWS(one_to_many_from_unstructured_relation(input)); + } + } } diff --git a/lib/utils/test/src/utils/one_to_many/one_to_many_from_unstructured_relation.cc b/lib/utils/test/src/utils/one_to_many/one_to_many_from_unstructured_relation.cc deleted file mode 100644 index 023c556c46..0000000000 --- a/lib/utils/test/src/utils/one_to_many/one_to_many_from_unstructured_relation.cc +++ /dev/null @@ -1,53 +0,0 @@ -#include "utils/one_to_many/one_to_many_from_unstructured_relation.h" -#include -#include - -using namespace ::FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("one_to_many_from_unstructured_relation") { - SUBCASE("relation is one-to-many") { - std::unordered_set> input = { - {1, "one"}, - {1, "ONE"}, - {2, "two"}, - }; - - OneToMany result = - one_to_many_from_unstructured_relation(input); - OneToMany correct = { - {1, {"one", "ONE"}}, - {2, {"two"}}, - }; - - CHECK(result == correct); - } - - SUBCASE("relation is one-to-one") { - std::unordered_set> input = { - {1, "one"}, - {2, "two"}, - }; - - OneToMany result = - one_to_many_from_unstructured_relation(input); - OneToMany correct = { - {1, {"one"}}, - {2, {"two"}}, - }; - - CHECK(result == correct); - } - - SUBCASE("relation is not one-to-many") { - std::unordered_set> input = { - {1, "one"}, - {1, "ONE"}, - {2, "two"}, - {3, "ONE"}, - }; - - CHECK_THROWS(one_to_many_from_unstructured_relation(input)); - } - } -} diff --git a/lib/utils/test/src/utils/one_to_many/unstructured_relation_from_one_to_many.cc b/lib/utils/test/src/utils/one_to_many/unstructured_relation_from_one_to_many.cc deleted file mode 100644 index 06a5d5ff2e..0000000000 --- a/lib/utils/test/src/utils/one_to_many/unstructured_relation_from_one_to_many.cc +++ /dev/null @@ -1,25 +0,0 @@ -#include "utils/one_to_many/unstructured_relation_from_one_to_many.h" -#include "test/utils/doctest/fmt/pair.h" -#include "test/utils/doctest/fmt/unordered_set.h" -#include - -using namespace ::FlexFlow; - -TEST_SUITE(FF_TEST_SUITE) { - TEST_CASE("unstructured_relation_from_one_to_many") { - OneToMany input = { - {1, {"one", "ONE"}}, - {2, {"two"}}, - }; - - std::unordered_set> result = - unstructured_relation_from_one_to_many(input); - std::unordered_set> correct = { - {1, "one"}, - {1, "ONE"}, - {2, "two"}, - }; - - CHECK(result == correct); - } -}